diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..fe47d097a8faa7cab2ebe3a4dce2187a39ede95e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +samples/A/test_1.png filter=lfs diff=lfs merge=lfs -text +samples/A/test_2.png filter=lfs diff=lfs merge=lfs -text +samples/A/test_3.png filter=lfs diff=lfs merge=lfs -text +samples/A/test_4.png filter=lfs diff=lfs merge=lfs -text +samples/A/test_5.png filter=lfs diff=lfs merge=lfs -text +samples/B/test_1.png filter=lfs diff=lfs merge=lfs -text +samples/B/test_2.png filter=lfs diff=lfs merge=lfs -text +samples/B/test_3.png filter=lfs diff=lfs merge=lfs -text +samples/B/test_4.png filter=lfs diff=lfs merge=lfs -text +samples/B/test_5.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0eded42e17584aa9a0e7ec5afe57ebc99eb32e0a --- /dev/null +++ b/.gitignore @@ -0,0 +1,128 @@ +*.pth +gradio_cached_examples/ + +.idea +.DS_Store +work_dirs/ +pretrain_models/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/en/_build/ +docs/zh_cn/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.DS_Store + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +data +.vscode +.idea + +# custom +*.pkl +*.pkl.json +*.log.json +work_dirs/ +mmseg/.mim + +# Pytorch +*.pth diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..16c30878679ed8477e27a0fccf2fe62cbf700f35 --- /dev/null +++ b/app.py @@ -0,0 +1,42 @@ +import gradio as gr +import glob +import torch +from opencd.apis import OpenCDInferencer + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +config_file = 'configs/TTP/ttp_sam_large_levircd_infer.py' +checkpoint_file = 'ckpt/epoch_270.pth' + +# build the model from a config file and a checkpoint file +mmcd_inferencer = OpenCDInferencer( + model=config_file, + weights=checkpoint_file, + classes=['unchanged', 'changed'], + palette=[[0, 0, 0], [255, 255, 255]], + device=device +) + +def infer(img1, img2): + # test a single image + result = mmcd_inferencer([[img1, img2]], show=False, return_vis=True) + visualization = result['visualization'] + return visualization + + +with gr.Blocks() as demo: + with gr.Row(): + input_0 = gr.Image(label='Input Image1') + input_1 = gr.Image(label='Input Image2') + with gr.Row(): + output_gt = gr.Image(label='Predicted Mask') + btn = gr.Button("Detect") + btn.click(infer, inputs=[input_0, input_1], outputs=[output_gt]) + + img1_files = glob.glob('samples/A/*.png') + img2_files = [f.replace('A', 'B') for f in img1_files] + input_files = [[x, y] for x, y in zip(img1_files, img2_files)] + gr.Examples(input_files, fn=infer, inputs=[input_0, input_1], outputs=[output_gt], cache_examples=True) + +if __name__ == "__main__": + demo.launch() diff --git a/ckpt/epoch_270.pth b/ckpt/epoch_270.pth new file mode 100644 index 0000000000000000000000000000000000000000..c4e1955ac36fcf924e7ec1e96cf14294dee7b337 --- /dev/null +++ b/ckpt/epoch_270.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a37d3a79379f4bf3d7ecb85b71209f35cd8af7e61cae564038397e8b7fb3eaf2 +size 1415063308 diff --git a/configs/.DS_Store b/configs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..842e2f563364fab91b43efef1590db0bfb830f6f Binary files /dev/null and b/configs/.DS_Store differ diff --git a/configs/TTP/ttp_sam_large_levircd.py b/configs/TTP/ttp_sam_large_levircd.py new file mode 100644 index 0000000000000000000000000000000000000000..71f36e42d056a1c80e12687a2ab6105f8a8d82f3 --- /dev/null +++ b/configs/TTP/ttp_sam_large_levircd.py @@ -0,0 +1,202 @@ +default_scope = 'opencd' + +work_dir = 'work_dirs/lervicd/ttp_sam_large_levircd' + +custom_imports = dict(imports=['mmseg.ttp'], allow_failed_imports=False) + +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=10, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, save_best='cd/iou_changed', max_keep_ckpts=5, greater_keys=['cd/iou_changed'], save_last=True), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='CDVisualizationHook', interval=1, + img_shape=(1024, 1024, 3)) +) +vis_backends = [dict(type='CDLocalVisBackend'), + dict(type='WandbVisBackend', + init_kwargs=dict(project='samcd', group='levircd', name='ttp_sam_large_levircd')) + ] + +visualizer = dict( + type='CDLocalVisualizer', + vis_backends=vis_backends, name='visualizer', alpha=1.0) +log_processor = dict(by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False + +crop_size = (512, 512) + +data_preprocessor = dict( + type='DualInputSegDataPreProcessor', + mean=[123.675, 116.28, 103.53] * 2, + std=[58.395, 57.12, 57.375] * 2, + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size_divisor=32, + test_cfg=dict(size_divisor=32) +) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +fpn_norm_cfg = dict(type='LN2d', requires_grad=True) + +sam_pretrain_ckpt_path = 'https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-large-p16_sam-pre_3rdparty_sa1b-1024px_20230411-595feafd.pth' + +model = dict( + type='SiamEncoderDecoder', + data_preprocessor=data_preprocessor, + backbone=dict( + type='MMPretrainSamVisionEncoder', + encoder_cfg=dict( + type='mmpretrain.ViTSAM', + arch='large', + img_size=crop_size[0], + patch_size=16, + out_channels=256, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + layer_cfgs=dict(type='TimeFusionTransformerEncoderLayer'), + init_cfg=dict(type='Pretrained', checkpoint=sam_pretrain_ckpt_path, prefix='backbone.'), + ), + peft_cfg=dict( + r=16, + target_modules=["qkv"], + lora_dropout=0.01, + bias='lora_only', + ), + ), + neck=dict( + type='SequentialNeck', + necks=[ + dict( + type='FeatureFusionNeck', + policy='concat', + out_indices=(0,)), + dict( + type='SimpleFPN', + backbone_channel=512, + in_channels=[128, 256, 512, 512], + out_channels=256, + num_outs=5, + norm_cfg=fpn_norm_cfg), + ], + ), + decode_head=dict( + type='MLPSegHead', + out_size=(128, 128), + in_channels=[256]*5, + in_index=[0, 1, 2, 3, 4], + channels=256, + dropout_ratio=0, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)) +) # yapf: disable + +dataset_type = 'LEVIR_CD_Dataset' +data_root = '/mnt/levir_datasets/levir-cd' + + +train_pipeline = [ + dict(type='MultiImgLoadImageFromFile'), + dict(type='MultiImgLoadAnnotations'), + dict(type='MultiImgRandomRotate', prob=0.5, degree=180), + dict(type='MultiImgRandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='MultiImgRandomFlip', prob=0.5, direction='horizontal'), + dict(type='MultiImgRandomFlip', prob=0.5, direction='vertical'), + # dict(type='MultiImgExchangeTime', prob=0.5), + dict( + type='MultiImgPhotoMetricDistortion', + brightness_delta=10, + contrast_range=(0.8, 1.2), + saturation_range=(0.8, 1.2), + hue_delta=10), + dict(type='MultiImgPackSegInputs') +] +test_pipeline = [ + dict(type='MultiImgLoadImageFromFile'), + dict(type='MultiImgResize', scale=(1024, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='MultiImgLoadAnnotations'), + dict(type='MultiImgPackSegInputs') +] + +batch_size_per_gpu = 2 + +train_dataloader = dict( + batch_size=batch_size_per_gpu, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + seg_map_path='train/label', + img_path_from='train/A', + img_path_to='train/B'), + pipeline=train_pipeline) +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + seg_map_path='test/label', + img_path_from='test/A', + img_path_to='test/B'), + pipeline=test_pipeline) +) + +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CDMetric', +) +test_evaluator = val_evaluator + +max_epochs = 300 +base_lr = 0.0004 +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-4, by_epoch=True, begin=0, end=5, convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=max_epochs, + begin=5, + by_epoch=True, + end=max_epochs, + convert_to_iter_based=True + ), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05), +) + diff --git a/configs/TTP/ttp_sam_large_levircd_fp16.py b/configs/TTP/ttp_sam_large_levircd_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..42028b9ce84f3e20d0bab1947e2cf924c145a113 --- /dev/null +++ b/configs/TTP/ttp_sam_large_levircd_fp16.py @@ -0,0 +1,201 @@ +default_scope = 'opencd' + +work_dir = 'work_dirs/lervicd/ttp_sam_large_levircd_fp16' + +custom_imports = dict(imports=['mmseg.ttp'], allow_failed_imports=False) + +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=10, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, save_best='cd/iou_changed', max_keep_ckpts=5, greater_keys=['cd/iou_changed'], save_last=True), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='CDVisualizationHook', interval=1, img_shape=(1024, 1024, 3)) +) +vis_backends = [dict(type='CDLocalVisBackend'), + dict(type='WandbVisBackend', init_kwargs=dict(project='samcd', group='levircd', name='ttp_sam_large_levircd_fp16')) + ] + +visualizer = dict( + type='CDLocalVisualizer', + vis_backends=vis_backends, name='visualizer', alpha=1.0) +log_processor = dict(by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False + +crop_size = (512, 512) + +data_preprocessor = dict( + type='DualInputSegDataPreProcessor', + mean=[123.675, 116.28, 103.53] * 2, + std=[58.395, 57.12, 57.375] * 2, + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size_divisor=32, + test_cfg=dict(size_divisor=32) +) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +fpn_norm_cfg = dict(type='LN2d', requires_grad=True) + +sam_pretrain_ckpt_path = 'https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-large-p16_sam-pre_3rdparty_sa1b-1024px_20230411-595feafd.pth' + +model = dict( + type='SiamEncoderDecoder', + data_preprocessor=data_preprocessor, + backbone=dict( + type='MMPretrainSamVisionEncoder', + encoder_cfg=dict( + type='mmpretrain.ViTSAM', + arch='large', + img_size=crop_size[0], + patch_size=16, + out_channels=256, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + layer_cfgs=dict(type='TimeFusionTransformerEncoderLayer'), + init_cfg=dict(type='Pretrained', checkpoint=sam_pretrain_ckpt_path, prefix='backbone.'), + ), + peft_cfg=dict( + r=16, + target_modules=["qkv"], + lora_dropout=0.01, + bias='lora_only', + ), + ), + neck=dict( + type='SequentialNeck', + necks=[ + dict( + type='FeatureFusionNeck', + policy='concat', + out_indices=(0,)), + dict( + type='SimpleFPN', + backbone_channel=512, + in_channels=[128, 256, 512, 512], + out_channels=256, + num_outs=5, + norm_cfg=fpn_norm_cfg), + ], + ), + decode_head=dict( + type='MLPSegHead', + out_size=(128, 128), + in_channels=[256]*5, + in_index=[0, 1, 2, 3, 4], + channels=256, + dropout_ratio=0, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)) +) # yapf: disable + +dataset_type = 'LEVIR_CD_Dataset' +data_root = '/mnt/levir_datasets/levir-cd' + + +train_pipeline = [ + dict(type='MultiImgLoadImageFromFile'), + dict(type='MultiImgLoadAnnotations'), + dict(type='MultiImgRandomRotate', prob=0.5, degree=180), + dict(type='MultiImgRandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='MultiImgRandomFlip', prob=0.5, direction='horizontal'), + dict(type='MultiImgRandomFlip', prob=0.5, direction='vertical'), + # dict(type='MultiImgExchangeTime', prob=0.5), + dict( + type='MultiImgPhotoMetricDistortion', + brightness_delta=10, + contrast_range=(0.8, 1.2), + saturation_range=(0.8, 1.2), + hue_delta=10), + dict(type='MultiImgPackSegInputs') +] +test_pipeline = [ + dict(type='MultiImgLoadImageFromFile'), + dict(type='MultiImgResize', scale=(1024, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='MultiImgLoadAnnotations'), + dict(type='MultiImgPackSegInputs') +] + +batch_size_per_gpu = 2 + +train_dataloader = dict( + batch_size=batch_size_per_gpu, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + seg_map_path='train/label', + img_path_from='train/A', + img_path_to='train/B'), + pipeline=train_pipeline) +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + seg_map_path='test/label', + img_path_from='test/A', + img_path_to='test/B'), + pipeline=test_pipeline) +) + +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CDMetric', +) +test_evaluator = val_evaluator + +max_epochs = 300 +base_lr = 0.0004 +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-4, by_epoch=True, begin=0, end=5, convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=max_epochs, + begin=5, + by_epoch=True, + end=max_epochs, + convert_to_iter_based=True + ), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + + +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict( + type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05), + dtype='float16', +) + diff --git a/configs/TTP/ttp_sam_large_levircd_infer.py b/configs/TTP/ttp_sam_large_levircd_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..5826cdc2818f6163863e2d70dbb8720ae27b1bc6 --- /dev/null +++ b/configs/TTP/ttp_sam_large_levircd_infer.py @@ -0,0 +1,199 @@ +default_scope = 'opencd' + +work_dir = 'work_dirs/lervicd/ttp_sam_large_levircd' + +custom_imports = dict(imports=['mmseg.ttp'], allow_failed_imports=False) + +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=10, log_metric_by_epoch=True), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, save_best='cd/iou_changed', max_keep_ckpts=5, greater_keys=['cd/iou_changed'], save_last=True), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='CDVisualizationHook', interval=1, + img_shape=(1024, 1024, 3)) +) +vis_backends = [dict(type='CDLocalVisBackend')] + +visualizer = dict( + type='CDLocalVisualizer', + vis_backends=vis_backends, name='visualizer', alpha=1.0) +log_processor = dict(by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False + +crop_size = (512, 512) + +data_preprocessor = dict( + type='DualInputSegDataPreProcessor', + mean=[123.675, 116.28, 103.53] * 2, + std=[58.395, 57.12, 57.375] * 2, + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size_divisor=32, + test_cfg=dict(size_divisor=32) +) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +fpn_norm_cfg = dict(type='LN2d', requires_grad=True) + +# sam_pretrain_ckpt_path = 'https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-large-p16_sam-pre_3rdparty_sa1b-1024px_20230411-595feafd.pth' + +model = dict( + type='SiamEncoderDecoder', + data_preprocessor=data_preprocessor, + backbone=dict( + type='MMPretrainSamVisionEncoder', + encoder_cfg=dict( + type='mmpretrain.ViTSAM', + arch='large', + img_size=crop_size[0], + patch_size=16, + out_channels=256, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + layer_cfgs=dict(type='TimeFusionTransformerEncoderLayer'), + # init_cfg=dict(type='Pretrained', checkpoint=sam_pretrain_ckpt_path, prefix='backbone.'), + ), + peft_cfg=dict( + r=16, + target_modules=["qkv"], + lora_dropout=0.01, + bias='lora_only', + ), + ), + neck=dict( + type='SequentialNeck', + necks=[ + dict( + type='FeatureFusionNeck', + policy='concat', + out_indices=(0,)), + dict( + type='SimpleFPN', + backbone_channel=512, + in_channels=[128, 256, 512, 512], + out_channels=256, + num_outs=5, + norm_cfg=fpn_norm_cfg), + ], + ), + decode_head=dict( + type='MLPSegHead', + out_size=(128, 128), + in_channels=[256]*5, + in_index=[0, 1, 2, 3, 4], + channels=256, + dropout_ratio=0, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)) +) # yapf: disable + +dataset_type = 'LEVIR_CD_Dataset' +data_root = '/mnt/levir_datasets/levir-cd' + + +train_pipeline = [ + dict(type='MultiImgLoadImageFromFile'), + dict(type='MultiImgLoadAnnotations'), + dict(type='MultiImgRandomRotate', prob=0.5, degree=180), + dict(type='MultiImgRandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='MultiImgRandomFlip', prob=0.5, direction='horizontal'), + dict(type='MultiImgRandomFlip', prob=0.5, direction='vertical'), + # dict(type='MultiImgExchangeTime', prob=0.5), + dict( + type='MultiImgPhotoMetricDistortion', + brightness_delta=10, + contrast_range=(0.8, 1.2), + saturation_range=(0.8, 1.2), + hue_delta=10), + dict(type='MultiImgPackSegInputs') +] +test_pipeline = [ + dict(type='MultiImgLoadImageFromFile', to_float32=True), + dict(type='MultiImgResize', scale=(1024, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='MultiImgLoadAnnotations'), + dict(type='MultiImgPackSegInputs') +] + +batch_size_per_gpu = 2 + +train_dataloader = dict( + batch_size=batch_size_per_gpu, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + seg_map_path='train/label', + img_path_from='train/A', + img_path_to='train/B'), + pipeline=train_pipeline) +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + seg_map_path='test/label', + img_path_from='test/A', + img_path_to='test/B'), + pipeline=test_pipeline) +) + +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CDMetric', +) +test_evaluator = val_evaluator + +max_epochs = 300 +base_lr = 0.0004 +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-4, by_epoch=True, begin=0, end=5, convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=max_epochs, + begin=5, + by_epoch=True, + end=max_epochs, + convert_to_iter_based=True + ), +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05), +) + diff --git a/demo/MMSegmentation_Tutorial.ipynb b/demo/MMSegmentation_Tutorial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ac8601b321e0f0a2b086e61987e39d8bab1a008c --- /dev/null +++ b/demo/MMSegmentation_Tutorial.ipynb @@ -0,0 +1,555 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FVmnaxFJvsb8" + }, + "source": [ + "# MMSegmentation Tutorial\n", + "Welcome to MMSegmentation! \n", + "\n", + "In this tutorial, we demo\n", + "* How to do inference with MMSeg trained weight\n", + "* How to train on your own dataset and visualize the results. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QS8YHrEhbpas" + }, + "source": [ + "## Install MMSegmentation\n", + "This step may take several minutes. \n", + "\n", + "We use PyTorch 1.12 and CUDA 11.3 for this tutorial. You may install other versions by change the version number in pip install command. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UWyLrLYaNEaL", + "outputId": "32a47fe3-f10d-47a1-f6b9-b7c235abdab1" + }, + "outputs": [], + "source": [ + "# Check nvcc version\n", + "!nvcc -V\n", + "# Check GCC version\n", + "!gcc --version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ki3WUBjKbutg", + "outputId": "14bd14b0-4d8c-4fa9-e3f9-da35c0efc0d5" + }, + "outputs": [], + "source": [ + "# Install PyTorch\n", + "!conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch\n", + "# Install mim\n", + "!pip install -U openmim\n", + "# Install mmengine\n", + "!mim install mmengine\n", + "# Install MMCV\n", + "!mim install 'mmcv >= 2.0.0rc1'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nR-hHRvbNJJZ", + "outputId": "10c3b131-d4db-458c-fc10-b94b1c6ed546" + }, + "outputs": [], + "source": [ + "!rm -rf mmsegmentation\n", + "!git clone -b main https://github.com/open-mmlab/mmsegmentation.git \n", + "%cd mmsegmentation\n", + "!pip install -e ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mAE_h7XhPT7d", + "outputId": "83bf0f8e-fc69-40b1-f9fe-0025724a217c" + }, + "outputs": [], + "source": [ + "# Check Pytorch installation\n", + "import torch, torchvision\n", + "print(torch.__version__, torch.cuda.is_available())\n", + "\n", + "# Check MMSegmentation installation\n", + "import mmseg\n", + "print(mmseg.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ta51clKX4cwM" + }, + "source": [ + "## Finetune a semantic segmentation model on a new dataset\n", + "\n", + "To finetune on a customized dataset, the following steps are necessary. \n", + "1. Add a new dataset class. \n", + "2. Create a config file accordingly. \n", + "3. Perform training and evaluation. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AcZg6x_K5Zs3" + }, + "source": [ + "### Add a new dataset\n", + "\n", + "Datasets in MMSegmentation require image and semantic segmentation maps to be placed in folders with the same prefix. To support a new dataset, we may need to modify the original file structure. \n", + "\n", + "In this tutorial, we give an example of converting the dataset. You may refer to [docs](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/tutorials/customize_datasets.md#customize-datasets-by-reorganizing-data) for details about dataset reorganization. \n", + "\n", + "We use [Stanford Background Dataset](http://dags.stanford.edu/projects/scenedataset.html) as an example. The dataset contains 715 images chosen from existing public datasets [LabelMe](http://labelme.csail.mit.edu), [MSRC](http://research.microsoft.com/en-us/projects/objectclassrecognition), [PASCAL VOC](http://pascallin.ecs.soton.ac.uk/challenges/VOC) and [Geometric Context](http://www.cs.illinois.edu/homes/dhoiem/). Images from these datasets are mainly outdoor scenes, each containing approximately 320-by-240 pixels. \n", + "In this tutorial, we use the region annotations as labels. There are 8 classes in total, i.e. sky, tree, road, grass, water, building, mountain, and foreground object. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TFIt7MHq5Wls", + "outputId": "74a126e4-c8a4-4d2f-a910-b58b71843a23" + }, + "outputs": [], + "source": [ + "# download and unzip\n", + "!wget http://dags.stanford.edu/data/iccv09Data.tar.gz -O stanford_background.tar.gz\n", + "!tar xf stanford_background.tar.gz" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 377 + }, + "id": "78LIci7F9WWI", + "outputId": "c432ddac-5a50-47b1-daac-5a26b07afea2" + }, + "outputs": [], + "source": [ + "# Let's take a look at the dataset\n", + "import mmcv\n", + "import mmengine\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "img = mmcv.imread('iccv09Data/images/6000124.jpg')\n", + "plt.figure(figsize=(8, 6))\n", + "plt.imshow(mmcv.bgr2rgb(img))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L5mNQuc2GsVE" + }, + "source": [ + "We need to convert the annotation into semantic map format as an image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WnGZfribFHCx" + }, + "outputs": [], + "source": [ + "# define dataset root and directory for images and annotations\n", + "data_root = 'iccv09Data'\n", + "img_dir = 'images'\n", + "ann_dir = 'labels'\n", + "# define class and palette for better visualization\n", + "classes = ('sky', 'tree', 'road', 'grass', 'water', 'bldg', 'mntn', 'fg obj')\n", + "palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34], \n", + " [0, 11, 123], [118, 20, 12], [122, 81, 25], [241, 134, 51]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WnGZfribFHCx" + }, + "outputs": [], + "source": [ + "import os.path as osp\n", + "import numpy as np\n", + "from PIL import Image\n", + "\n", + "# convert dataset annotation to semantic segmentation map\n", + "for file in mmengine.scandir(osp.join(data_root, ann_dir), suffix='.regions.txt'):\n", + " seg_map = np.loadtxt(osp.join(data_root, ann_dir, file)).astype(np.uint8)\n", + " seg_img = Image.fromarray(seg_map).convert('P')\n", + " seg_img.putpalette(np.array(palette, dtype=np.uint8))\n", + " seg_img.save(osp.join(data_root, ann_dir, file.replace('.regions.txt', \n", + " '.png')))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 377 + }, + "id": "5MCSS9ABfSks", + "outputId": "92b9bafc-589e-48fc-c9e9-476f125d6522" + }, + "outputs": [], + "source": [ + "# Let's take a look at the segmentation map we got\n", + "import matplotlib.patches as mpatches\n", + "img = Image.open('iccv09Data/labels/6000124.png')\n", + "plt.figure(figsize=(8, 6))\n", + "im = plt.imshow(np.array(img.convert('RGB')))\n", + "\n", + "# create a patch (proxy artist) for every color \n", + "patches = [mpatches.Patch(color=np.array(palette[i])/255., \n", + " label=classes[i]) for i in range(8)]\n", + "# put those patched as legend-handles into the legend\n", + "plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., \n", + " fontsize='large')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WbeLYCp2k5hl" + }, + "outputs": [], + "source": [ + "# split train/val set randomly\n", + "split_dir = 'splits'\n", + "mmengine.mkdir_or_exist(osp.join(data_root, split_dir))\n", + "filename_list = [osp.splitext(filename)[0] for filename in mmengine.scandir(\n", + " osp.join(data_root, ann_dir), suffix='.png')]\n", + "with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:\n", + " # select first 4/5 as train set\n", + " train_length = int(len(filename_list)*4/5)\n", + " f.writelines(line + '\\n' for line in filename_list[:train_length])\n", + "with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:\n", + " # select last 1/5 as train set\n", + " f.writelines(line + '\\n' for line in filename_list[train_length:])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HchvmGYB_rrO" + }, + "source": [ + "After downloading the data, we need to implement `load_annotations` function in the new dataset class `StanfordBackgroundDataset`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LbsWOw62_o-X" + }, + "outputs": [], + "source": [ + "from mmseg.registry import DATASETS\n", + "from mmseg.datasets import BaseSegDataset\n", + "\n", + "\n", + "@DATASETS.register_module()\n", + "class StanfordBackgroundDataset(BaseSegDataset):\n", + " METAINFO = dict(classes = classes, palette = palette)\n", + " def __init__(self, **kwargs):\n", + " super().__init__(img_suffix='.jpg', seg_map_suffix='.png', **kwargs)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yUVtmn3Iq3WA" + }, + "source": [ + "### Create a config file\n", + "In the next step, we need to modify the config for the training. To accelerate the process, we finetune the model from trained weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download config and checkpoint files\n", + "!mim download mmsegmentation --config pspnet_r50-d8_4xb2-40k_cityscapes-512x1024 --dest ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Wwnj9tRzqX_A" + }, + "outputs": [], + "source": [ + "from mmengine import Config\n", + "cfg = Config.fromfile('configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')\n", + "print(f'Config:\\n{cfg.pretty_text}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1y2oV5w97jQo" + }, + "source": [ + "Since the given config is used to train PSPNet on the cityscapes dataset, we need to modify it accordingly for our new dataset. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eyKnYC1Z7iCV", + "outputId": "6195217b-187f-4675-994b-ba90d8bb3078" + }, + "outputs": [], + "source": [ + "# Since we use only one GPU, BN is used instead of SyncBN\n", + "cfg.norm_cfg = dict(type='BN', requires_grad=True)\n", + "cfg.crop_size = (256, 256)\n", + "cfg.model.data_preprocessor.size = cfg.crop_size\n", + "cfg.model.backbone.norm_cfg = cfg.norm_cfg\n", + "cfg.model.decode_head.norm_cfg = cfg.norm_cfg\n", + "cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg\n", + "# modify num classes of the model in decode/auxiliary head\n", + "cfg.model.decode_head.num_classes = 8\n", + "cfg.model.auxiliary_head.num_classes = 8\n", + "\n", + "# Modify dataset type and path\n", + "cfg.dataset_type = 'StanfordBackgroundDataset'\n", + "cfg.data_root = data_root\n", + "\n", + "cfg.train_dataloader.batch_size = 8\n", + "\n", + "cfg.train_pipeline = [\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='LoadAnnotations'),\n", + " dict(type='RandomResize', scale=(320, 240), ratio_range=(0.5, 2.0), keep_ratio=True),\n", + " dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),\n", + " dict(type='RandomFlip', prob=0.5),\n", + " dict(type='PackSegInputs')\n", + "]\n", + "\n", + "cfg.test_pipeline = [\n", + " dict(type='LoadImageFromFile'),\n", + " dict(type='Resize', scale=(320, 240), keep_ratio=True),\n", + " # add loading annotation after ``Resize`` because ground truth\n", + " # does not need to do resize data transform\n", + " dict(type='LoadAnnotations'),\n", + " dict(type='PackSegInputs')\n", + "]\n", + "\n", + "\n", + "cfg.train_dataloader.dataset.type = cfg.dataset_type\n", + "cfg.train_dataloader.dataset.data_root = cfg.data_root\n", + "cfg.train_dataloader.dataset.data_prefix = dict(img_path=img_dir, seg_map_path=ann_dir)\n", + "cfg.train_dataloader.dataset.pipeline = cfg.train_pipeline\n", + "cfg.train_dataloader.dataset.ann_file = 'splits/train.txt'\n", + "\n", + "cfg.val_dataloader.dataset.type = cfg.dataset_type\n", + "cfg.val_dataloader.dataset.data_root = cfg.data_root\n", + "cfg.val_dataloader.dataset.data_prefix = dict(img_path=img_dir, seg_map_path=ann_dir)\n", + "cfg.val_dataloader.dataset.pipeline = cfg.test_pipeline\n", + "cfg.val_dataloader.dataset.ann_file = 'splits/val.txt'\n", + "\n", + "cfg.test_dataloader = cfg.val_dataloader\n", + "\n", + "\n", + "# Load the pretrained weights\n", + "cfg.load_from = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'\n", + "\n", + "# Set up working dir to save files and logs.\n", + "cfg.work_dir = './work_dirs/tutorial'\n", + "\n", + "cfg.train_cfg.max_iters = 200\n", + "cfg.train_cfg.val_interval = 200\n", + "cfg.default_hooks.logger.interval = 10\n", + "cfg.default_hooks.checkpoint.interval = 200\n", + "\n", + "# Set seed to facilitate reproducing the result\n", + "cfg['randomness'] = dict(seed=0)\n", + "\n", + "# Let's have a look at the final config used for training\n", + "print(f'Config:\\n{cfg.pretty_text}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QWuH14LYF2gQ" + }, + "source": [ + "### Train and Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jYKoSfdMF12B", + "outputId": "422219ca-d7a5-4890-f09f-88c959942e64" + }, + "outputs": [], + "source": [ + "from mmengine.runner import Runner\n", + "\n", + "runner = Runner.from_cfg(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# start training\n", + "runner.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DEkWOP-NMbc_" + }, + "source": [ + "Inference with trained model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 645 + }, + "id": "ekG__UfaH_OU", + "outputId": "1437419c-869a-4902-df86-d4f6f8b2597a" + }, + "outputs": [], + "source": [ + "from mmseg.apis import init_model, inference_model, show_result_pyplot\n", + "\n", + "# Init the model from the config and the checkpoint\n", + "checkpoint_path = './work_dirs/tutorial/iter_200.pth'\n", + "model = init_model(cfg, checkpoint_path, 'cuda:0')\n", + "\n", + "img = mmcv.imread('iccv09Data/images/6000124.jpg')\n", + "result = inference_model(model, img)\n", + "plt.figure(figsize=(8, 6))\n", + "vis_result = show_result_pyplot(model, img, result)\n", + "plt.imshow(mmcv.bgr2rgb(vis_result))\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "include_colab_link": true, + "name": "MMSegmentation Tutorial.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3.10.6 ('pt1.12')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + }, + "vscode": { + "interpreter": { + "hash": "0442e67aee3d9cbb788fa6e86d60c4ffa94ad7f1943c65abfecb99a6f4696c58" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demo/classroom__rgb_00283.jpg b/demo/classroom__rgb_00283.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1df37e9248390cadd0b94d180ac3d4527b1b69c7 Binary files /dev/null and b/demo/classroom__rgb_00283.jpg differ diff --git a/demo/demo.png b/demo/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..1e82d7a0773cea14b36f0021fea603de0961b5d8 Binary files /dev/null and b/demo/demo.png differ diff --git a/demo/image_demo.py b/demo/image_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc34c80b29e98d08696d981cb5e1646821c4f72 --- /dev/null +++ b/demo/image_demo.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmengine.model import revert_sync_batchnorm + +from mmseg.apis import inference_model, init_model, show_result_pyplot + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument('--out-file', default=None, help='Path to output file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument( + '--with-labels', + action='store_true', + default=False, + help='Whether to display the class labels.') + parser.add_argument( + '--title', default='result', help='The image identifier.') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + if args.device == 'cpu': + model = revert_sync_batchnorm(model) + # test a single image + result = inference_model(model, args.img) + # show the results + show_result_pyplot( + model, + args.img, + result, + title=args.title, + opacity=args.opacity, + with_labels=args.with_labels, + draw_gt=False, + show=False if args.out_file is not None else True, + out_file=args.out_file) + + +if __name__ == '__main__': + main() diff --git a/demo/image_demo_with_inferencer.py b/demo/image_demo_with_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..d1fa9deb9e2e6558a28ddce7f714cbb2535f956a --- /dev/null +++ b/demo/image_demo_with_inferencer.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmseg.apis import MMSegInferencer + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('model', help='Config file') + parser.add_argument('--checkpoint', default=None, help='Checkpoint file') + parser.add_argument( + '--out-dir', default='', help='Path to save result file') + parser.add_argument( + '--show', + action='store_true', + default=False, + help='Whether to display the drawn image.') + parser.add_argument( + '--dataset-name', + default='cityscapes', + help='Color palette used for segmentation map') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument( + '--with-labels', + action='store_true', + default=False, + help='Whether to display the class labels.') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + mmseg_inferencer = MMSegInferencer( + args.model, + args.checkpoint, + dataset_name=args.dataset_name, + device=args.device) + + # test a single image + mmseg_inferencer( + args.img, + show=args.show, + out_dir=args.out_dir, + opacity=args.opacity, + with_labels=args.with_labels) + + +if __name__ == '__main__': + main() diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..455c5df4e1fac330c0be7b47e41d9a24cd458f87 --- /dev/null +++ b/demo/inference_demo.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir ../checkpoints\n", + "!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P ../checkpoints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from mmengine.model.utils import revert_sync_batchnorm\n", + "from mmseg.apis import init_model, inference_model, show_result_pyplot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "config_file = '../configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'\n", + "checkpoint_file = '../checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# build the model from a config file and a checkpoint file\n", + "model = init_model(config_file, checkpoint_file, device='cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test a single image\n", + "img = 'demo.png'\n", + "if not torch.cuda.is_available():\n", + " model = revert_sync_batchnorm(model)\n", + "result = inference_model(model, img)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the results\n", + "vis_result = show_result_pyplot(model, img, result, show=False)\n", + "plt.imshow(vis_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pt1.13", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + }, + "vscode": { + "interpreter": { + "hash": "f61d5b8fecdd960739697f6c2860080d7b76a5be5d896cb034bdb275ab3ddda0" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/demo/rs_image_inference.py b/demo/rs_image_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..799181f93c73480de6e73dc46f8e3ac5824e7a14 --- /dev/null +++ b/demo/rs_image_inference.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmseg.apis import RSImage, RSInferencer + + +def main(): + parser = ArgumentParser() + parser.add_argument('image', help='Image file path') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--output-path', + help='Path to save result image', + default='result.png') + parser.add_argument( + '--batch-size', + type=int, + default=1, + help='maximum number of windows inferred simultaneously') + parser.add_argument( + '--window-size', + help='window xsize,ysize', + default=(224, 224), + type=int, + nargs=2) + parser.add_argument( + '--stride', + help='window xstride,ystride', + default=(224, 224), + type=int, + nargs=2) + parser.add_argument( + '--thread', default=1, type=int, help='number of inference threads') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + inferencer = RSInferencer.from_config_path( + args.config, + args.checkpoint, + batch_size=args.batch_size, + thread=args.thread, + device=args.device) + image = RSImage(args.image) + + inferencer.run(image, args.window_size, args.stride, args.output_path) + + +if __name__ == '__main__': + main() diff --git a/demo/video_demo.py b/demo/video_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6f3d605c30142c2d7450284a3c6679e9903b70 --- /dev/null +++ b/demo/video_demo.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +import cv2 +from mmengine.model.utils import revert_sync_batchnorm + +from mmseg.apis import inference_model, init_model +from mmseg.apis.inference import show_result_pyplot + + +def main(): + parser = ArgumentParser() + parser.add_argument('video', help='Video file or webcam id') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--palette', + default='cityscapes', + help='Color palette used for segmentation map') + parser.add_argument( + '--show', action='store_true', help='Whether to show draw result') + parser.add_argument( + '--show-wait-time', default=1, type=int, help='Wait time after imshow') + parser.add_argument( + '--output-file', default=None, type=str, help='Output video file path') + parser.add_argument( + '--output-fourcc', + default='MJPG', + type=str, + help='Fourcc of the output video') + parser.add_argument( + '--output-fps', default=-1, type=int, help='FPS of the output video') + parser.add_argument( + '--output-height', + default=-1, + type=int, + help='Frame height of the output video') + parser.add_argument( + '--output-width', + default=-1, + type=int, + help='Frame width of the output video') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + args = parser.parse_args() + + assert args.show or args.output_file, \ + 'At least one output should be enabled.' + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + if args.device == 'cpu': + model = revert_sync_batchnorm(model) + + # build input video + if args.video.isdigit(): + args.video = int(args.video) + cap = cv2.VideoCapture(args.video) + assert (cap.isOpened()) + input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + input_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + input_fps = cap.get(cv2.CAP_PROP_FPS) + + # init output video + writer = None + output_height = None + output_width = None + if args.output_file is not None: + fourcc = cv2.VideoWriter_fourcc(*args.output_fourcc) + output_fps = args.output_fps if args.output_fps > 0 else input_fps + output_height = args.output_height if args.output_height > 0 else int( + input_height) + output_width = args.output_width if args.output_width > 0 else int( + input_width) + writer = cv2.VideoWriter(args.output_file, fourcc, output_fps, + (output_width, output_height), True) + + # start looping + try: + while True: + flag, frame = cap.read() + if not flag: + break + + # test a single image + result = inference_model(model, frame) + + # blend raw image and prediction + draw_img = show_result_pyplot(model, frame, result) + + if args.show: + cv2.imshow('video_demo', draw_img) + cv2.waitKey(args.show_wait_time) + if writer: + if draw_img.shape[0] != output_height or draw_img.shape[ + 1] != output_width: + draw_img = cv2.resize(draw_img, + (output_width, output_height)) + writer.write(draw_img) + finally: + if writer: + writer.release() + cap.release() + + +if __name__ == '__main__': + main() diff --git a/mmdet/.DS_Store b/mmdet/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cd5b45b2dccab2ff995d81dfaddb14f2d861ec89 Binary files /dev/null and b/mmdet/.DS_Store differ diff --git a/mmdet/__init__.py b/mmdet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac884ac8b40c1543ed840dfcafe367fbe4bda62 --- /dev/null +++ b/mmdet/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +from .version import __version__, version_info + +mmcv_minimum_version = '2.0.0rc4' +mmcv_maximum_version = '2.2.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.7.1' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<{mmengine_maximum_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmdet/__pycache__/__init__.cpython-311.pyc b/mmdet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5670b110c88efd87a1c5c450e98b27a8c52d0590 Binary files /dev/null and b/mmdet/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/__pycache__/registry.cpython-311.pyc b/mmdet/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32cc072a2ab308158c17212198b6309cd9a071e3 Binary files /dev/null and b/mmdet/__pycache__/registry.cpython-311.pyc differ diff --git a/mmdet/__pycache__/version.cpython-311.pyc b/mmdet/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..322857dc767005fd33d5de7728d4769363498108 Binary files /dev/null and b/mmdet/__pycache__/version.cpython-311.pyc differ diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c89dc72914b11a73e91dc7e9404f41bf10b93c6c --- /dev/null +++ b/mmdet/apis/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .det_inferencer import DetInferencer +from .inference import (async_inference_detector, inference_detector, + inference_mot, init_detector, init_track_model) + +__all__ = [ + 'init_detector', 'async_inference_detector', 'inference_detector', + 'DetInferencer', 'inference_mot', 'init_track_model' +] diff --git a/mmdet/apis/det_inferencer.py b/mmdet/apis/det_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..9efbb00cbe93de954fe5415b6e955a1d26908e15 --- /dev/null +++ b/mmdet/apis/det_inferencer.py @@ -0,0 +1,644 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +import warnings +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import mmcv +import mmengine +import numpy as np +import torch.nn as nn +from mmcv.transforms import LoadImageFromFile +from mmengine.dataset import Compose +from mmengine.fileio import (get_file_backend, isdir, join_path, + list_dir_or_file) +from mmengine.infer.infer import BaseInferencer, ModelType +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner.checkpoint import _load_checkpoint_to_model +from mmengine.visualization import Visualizer +from rich.progress import track + +from mmdet.evaluation import INSTANCE_OFFSET +from mmdet.registry import DATASETS +from mmdet.structures import DetDataSample +from mmdet.structures.mask import encode_mask_results, mask2bbox +from mmdet.utils import ConfigType +from ..evaluation import get_classes + +try: + from panopticapi.evaluation import VOID + from panopticapi.utils import id2rgb +except ImportError: + id2rgb = None + VOID = None + +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = List[DetDataSample] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', + '.tiff', '.webp') + + +class DetInferencer(BaseInferencer): + """Object Detection Inferencer. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. For example, it could be + "rtmdet-s" or 'rtmdet_s_8xb32-300e_coco' or + "configs/rtmdet/rtmdet_s_8xb32-300e_coco.py". + If model is not specified, user must provide the + `weights` saved by MMEngine which contains the config string. + Defaults to None. + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to mmdet. + palette (str): Color palette used for visualization. The order of + priority is palette -> config -> checkpoint. Defaults to 'none'. + show_progress (bool): Control whether to display the progress + bar during the inference process. Defaults to True. + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', + 'show', + 'wait_time', + 'draw_pred', + 'pred_score_thr', + 'img_out_dir', + 'no_save_vis', + } + postprocess_kwargs: set = { + 'print_result', + 'pred_out_dir', + 'return_datasamples', + 'no_save_pred', + } + + def __init__(self, + model: Optional[Union[ModelType, str]] = None, + weights: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmdet', + palette: str = 'none', + show_progress: bool = True) -> None: + # A global counter tracking the number of images processed, for + # naming of the output images + self.num_visualized_imgs = 0 + self.num_predicted_imgs = 0 + self.palette = palette + init_default_scope(scope) + super().__init__( + model=model, weights=weights, device=device, scope=scope) + self.model = revert_sync_batchnorm(self.model) + self.show_progress = show_progress + + def _load_weights_to_model(self, model: nn.Module, + checkpoint: Optional[dict], + cfg: Optional[ConfigType]) -> None: + """Loading model weights and meta information from cfg and checkpoint. + + Args: + model (nn.Module): Model to load weights and meta information. + checkpoint (dict, optional): The loaded checkpoint. + cfg (Config or ConfigDict, optional): The loaded config. + """ + + if checkpoint is not None: + _load_checkpoint_to_model(model, checkpoint) + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmdet 3.x, all keys should be lowercase + model.dataset_meta = { + k.lower(): v + for k, v in checkpoint_meta['dataset_meta'].items() + } + elif 'CLASSES' in checkpoint_meta: + # < mmdet 3.x + classes = checkpoint_meta['CLASSES'] + model.dataset_meta = {'classes': classes} + else: + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, use COCO classes by default.') + model.dataset_meta = {'classes': get_classes('coco')} + else: + warnings.warn('Checkpoint is not loaded, and the inference ' + 'result is calculated by the randomly initialized ' + 'model!') + warnings.warn('weights is None, use COCO classes by default.') + model.dataset_meta = {'classes': get_classes('coco')} + + # Priority: args.palette -> config -> checkpoint + if self.palette != 'none': + model.dataset_meta['palette'] = self.palette + else: + test_dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset) + # lazy init. We only need the metainfo. + test_dataset_cfg['lazy_init'] = True + metainfo = DATASETS.build(test_dataset_cfg).metainfo + cfg_palette = metainfo.get('palette', None) + if cfg_palette is not None: + model.dataset_meta['palette'] = cfg_palette + else: + if 'palette' not in model.dataset_meta: + warnings.warn( + 'palette does not exist, random is used by default. ' + 'You can also set the palette to customize.') + model.dataset_meta['palette'] = 'random' + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + + # For inference, the key of ``img_id`` is not used. + if 'meta_keys' in pipeline_cfg[-1]: + pipeline_cfg[-1]['meta_keys'] = tuple( + meta_key for meta_key in pipeline_cfg[-1]['meta_keys'] + if meta_key != 'img_id') + + load_img_idx = self._get_transform_idx( + pipeline_cfg, ('LoadImageFromFile', LoadImageFromFile)) + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader' + return Compose(pipeline_cfg) + + def _get_transform_idx(self, pipeline_cfg: ConfigType, + name: Union[str, Tuple[str, type]]) -> int: + """Returns the index of the transform in a pipeline. + + If the transform is not found, returns -1. + """ + for i, transform in enumerate(pipeline_cfg): + if transform['type'] in name: + return i + return -1 + + def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]: + """Initialize visualizers. + + Args: + cfg (ConfigType): Config containing the visualizer information. + + Returns: + Visualizer or None: Visualizer initialized with config. + """ + visualizer = super()._init_visualizer(cfg) + visualizer.dataset_meta = self.model.dataset_meta + return visualizer + + def _inputs_to_list(self, inputs: InputsType) -> list: + """Preprocess the inputs to a list. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - str: + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (InputsType): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if isinstance(inputs, str): + backend = get_file_backend(inputs) + if hasattr(backend, 'isdir') and isdir(inputs): + # Backends like HttpsBackend do not implement `isdir`, so only + # those backends that implement `isdir` could accept the inputs + # as a directory + filename_list = list_dir_or_file( + inputs, list_dir=False, suffix=IMG_EXTENSIONS) + inputs = [ + join_path(inputs, filename) for filename in filename_list + ] + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + + def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): + """Process the inputs into a model-feedable format. + + Customize your preprocess by overriding this method. Preprocess should + return an iterable object, of which each item will be used as the + input of ``model.test_step``. + + ``BaseInferencer.preprocess`` will return an iterable chunked data, + which will be used in __call__ like this: + + .. code-block:: python + + def __call__(self, inputs, batch_size=1, **kwargs): + chunked_data = self.preprocess(inputs, batch_size, **kwargs) + for batch in chunked_data: + preds = self.forward(batch, **kwargs) + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``collate_fn``. + """ + chunked_data = self._get_chunk_data(inputs, batch_size) + yield from map(self.collate_fn, chunked_data) + + def _get_chunk_data(self, inputs: Iterable, chunk_size: int): + """Get batch data from inputs. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + inputs_ = next(inputs_iter) + if isinstance(inputs_, dict): + if 'img' in inputs_: + ori_inputs_ = inputs_['img'] + else: + ori_inputs_ = inputs_['img_path'] + chunk_data.append( + (ori_inputs_, + self.pipeline(copy.deepcopy(inputs_)))) + else: + chunk_data.append((inputs_, self.pipeline(inputs_))) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + + # TODO: Video and Webcam are currently not supported and + # may consume too much memory if your input folder has a lot of images. + # We will be optimized later. + def __call__( + self, + inputs: InputsType, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + no_save_vis: bool = False, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + return_datasamples: bool = False, + print_result: bool = False, + no_save_pred: bool = True, + out_dir: str = '', + # by open image task + texts: Optional[Union[str, list]] = None, + # by open panoptic task + stuff_texts: Optional[Union[str, list]] = None, + # by GLIP + custom_entities: bool = False, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + batch_size (int): Inference batch size. Defaults to 1. + show (bool): Whether to display the visualization results in a + popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + no_save_vis (bool): Whether to force not to save prediction + vis results. Defaults to False. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + pred_score_thr (float): Minimum score of bboxes to draw. + Defaults to 0.3. + return_datasamples (bool): Whether to return results as + :obj:`DetDataSample`. Defaults to False. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + no_save_pred (bool): Whether to force not to save prediction + results. Defaults to True. + out_dir: Dir to save the inference results or + visualization. If left as empty, no file will be saved. + Defaults to ''. + texts (str | list[str]): Text prompts. Defaults to None. + stuff_texts (str | list[str]): Stuff text prompts of open + panoptic task. Defaults to None. + custom_entities (bool): Whether to use custom entities. + Defaults to False. Only used in GLIP. + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + + if texts is not None and isinstance(texts, str): + texts = [texts] * len(ori_inputs) + if stuff_texts is not None and isinstance(stuff_texts, str): + stuff_texts = [stuff_texts] * len(ori_inputs) + if texts is not None: + assert len(texts) == len(ori_inputs) + for i in range(len(texts)): + if isinstance(ori_inputs[i], str): + ori_inputs[i] = { + 'text': texts[i], + 'img_path': ori_inputs[i], + 'custom_entities': custom_entities + } + else: + ori_inputs[i] = { + 'text': texts[i], + 'img': ori_inputs[i], + 'custom_entities': custom_entities + } + if stuff_texts is not None: + assert len(stuff_texts) == len(ori_inputs) + for i in range(len(stuff_texts)): + ori_inputs[i]['stuff_text'] = stuff_texts[i] + + inputs = self.preprocess( + ori_inputs, batch_size=batch_size, **preprocess_kwargs) + + results_dict = {'predictions': [], 'visualization': []} + for ori_imgs, data in (track(inputs, description='Inference') + if self.show_progress else inputs): + preds = self.forward(data, **forward_kwargs) + visualization = self.visualize( + ori_imgs, + preds, + return_vis=return_vis, + show=show, + wait_time=wait_time, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + no_save_vis=no_save_vis, + img_out_dir=out_dir, + **visualize_kwargs) + results = self.postprocess( + preds, + visualization, + return_datasamples=return_datasamples, + print_result=print_result, + no_save_pred=no_save_pred, + pred_out_dir=out_dir, + **postprocess_kwargs) + results_dict['predictions'].extend(results['predictions']) + if results['visualization'] is not None: + results_dict['visualization'].extend(results['visualization']) + return results_dict + + def visualize(self, + inputs: InputsType, + preds: PredType, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + no_save_vis: bool = False, + img_out_dir: str = '', + **kwargs) -> Union[List[np.ndarray], None]: + """Visualize predictions. + + Args: + inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. + preds (List[:obj:`DetDataSample`]): Predictions of the model. + return_vis (bool): Whether to return the visualization result. + Defaults to False. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + pred_score_thr (float): Minimum score of bboxes to draw. + Defaults to 0.3. + no_save_vis (bool): Whether to force not to save prediction + vis results. Defaults to False. + img_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. + """ + if no_save_vis is True: + img_out_dir = '' + + if not show and img_out_dir == '' and not return_vis: + return None + + if self.visualizer is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type: ' + f'{type(single_input)}') + + out_file = osp.join(img_out_dir, 'vis', + img_name) if img_out_dir != '' else None + + self.visualizer.add_datasample( + img_name, + img, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + out_file=out_file, + ) + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results + + def postprocess( + self, + preds: PredType, + visualization: Optional[List[np.ndarray]] = None, + return_datasamples: bool = False, + print_result: bool = False, + no_save_pred: bool = False, + pred_out_dir: str = '', + **kwargs, + ) -> Dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Args: + preds (List[:obj:`DetDataSample`]): Predictions of the model. + visualization (Optional[np.ndarray]): Visualized predictions. + return_datasamples (bool): Whether to use Datasample to store + inference results. If False, dict will be used. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + no_save_pred (bool): Whether to force not to save prediction + results. Defaults to False. + pred_out_dir: Dir to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization``. + + - ``visualization`` (Any): Returned by :meth:`visualize`. + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasamples=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + if no_save_pred is True: + pred_out_dir = '' + + result_dict = {} + results = preds + if not return_datasamples: + results = [] + for pred in preds: + result = self.pred2dict(pred, pred_out_dir) + results.append(result) + elif pred_out_dir != '': + warnings.warn('Currently does not support saving datasample ' + 'when return_datasamples is set to True. ' + 'Prediction results are not saved!') + # Add img to the results after printing and dumping + result_dict['predictions'] = results + if print_result: + print(result_dict) + result_dict['visualization'] = visualization + return result_dict + + # TODO: The data format and fields saved in json need further discussion. + # Maybe should include model name, timestamp, filename, image info etc. + def pred2dict(self, + data_sample: DetDataSample, + pred_out_dir: str = '') -> Dict: + """Extract elements necessary to represent a prediction into a + dictionary. + + It's better to contain only basic data elements such as strings and + numbers in order to guarantee it's json-serializable. + + Args: + data_sample (:obj:`DetDataSample`): Predictions of the model. + pred_out_dir: Dir to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Prediction results. + """ + is_save_pred = True + if pred_out_dir == '': + is_save_pred = False + + if is_save_pred and 'img_path' in data_sample: + img_path = osp.basename(data_sample.img_path) + img_path = osp.splitext(img_path)[0] + out_img_path = osp.join(pred_out_dir, 'preds', + img_path + '_panoptic_seg.png') + out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json') + elif is_save_pred: + out_img_path = osp.join( + pred_out_dir, 'preds', + f'{self.num_predicted_imgs}_panoptic_seg.png') + out_json_path = osp.join(pred_out_dir, 'preds', + f'{self.num_predicted_imgs}.json') + self.num_predicted_imgs += 1 + + result = {} + if 'pred_instances' in data_sample: + masks = data_sample.pred_instances.get('masks') + pred_instances = data_sample.pred_instances.numpy() + result = { + 'labels': pred_instances.labels.tolist(), + 'scores': pred_instances.scores.tolist() + } + if 'bboxes' in pred_instances: + result['bboxes'] = pred_instances.bboxes.tolist() + if masks is not None: + if 'bboxes' not in pred_instances or pred_instances.bboxes.sum( + ) == 0: + # Fake bbox, such as the SOLO. + bboxes = mask2bbox(masks.cpu()).numpy().tolist() + result['bboxes'] = bboxes + encode_masks = encode_mask_results(pred_instances.masks) + for encode_mask in encode_masks: + if isinstance(encode_mask['counts'], bytes): + encode_mask['counts'] = encode_mask['counts'].decode() + result['masks'] = encode_masks + + if 'pred_panoptic_seg' in data_sample: + if VOID is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0] + pan[pan % INSTANCE_OFFSET == len( + self.model.dataset_meta['classes'])] = VOID + pan = id2rgb(pan).astype(np.uint8) + + if is_save_pred: + mmcv.imwrite(pan[:, :, ::-1], out_img_path) + result['panoptic_seg_path'] = out_img_path + else: + result['panoptic_seg'] = pan + + if is_save_pred: + mmengine.dump(result, out_json_path) + + return result diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6f914ecabf4b9c110a4fd15310bc97d0197db9 --- /dev/null +++ b/mmdet/apis/inference.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from pathlib import Path +from typing import Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.ops import RoIPool +from mmcv.transforms import Compose +from mmengine.config import Config +from mmengine.dataset import default_collate +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint + +from mmdet.registry import DATASETS +from mmdet.utils import ConfigType +from ..evaluation import get_classes +from ..registry import MODELS +from ..structures import DetDataSample, SampleList +from ..utils import get_test_pipeline_cfg + + +def init_detector( + config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + palette: str = 'none', + device: str = 'cuda:0', + cfg_options: Optional[dict] = None, +) -> nn.Module: + """Initialize a detector from config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + palette (str): Color palette used for visualization. If palette + is stored in checkpoint, use checkpoint's palette first, otherwise + use externally passed palette. Currently, supports 'coco', 'voc', + 'citys' and 'random'. Defaults to none. + device (str): The device where the anchors will be put on. + Defaults to cuda:0. + cfg_options (dict, optional): Options to override some settings in + the used config. + + Returns: + nn.Module: The constructed detector. + """ + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None + + scope = config.get('default_scope', 'mmdet') + if scope is not None: + init_default_scope(config.get('default_scope', 'mmdet')) + + model = MODELS.build(config.model) + model = revert_sync_batchnorm(model) + if checkpoint is None: + warnings.simplefilter('once') + warnings.warn('checkpoint is None, use COCO classes by default.') + model.dataset_meta = {'classes': get_classes('coco')} + else: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + # Weights converted from elsewhere may not have meta fields. + checkpoint_meta = checkpoint.get('meta', {}) + + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmdet 3.x, all keys should be lowercase + model.dataset_meta = { + k.lower(): v + for k, v in checkpoint_meta['dataset_meta'].items() + } + elif 'CLASSES' in checkpoint_meta: + # < mmdet 3.x + classes = checkpoint_meta['CLASSES'] + model.dataset_meta = {'classes': classes} + else: + warnings.simplefilter('once') + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, use COCO classes by default.') + model.dataset_meta = {'classes': get_classes('coco')} + + # Priority: args.palette -> config -> checkpoint + if palette != 'none': + model.dataset_meta['palette'] = palette + else: + test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset) + # lazy init. We only need the metainfo. + test_dataset_cfg['lazy_init'] = True + metainfo = DATASETS.build(test_dataset_cfg).metainfo + cfg_palette = metainfo.get('palette', None) + if cfg_palette is not None: + model.dataset_meta['palette'] = cfg_palette + else: + if 'palette' not in model.dataset_meta: + warnings.warn( + 'palette does not exist, random is used by default. ' + 'You can also set the palette to customize.') + model.dataset_meta['palette'] = 'random' + + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def inference_detector( + model: nn.Module, + imgs: ImagesType, + test_pipeline: Optional[Compose] = None, + text_prompt: Optional[str] = None, + custom_entities: bool = False, +) -> Union[DetDataSample, SampleList]: + """Inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + imgs (str, ndarray, Sequence[str/ndarray]): + Either image files or loaded images. + test_pipeline (:obj:`Compose`): Test pipeline. + + Returns: + :obj:`DetDataSample` or list[:obj:`DetDataSample`]: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the detection results directly. + """ + + if isinstance(imgs, (list, tuple)): + is_batch = True + else: + imgs = [imgs] + is_batch = False + + cfg = model.cfg + + if test_pipeline is None: + cfg = cfg.copy() + test_pipeline = get_test_pipeline_cfg(cfg) + if isinstance(imgs[0], np.ndarray): + # Calling this method across libraries will result + # in module unregistered error if not prefixed with mmdet. + test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' + + test_pipeline = Compose(test_pipeline) + + if model.data_preprocessor.device.type == 'cpu': + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), 'CPU inference with RoIPool is not supported currently.' + + result_list = [] + for i, img in enumerate(imgs): + # prepare data + if isinstance(img, np.ndarray): + # TODO: remove img_id. + data_ = dict(img=img, img_id=0) + else: + # TODO: remove img_id. + data_ = dict(img_path=img, img_id=0) + + if text_prompt: + data_['text'] = text_prompt + data_['custom_entities'] = custom_entities + + # build the data pipeline + data_ = test_pipeline(data_) + + data_['inputs'] = [data_['inputs']] + data_['data_samples'] = [data_['data_samples']] + + # forward the model + with torch.no_grad(): + results = model.test_step(data_)[0] + + result_list.append(results) + + if not is_batch: + return result_list[0] + else: + return result_list + + +# TODO: Awaiting refactoring +async def async_inference_detector(model, imgs): + """Async inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + img (str | ndarray): Either image files or loaded images. + + Returns: + Awaitable detection results. + """ + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + + cfg = model.cfg + + if isinstance(imgs[0], np.ndarray): + cfg = cfg.copy() + # set loading pipeline type + cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray' + + # cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + + datas = [] + for img in imgs: + # prepare data + if isinstance(img, np.ndarray): + # directly add img + data = dict(img=img) + else: + # add information into dict + data = dict(img_info=dict(filename=img), img_prefix=None) + # build the data pipeline + data = test_pipeline(data) + datas.append(data) + + for m in model.modules(): + assert not isinstance( + m, + RoIPool), 'CPU inference with RoIPool is not supported currently.' + + # We don't restore `torch.is_grad_enabled()` value during concurrent + # inference since execution can overlap + torch.set_grad_enabled(False) + results = await model.aforward_test(data, rescale=True) + return results + + +def build_test_pipeline(cfg: ConfigType) -> ConfigType: + """Build test_pipeline for mot/vis demo. In mot/vis infer, original + test_pipeline should remove the "LoadImageFromFile" and + "LoadTrackAnnotations". + + Args: + cfg (ConfigDict): The loaded config. + Returns: + ConfigType: new test_pipeline + """ + # remove the "LoadImageFromFile" and "LoadTrackAnnotations" in pipeline + transform_broadcaster = cfg.test_dataloader.dataset.pipeline[0].copy() + for transform in transform_broadcaster['transforms']: + if transform['type'] == 'Resize': + transform_broadcaster['transforms'] = transform + pack_track_inputs = cfg.test_dataloader.dataset.pipeline[-1].copy() + test_pipeline = Compose([transform_broadcaster, pack_track_inputs]) + + return test_pipeline + + +def inference_mot(model: nn.Module, img: np.ndarray, frame_id: int, + video_len: int) -> SampleList: + """Inference image(s) with the mot model. + + Args: + model (nn.Module): The loaded mot model. + img (np.ndarray): Loaded image. + frame_id (int): frame id. + video_len (int): demo video length + Returns: + SampleList: The tracking data samples. + """ + cfg = model.cfg + data = dict( + img=[img.astype(np.float32)], + frame_id=[frame_id], + ori_shape=[img.shape[:2]], + img_id=[frame_id + 1], + ori_video_length=[video_len]) + + test_pipeline = build_test_pipeline(cfg) + data = test_pipeline(data) + + if not next(model.parameters()).is_cuda: + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), 'CPU inference with RoIPool is not supported currently.' + + # forward the model + with torch.no_grad(): + data = default_collate([data]) + result = model.test_step(data)[0] + return result + + +def init_track_model(config: Union[str, Config], + checkpoint: Optional[str] = None, + detector: Optional[str] = None, + reid: Optional[str] = None, + device: str = 'cuda:0', + cfg_options: Optional[dict] = None) -> nn.Module: + """Initialize a model from config file. + + Args: + config (str or :obj:`mmengine.Config`): Config file path or the config + object. + checkpoint (Optional[str], optional): Checkpoint path. Defaults to + None. + detector (Optional[str], optional): Detector Checkpoint path, use in + some tracking algorithms like sort. Defaults to None. + reid (Optional[str], optional): Reid checkpoint path. use in + some tracking algorithms like sort. Defaults to None. + device (str, optional): The device that the model inferences on. + Defaults to `cuda:0`. + cfg_options (Optional[dict], optional): Options to override some + settings in the used config. Defaults to None. + + Returns: + nn.Module: The constructed model. + """ + if isinstance(config, str): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if cfg_options is not None: + config.merge_from_dict(cfg_options) + + model = MODELS.build(config.model) + + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + # Weights converted from elsewhere may not have meta fields. + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + if 'CLASSES' in checkpoint_meta['dataset_meta']: + value = checkpoint_meta['dataset_meta'].pop('CLASSES') + checkpoint_meta['dataset_meta']['classes'] = value + model.dataset_meta = checkpoint_meta['dataset_meta'] + + if detector is not None: + assert not (checkpoint and detector), \ + 'Error: checkpoint and detector checkpoint cannot both exist' + load_checkpoint(model.detector, detector, map_location='cpu') + + if reid is not None: + assert not (checkpoint and reid), \ + 'Error: checkpoint and reid checkpoint cannot both exist' + load_checkpoint(model.reid, reid, map_location='cpu') + + # Some methods don't load checkpoints or checkpoints don't contain + # 'dataset_meta' + # VIS need dataset_meta, MOT don't need dataset_meta + if not hasattr(model, 'dataset_meta'): + warnings.warn('dataset_meta or class names are missed, ' + 'use None by default.') + model.dataset_meta = {'classes': None} + + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model diff --git a/mmdet/configs/.DS_Store b/mmdet/configs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..296c6ca071479e949bc813bd714d4519027fd1e7 Binary files /dev/null and b/mmdet/configs/.DS_Store differ diff --git a/mmdet/configs/_base_/datasets/coco_detection.py b/mmdet/configs/_base_/datasets/coco_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..45041f6d236be95eb7592035d31f155c61bfcb25 --- /dev/null +++ b/mmdet/configs/_base_/datasets/coco_detection.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import LoadImageFromFile +from mmengine.dataset.sampler import DefaultSampler + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs, + RandomFlip, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator + +# inference on test dataset and +# format the output results for submission. +# test_dataloader = dict( +# batch_size=1, +# num_workers=2, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type=DefaultSampler, shuffle=False), +# dataset=dict( +# type=dataset_type, +# data_root=data_root, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# data_prefix=dict(img='test2017/'), +# test_mode=True, +# pipeline=test_pipeline)) +# test_evaluator = dict( +# type=CocoMetric, +# metric='bbox', +# format_only=True, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# outfile_prefix='./work_dirs/coco_detection/test') diff --git a/mmdet/configs/_base_/datasets/coco_instance.py b/mmdet/configs/_base_/datasets/coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..b9575432e26b7e861c4dfcf535773b7a1990eeab --- /dev/null +++ b/mmdet/configs/_base_/datasets/coco_instance.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms.loading import LoadImageFromFile +from mmengine.dataset.sampler import DefaultSampler + +from mmdet.datasets.coco import CocoDataset +from mmdet.datasets.samplers.batch_sampler import AspectRatioBatchSampler +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import RandomFlip, Resize +from mmdet.evaluation.metrics.coco_metric import CocoMetric + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=CocoDataset, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=CocoDataset, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator + +# inference on test dataset and +# format the output results for submission. +# test_dataloader = dict( +# batch_size=1, +# num_workers=2, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type=DefaultSampler, shuffle=False), +# dataset=dict( +# type=CocoDataset, +# data_root=data_root, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# data_prefix=dict(img='test2017/'), +# test_mode=True, +# pipeline=test_pipeline)) +# test_evaluator = dict( +# type=CocoMetric, +# metric=['bbox', 'segm'], +# format_only=True, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# outfile_prefix='./work_dirs/coco_instance/test') diff --git a/mmdet/configs/_base_/datasets/coco_instance_semantic.py b/mmdet/configs/_base_/datasets/coco_instance_semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf5b2cfab8a98a6c97e23a8df663e8f1e90b355 --- /dev/null +++ b/mmdet/configs/_base_/datasets/coco_instance_semantic.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms.loading import LoadImageFromFile +from mmengine.dataset.sampler import DefaultSampler + +from mmdet.datasets.coco import CocoDataset +from mmdet.datasets.samplers.batch_sampler import AspectRatioBatchSampler +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import RandomFlip, Resize +from mmdet.evaluation.metrics.coco_metric import CocoMetric + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True, with_seg=True), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type=LoadAnnotations, with_bbox=True, with_mask=True, with_seg=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=CocoDataset, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/', seg='stuffthingmaps/train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=CocoDataset, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) + +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator diff --git a/mmdet/configs/_base_/datasets/coco_panoptic.py b/mmdet/configs/_base_/datasets/coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..29d655ff619c74c5976d5f06c0c623a0d3459997 --- /dev/null +++ b/mmdet/configs/_base_/datasets/coco_panoptic.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms.loading import LoadImageFromFile +from mmengine.dataset.sampler import DefaultSampler + +from mmdet.datasets.coco_panoptic import CocoPanopticDataset +from mmdet.datasets.samplers.batch_sampler import AspectRatioBatchSampler +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadPanopticAnnotations +from mmdet.datasets.transforms.transforms import RandomFlip, Resize +from mmdet.evaluation.metrics.coco_panoptic_metric import CocoPanopticMetric + +# dataset settings +dataset_type = 'CocoPanopticDataset' +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadPanopticAnnotations, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadPanopticAnnotations, backend_args=backend_args), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=CocoPanopticDataset, + data_root=data_root, + ann_file='annotations/panoptic_train2017.json', + data_prefix=dict( + img='train2017/', seg='annotations/panoptic_train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=CocoPanopticDataset, + data_root=data_root, + ann_file='annotations/panoptic_val2017.json', + data_prefix=dict(img='val2017/', seg='annotations/panoptic_val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoPanopticMetric, + ann_file=data_root + 'annotations/panoptic_val2017.json', + seg_prefix=data_root + 'annotations/panoptic_val2017/', + backend_args=backend_args) +test_evaluator = val_evaluator + +# inference on test dataset and +# format the output results for submission. +# test_dataloader = dict( +# batch_size=1, +# num_workers=1, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type=DefaultSampler, shuffle=False), +# dataset=dict( +# type=CocoPanopticDataset, +# data_root=data_root, +# ann_file='annotations/panoptic_image_info_test-dev2017.json', +# data_prefix=dict(img='test2017/'), +# test_mode=True, +# pipeline=test_pipeline)) +# test_evaluator = dict( +# type=CocoPanopticMetric, +# format_only=True, +# ann_file=data_root + 'annotations/panoptic_image_info_test-dev2017.json', +# outfile_prefix='./work_dirs/coco_panoptic/test') diff --git a/mmdet/configs/_base_/datasets/mot_challenge.py b/mmdet/configs/_base_/datasets/mot_challenge.py new file mode 100644 index 0000000000000000000000000000000000000000..a71520a84e52a812f83862920040d96746829285 --- /dev/null +++ b/mmdet/configs/_base_/datasets/mot_challenge.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import (LoadImageFromFile, RandomResize, + TransformBroadcaster) + +from mmdet.datasets import MOTChallengeDataset +from mmdet.datasets.samplers import TrackImgSampler +from mmdet.datasets.transforms import (LoadTrackAnnotations, PackTrackInputs, + PhotoMetricDistortion, RandomCrop, + RandomFlip, Resize, + UniformRefFrameSample) +from mmdet.evaluation import MOTChallengeMetric + +# dataset settings +dataset_type = MOTChallengeDataset +data_root = 'data/MOT17/' +img_scale = (1088, 1088) + +backend_args = None +# data pipeline +train_pipeline = [ + dict( + type=UniformRefFrameSample, + num_ref_imgs=1, + frame_range=10, + filter_key_img=True), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadTrackAnnotations), + dict( + type=RandomResize, + scale=img_scale, + ratio_range=(0.8, 1.2), + keep_ratio=True, + clip_object_border=False), + dict(type=PhotoMetricDistortion) + ]), + dict( + type=TransformBroadcaster, + # different cropped positions for different frames + share_random_params=False, + transforms=[ + dict(type=RandomCrop, crop_size=img_scale, bbox_clip_border=False) + ]), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=RandomFlip, prob=0.5), + ]), + dict(type=PackTrackInputs) +] + +test_pipeline = [ + dict( + type=TransformBroadcaster, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=img_scale, keep_ratio=True), + dict(type=LoadTrackAnnotations) + ]), + dict(type=PackTrackInputs) +] + +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=TrackImgSampler), # image-based sampling + dataset=dict( + type=dataset_type, + data_root=data_root, + visibility_thr=-1, + ann_file='annotations/half-train_cocoformat.json', + data_prefix=dict(img_path='train'), + metainfo=dict(classes=('pedestrian', )), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + # Now we support two ways to test, image_based and video_based + # if you want to use video_based sampling, you can use as follows + # sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + sampler=dict(type=TrackImgSampler), # image-based sampling + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/half-val_cocoformat.json', + data_prefix=dict(img_path='train'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# evaluator +val_evaluator = dict( + type=MOTChallengeMetric, metric=['HOTA', 'CLEAR', 'Identity']) +test_evaluator = val_evaluator diff --git a/mmdet/configs/_base_/default_runtime.py b/mmdet/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..ff96dbf29f3c90266a268d3831878b0a437d98b2 --- /dev/null +++ b/mmdet/configs/_base_/default_runtime.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.runner import LogProcessor +from mmengine.visualization import LocalVisBackend + +from mmdet.engine.hooks import DetVisualizationHook +from mmdet.visualization import DetLocalVisualizer + +default_scope = None + +default_hooks = dict( + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=50), + param_scheduler=dict(type=ParamSchedulerHook), + checkpoint=dict(type=CheckpointHook, interval=1), + sampler_seed=dict(type=DistSamplerSeedHook), + visualization=dict(type=DetVisualizationHook)) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict( + type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer') +log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/mmdet/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py b/mmdet/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..b9132ac40330c67e03ebc608f9527c678c72210e --- /dev/null +++ b/mmdet/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import RoIAlign, nms +from torch.nn import BatchNorm2d + +from mmdet.models.backbones.resnet import ResNet +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.rpn_head import RPNHead +from mmdet.models.detectors.cascade_rcnn import CascadeRCNN +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.losses.smooth_l1_loss import SmoothL1Loss +from mmdet.models.necks.fpn import FPN +from mmdet.models.roi_heads.bbox_heads.convfc_bbox_head import \ + Shared2FCBBoxHead +from mmdet.models.roi_heads.cascade_roi_head import CascadeRoIHead +from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead +from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \ + SingleRoIExtractor +from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \ + DeltaXYWHBBoxCoder +from mmdet.models.task_modules.prior_generators.anchor_generator import \ + AnchorGenerator +from mmdet.models.task_modules.samplers.random_sampler import RandomSampler + +# model settings +model = dict( + type=CascadeRCNN, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type=FPN, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type=RPNHead, + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type=AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type=CascadeRoIHead, + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)), + dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)), + dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)) + ], + mask_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type=FCNMaskHead, + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type=CrossEntropyLoss, use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False) + ]), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type=nms, iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/mmdet/configs/_base_/models/cascade_rcnn_r50_fpn.py b/mmdet/configs/_base_/models/cascade_rcnn_r50_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6654f381f4993a57b81e6ed1f86c0558b56616 --- /dev/null +++ b/mmdet/configs/_base_/models/cascade_rcnn_r50_fpn.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import RoIAlign, nms +from torch.nn import BatchNorm2d + +from mmdet.models.backbones.resnet import ResNet +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.rpn_head import RPNHead +from mmdet.models.detectors.cascade_rcnn import CascadeRCNN +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.losses.smooth_l1_loss import SmoothL1Loss +from mmdet.models.necks.fpn import FPN +from mmdet.models.roi_heads.bbox_heads.convfc_bbox_head import \ + Shared2FCBBoxHead +from mmdet.models.roi_heads.cascade_roi_head import CascadeRoIHead +from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \ + SingleRoIExtractor +from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \ + DeltaXYWHBBoxCoder +from mmdet.models.task_modules.prior_generators.anchor_generator import \ + AnchorGenerator +from mmdet.models.task_modules.samplers.random_sampler import RandomSampler + +# model settings +model = dict( + type=CascadeRCNN, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type=FPN, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type=RPNHead, + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type=AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type=CascadeRoIHead, + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)), + dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)), + dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)) + ]), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False) + ]), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type=nms, iou_threshold=0.5), + max_per_img=100))) diff --git a/mmdet/configs/_base_/models/faster_rcnn_r50_fpn.py b/mmdet/configs/_base_/models/faster_rcnn_r50_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..7e18de2224d5b4d2cd16a930daf3a9b360455b36 --- /dev/null +++ b/mmdet/configs/_base_/models/faster_rcnn_r50_fpn.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import RoIAlign, nms +from torch.nn import BatchNorm2d + +from mmdet.models.backbones.resnet import ResNet +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.rpn_head import RPNHead +from mmdet.models.detectors.faster_rcnn import FasterRCNN +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.losses.smooth_l1_loss import L1Loss +from mmdet.models.necks.fpn import FPN +from mmdet.models.roi_heads.bbox_heads.convfc_bbox_head import \ + Shared2FCBBoxHead +from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \ + SingleRoIExtractor +from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead +from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \ + DeltaXYWHBBoxCoder +from mmdet.models.task_modules.prior_generators.anchor_generator import \ + AnchorGenerator +from mmdet.models.task_modules.samplers.random_sampler import RandomSampler + +# model settings +model = dict( + type=FasterRCNN, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type=FPN, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type=RPNHead, + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type=AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0)), + roi_head=dict( + type=StandardRoIHead, + bbox_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type=nms, iou_threshold=0.5), + max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + )) diff --git a/mmdet/configs/_base_/models/mask_rcnn_r50_caffe_c4.py b/mmdet/configs/_base_/models/mask_rcnn_r50_caffe_c4.py new file mode 100644 index 0000000000000000000000000000000000000000..3054818375f708826ee41901650a11bbbe3afca9 --- /dev/null +++ b/mmdet/configs/_base_/models/mask_rcnn_r50_caffe_c4.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import RoIAlign, nms +from mmengine.model.weight_init import PretrainedInit +from torch.nn import BatchNorm2d + +from mmdet.models.backbones.resnet import ResNet +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.rpn_head import RPNHead +from mmdet.models.detectors.mask_rcnn import MaskRCNN +from mmdet.models.layers import ResLayer +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.losses.smooth_l1_loss import L1Loss +from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead +from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead +from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \ + SingleRoIExtractor +from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead +from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \ + DeltaXYWHBBoxCoder +from mmdet.models.task_modules.prior_generators.anchor_generator import \ + AnchorGenerator +from mmdet.models.task_modules.samplers.random_sampler import RandomSampler + +# model settings +norm_cfg = dict(type=BatchNorm2d, requires_grad=False) +# model settings +model = dict( + type=MaskRCNN, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_mask=True, + pad_size_divisor=32), + backbone=dict( + type=ResNet, + depth=50, + num_stages=3, + strides=(1, 2, 2), + dilations=(1, 1, 1), + out_indices=(2, ), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + norm_eval=True, + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet50_caffe')), + rpn_head=dict( + type=RPNHead, + in_channels=1024, + feat_channels=1024, + anchor_generator=dict( + type=AnchorGenerator, + scales=[2, 4, 8, 16, 32], + ratios=[0.5, 1.0, 2.0], + strides=[16]), + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0)), + roi_head=dict( + type=StandardRoIHead, + shared_head=dict( + type=ResLayer, + depth=50, + stage=3, + stride=2, + dilation=1, + style='caffe', + norm_cfg=norm_cfg, + norm_eval=True), + bbox_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=14, sampling_ratio=0), + out_channels=1024, + featmap_strides=[16]), + bbox_head=dict( + type=BBoxHead, + with_avg_pool=True, + roi_feat_size=7, + in_channels=2048, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0)), + mask_roi_extractor=None, + mask_head=dict( + type=FCNMaskHead, + num_convs=0, + in_channels=2048, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type=CrossEntropyLoss, use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=12000, + max_per_img=2000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=14, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=6000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type=nms, iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/mmdet/configs/_base_/models/mask_rcnn_r50_fpn.py b/mmdet/configs/_base_/models/mask_rcnn_r50_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a0b031da51c8147c8ed5c5f29502bd0c4bbe7f --- /dev/null +++ b/mmdet/configs/_base_/models/mask_rcnn_r50_fpn.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import RoIAlign, nms +from mmengine.model.weight_init import PretrainedInit +from torch.nn import BatchNorm2d + +from mmdet.models.backbones.resnet import ResNet +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.rpn_head import RPNHead +from mmdet.models.detectors.mask_rcnn import MaskRCNN +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.losses.smooth_l1_loss import L1Loss +from mmdet.models.necks.fpn import FPN +from mmdet.models.roi_heads.bbox_heads.convfc_bbox_head import \ + Shared2FCBBoxHead +from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead +from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import \ + SingleRoIExtractor +from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead +from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner +from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \ + DeltaXYWHBBoxCoder +from mmdet.models.task_modules.prior_generators.anchor_generator import \ + AnchorGenerator +from mmdet.models.task_modules.samplers.random_sampler import RandomSampler + +# model settings +model = dict( + type=MaskRCNN, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet50')), + neck=dict( + type=FPN, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type=RPNHead, + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type=AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0)), + roi_head=dict( + type=StandardRoIHead, + bbox_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type=Shared2FCBBoxHead, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0)), + mask_roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type=FCNMaskHead, + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type=CrossEntropyLoss, use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type=RandomSampler, + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type=nms, iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type=nms, iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/mmdet/configs/_base_/models/retinanet_r50_fpn.py b/mmdet/configs/_base_/models/retinanet_r50_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..33e5cc4f1fe69f66801abdfedc578293e96cd23d --- /dev/null +++ b/mmdet/configs/_base_/models/retinanet_r50_fpn.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import nms +from torch.nn import BatchNorm2d + +from mmdet.models import (FPN, DetDataPreprocessor, FocalLoss, L1Loss, ResNet, + RetinaHead, RetinaNet) +from mmdet.models.task_modules import (AnchorGenerator, DeltaXYWHBBoxCoder, + MaxIoUAssigner, PseudoSampler) + +# model settings +model = dict( + type=RetinaNet, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type=FPN, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type=RetinaHead, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type=AnchorGenerator, + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type=DeltaXYWHBBoxCoder, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type=FocalLoss, + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + sampler=dict( + type=PseudoSampler), # Focal loss should use PseudoSampler + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type=nms, iou_threshold=0.5), + max_per_img=100)) diff --git a/mmdet/configs/_base_/schedules/schedule_1x.py b/mmdet/configs/_base_/schedules/schedule_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..47d1fa6a4852c40f3f9962a47ec90e365671c61c --- /dev/null +++ b/mmdet/configs/_base_/schedules/schedule_1x.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.optim.sgd import SGD + +# training schedule for 1x +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=12, val_interval=1) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/mmdet/configs/_base_/schedules/schedule_2x.py b/mmdet/configs/_base_/schedules/schedule_2x.py new file mode 100644 index 0000000000000000000000000000000000000000..51ba09a4723bc6ba41b8b4cb6e623ade7db26511 --- /dev/null +++ b/mmdet/configs/_base_/schedules/schedule_2x.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.optim.sgd import SGD + +# training schedule for 1x +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=24, val_interval=1) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=24, + by_epoch=True, + milestones=[16, 22], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/mmdet/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py b/mmdet/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..a81c25af8b9506acfa8755ff4ec99d33c661442b --- /dev/null +++ b/mmdet/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_instance import * + from .._base_.default_runtime import * + from .._base_.models.cascade_mask_rcnn_r50_fpn import * + from .._base_.schedules.schedule_1x import * diff --git a/mmdet/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py b/mmdet/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..883f09be67066283e1b59484d3483e73d82af776 --- /dev/null +++ b/mmdet/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_detection import * + from .._base_.default_runtime import * + from .._base_.models.cascade_rcnn_r50_fpn import * + from .._base_.schedules.schedule_1x import * diff --git a/mmdet/configs/common/lsj_100e_coco_detection.py b/mmdet/configs/common/lsj_100e_coco_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2d6bad7f500417ad1eb3e16ca7761c6cadca0e --- /dev/null +++ b/mmdet/configs/common/lsj_100e_coco_detection.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmengine.dataset.sampler import DefaultSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import CocoDataset, RepeatDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +image_size = (1024, 1024) + +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=RandomResize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict( + type=RandomCrop, + crop_type='absolute_range', + crop_size=image_size, + recompute_bbox=True, + allow_negative_crop=True), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +# Use RepeatDataset to speed up training +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + dataset=dict( + type=RepeatDataset, + times=4, # simply change this from 2 to 16 for 50e - 400e training. + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator + +max_epochs = 25 + +train_cfg = dict( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=5) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# optimizer assumes bs=64 +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.00004)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.067, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[22, 24], + gamma=0.1) +] + +# only keep latest 2 checkpoints +default_hooks.update(dict(checkpoint=dict(max_keep_ckpts=2))) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (32 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/mmdet/configs/common/lsj_100e_coco_instance.py b/mmdet/configs/common/lsj_100e_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..90104ee503b22ef395a9b87d74ee80431575d90c --- /dev/null +++ b/mmdet/configs/common/lsj_100e_coco_instance.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmengine.dataset.sampler import DefaultSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import CocoDataset, RepeatDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +image_size = (1024, 1024) + +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=RandomResize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict( + type=RandomCrop, + crop_type='absolute_range', + crop_size=image_size, + recompute_bbox=True, + allow_negative_crop=True), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +# Use RepeatDataset to speed up training +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + dataset=dict( + type=RepeatDataset, + times=4, # simply change this from 2 to 16 for 50e - 400e training. + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator + +max_epochs = 25 + +train_cfg = dict( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=5) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# optimizer assumes bs=64 +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.00004)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.067, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[22, 24], + gamma=0.1) +] + +# only keep latest 2 checkpoints +default_hooks.update(dict(checkpoint=dict(max_keep_ckpts=2))) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (32 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/mmdet/configs/common/lsj_200e_coco_detection.py b/mmdet/configs/common/lsj_200e_coco_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..5759499e95dde6ef99246ab00c21264192ff511c --- /dev/null +++ b/mmdet/configs/common/lsj_200e_coco_detection.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .lsj_100e_coco_detection import * + +# 8x25=200e +train_dataloader.update(dict(dataset=dict(times=8))) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.067, by_epoch=False, begin=0, end=1000), + dict( + type=MultiStepLR, + begin=0, + end=25, + by_epoch=True, + milestones=[22, 24], + gamma=0.1) +] diff --git a/mmdet/configs/common/lsj_200e_coco_instance.py b/mmdet/configs/common/lsj_200e_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..77c5cdd44c488a763d320768e80b314f999ac555 --- /dev/null +++ b/mmdet/configs/common/lsj_200e_coco_instance.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .lsj_100e_coco_instance import * + +# 8x25=200e +train_dataloader.update(dict(dataset=dict(times=8))) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.067, by_epoch=False, begin=0, end=1000), + dict( + type=MultiStepLR, + begin=0, + end=25, + by_epoch=True, + milestones=[22, 24], + gamma=0.1) +] diff --git a/mmdet/configs/common/ms_3x_coco.py b/mmdet/configs/common/ms_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..c32b24d96aeed59a7340cd7e743dd16b7c728bf1 --- /dev/null +++ b/mmdet/configs/common/ms_3x_coco.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmcv.transforms import RandomResize +from mmengine.dataset import RepeatDataset +from mmengine.dataset.sampler import DefaultSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import RandomFlip, Resize +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)], +# multiscale_mode='range' +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=RandomResize, scale=[(1333, 640), (1333, 800)], keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + pin_memory=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=RepeatDataset, + times=3, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + backend_args=backend_args) +test_evaluator = val_evaluator + +# training schedule for 3x with `RepeatDataset` +train_cfg = dict(type=EpochBasedTrainLoop, max_iters=12, val_interval=1) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=12, + by_epoch=False, + milestones=[9, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001)) +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/mmdet/configs/common/ms_3x_coco_instance.py b/mmdet/configs/common/ms_3x_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..3c78909df80173eb37ff83c4ba12614e73848f29 --- /dev/null +++ b/mmdet/configs/common/ms_3x_coco_instance.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.dataset import RepeatDataset +from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type='RandomResize', scale=[(1333, 640), (1333, 800)], + keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader.update( + dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=RepeatDataset, + times=3, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)))) +val_dataloader.update( + dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args))) +test_dataloader = val_dataloader + +val_evaluator.update( + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + backend_args=backend_args)) +test_evaluator = val_evaluator + +# training schedule for 3x with `RepeatDataset` +train_cfg.update(dict(type=EpochBasedTrainLoop, max_epochs=12, val_interval=1)) +val_cfg.update(dict(type=ValLoop)) +test_cfg.update(dict(type=TestLoop)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=12, + by_epoch=False, + milestones=[9, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper.update( + dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001))) +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr.update(dict(enable=False, base_batch_size=16)) diff --git a/mmdet/configs/common/ms_90k_coco.py b/mmdet/configs/common/ms_90k_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..3abf1d4a4a8cf53a4abfa43722e306ac04770e18 --- /dev/null +++ b/mmdet/configs/common/ms_90k_coco.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.dataset import RepeatDataset +from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +# Align with Detectron2 +backend = 'pillow' +train_pipeline = [ + dict( + type=LoadImageFromFile, + backend_args=backend_args, + imdecode_backend=backend), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=RandomChoiceResize, + scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + keep_ratio=True, + backend=backend), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict( + type=LoadImageFromFile, + backend_args=backend_args, + imdecode_backend=backend), + dict(type=Resize, scale=(1333, 800), keep_ratio=True, backend=backend), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader.update( + dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + pin_memory=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader.update( + dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + pin_memory=True, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args))) +test_dataloader = val_dataloader + +val_evaluator.update( + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + format_only=False, + backend_args=backend_args)) +test_evaluator = val_evaluator + +# training schedule for 90k +max_iter = 90000 +train_cfg.update( + dict(type=IterBasedTrainLoop, max_iters=max_iter, val_interval=10000)) +val_cfg.update(dict(type=ValLoop)) +test_cfg.update(dict(type=TestLoop)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=1000), + dict( + type=MultiStepLR, + begin=0, + end=max_iter, + by_epoch=False, + milestones=[60000, 80000], + gamma=0.1) +] + +# optimizer +optim_wrapper.update( + dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001))) +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr.update(dict(enable=False, base_batch_size=16)) + +default_hooks.update(dict(checkpoint=dict(by_epoch=False, interval=10000))) +log_processor.update(dict(by_epoch=False)) diff --git a/mmdet/configs/common/ms_poly_3x_coco_instance.py b/mmdet/configs/common/ms_poly_3x_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..53913a059a4db9230ebd777934cc8db5595479fe --- /dev/null +++ b/mmdet/configs/common/ms_poly_3x_coco_instance.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.dataset import RepeatDataset +from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)], +# multiscale_mode='range' +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type='RandomResize', scale=[(1333, 640), (1333, 800)], + keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader.update( + dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + pin_memory=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=RepeatDataset, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader.update( + dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + pin_memory=True, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args))) +test_dataloader = val_dataloader + +val_evaluator.update( + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + backend_args=backend_args)) +test_evaluator = val_evaluator + +# training schedule for 3x with `RepeatDataset` +train_cfg.update(dict(type=EpochBasedTrainLoop, max_iters=12, val_interval=1)) +val_cfg.update(dict(type=ValLoop)) +test_cfg.update(dict(type=TestLoop)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=12, + by_epoch=False, + milestones=[9, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper.update( + dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001))) +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr.update(dict(enable=False, base_batch_size=16)) diff --git a/mmdet/configs/common/ms_poly_90k_coco_instance.py b/mmdet/configs/common/ms_poly_90k_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..52367350137035604ea167e5732a791c2e9cae87 --- /dev/null +++ b/mmdet/configs/common/ms_poly_90k_coco_instance.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.dataset import RepeatDataset +from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +# Align with Detectron2 +backend = 'pillow' +train_pipeline = [ + dict( + type=LoadImageFromFile, + backend_args=backend_args, + imdecode_backend=backend), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=RandomChoiceResize, + scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + keep_ratio=True, + backend=backend), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict( + type=LoadImageFromFile, + backend_args=backend_args, + imdecode_backend=backend), + dict(type=Resize, scale=(1333, 800), keep_ratio=True, backend=backend), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader.update( + dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + pin_memory=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader.update( + dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + pin_memory=True, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args))) +test_dataloader = val_dataloader + +val_evaluator.update( + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args)) +test_evaluator = val_evaluator + +# training schedule for 90k +max_iter = 90000 +train_cfg.update( + dict(type=IterBasedTrainLoop, max_iters=max_iter, val_interval=10000)) +val_cfg.update(dict(type=ValLoop)) +test_cfg.update(dict(type=TestLoop)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=1000), + dict( + type=MultiStepLR, + begin=0, + end=max_iter, + by_epoch=False, + milestones=[60000, 80000], + gamma=0.1) +] + +# optimizer +optim_wrapper.update( + dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001))) +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr.update(dict(enable=False, base_batch_size=16)) + +default_hooks.update(dict(checkpoint=dict(by_epoch=False, interval=10000))) +log_processor.update(dict(by_epoch=False)) diff --git a/mmdet/configs/common/ssj_270_coco_instance.py b/mmdet/configs/common/ssj_270_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..ee86fdad4eca5b87ac0066b635e098d6a927bb49 --- /dev/null +++ b/mmdet/configs/common/ssj_270_coco_instance.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.dataset import RepeatDataset +from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler +from mmengine.optim import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import LinearLR, MultiStepLR +from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop +from torch.optim import SGD + +from mmdet.datasets import AspectRatioBatchSampler, CocoDataset +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations, + LoadImageFromFile) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + RandomResize, Resize) +from mmdet.evaluation import CocoMetric + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +# Standard Scale Jittering (SSJ) resizes and crops an image +# with a resize range of 0.8 to 1.25 of the original image size. +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=RandomResize, + scale=image_size, + ratio_range=(0.8, 1.25), + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=image_size, + recompute_bbox=True, + allow_negative_crop=True), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader.update( + dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=InfiniteSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args))) +val_dataloader.update( + dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args))) +test_dataloader = val_dataloader + +val_evaluator.update( + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args)) +test_evaluator = val_evaluator + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator + +# The model is trained by 270k iterations with batch_size 64, +# which is roughly equivalent to 144 epochs. + +max_iter = 270000 +train_cfg.update( + dict(type=IterBasedTrainLoop, max_iters=max_iter, val_interval=10000)) +val_cfg.update(dict(type=ValLoop)) +test_cfg.update(dict(type=TestLoop)) + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=1000), + dict( + type=MultiStepLR, + begin=0, + end=max_iter, + by_epoch=False, + milestones=[243000, 256500, 263250], + gamma=0.1) +] + +# optimizer +optim_wrapper.update( + dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.00004))) +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr.update(dict(base_batch_size=64)) + +default_hooks.update(dict(checkpoint=dict(by_epoch=False, interval=10000))) +log_processor.update(dict(by_epoch=False)) diff --git a/mmdet/configs/common/ssj_scp_270k_coco_instance.py b/mmdet/configs/common/ssj_scp_270k_coco_instance.py new file mode 100644 index 0000000000000000000000000000000000000000..68bb1f0904fcb4de3e2f892355e489f52f53d960 --- /dev/null +++ b/mmdet/configs/common/ssj_scp_270k_coco_instance.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .ssj_270_coco_instance import * + +from mmdet.datasets import MultiImageMixDataset +from mmdet.datasets.transforms import CopyPaste + +# dataset settings +dataset_type = CocoDataset +data_root = 'data/coco/' +image_size = (1024, 1024) +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +# Standard Scale Jittering (SSJ) resizes and crops an image +# with a resize range of 0.8 to 1.25 of the original image size. +load_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=RandomResize, + scale=image_size, + ratio_range=(0.8, 1.25), + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=image_size, + recompute_bbox=True, + allow_negative_crop=True), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=image_size), +] +train_pipeline = [ + dict(type=CopyPaste, max_num_pasted=100), + dict(type=PackDetInputs) +] + +train_dataloader.update( + dict( + type=MultiImageMixDataset, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=load_pipeline, + backend_args=backend_args), + pipeline=train_pipeline)) diff --git a/mmdet/configs/deformable_detr/deformable_detr_r50_16xb2_50e_coco.py b/mmdet/configs/deformable_detr/deformable_detr_r50_16xb2_50e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2a41639d84ed8e278af45229b451b742ac8974 --- /dev/null +++ b/mmdet/configs/deformable_detr/deformable_detr_r50_16xb2_50e_coco.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_detection import * + from .._base_.default_runtime import * + +from mmcv.transforms import LoadImageFromFile, RandomChoice, RandomChoiceResize +from mmengine.optim.optimizer import OptimWrapper +from mmengine.optim.scheduler import MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.optim.adamw import AdamW + +from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs, + RandomCrop, RandomFlip, Resize) +from mmdet.models.backbones import ResNet +from mmdet.models.data_preprocessors import DetDataPreprocessor +from mmdet.models.dense_heads import DeformableDETRHead +from mmdet.models.detectors import DeformableDETR +from mmdet.models.losses import FocalLoss, GIoULoss, L1Loss +from mmdet.models.necks import ChannelMapper +from mmdet.models.task_modules import (BBoxL1Cost, FocalLossCost, + HungarianAssigner, IoUCost) + +model = dict( + type=DeformableDETR, + num_queries=300, + num_feature_levels=4, + with_box_refine=False, + as_two_stage=False, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=1), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type=ChannelMapper, + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))), + decoder=dict( # DeformableDetrTransformerDecoder + num_layers=6, + return_intermediate=True, + layer_cfg=dict( # DeformableDetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + cross_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)), + post_norm_cfg=None), + positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5), + bbox_head=dict( + type=DeformableDETRHead, + num_classes=80, + sync_cls_avg_factor=True, + loss_cls=dict( + type=FocalLoss, + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_bbox=dict(type=L1Loss, loss_weight=5.0), + loss_iou=dict(type=GIoULoss, loss_weight=2.0)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type=HungarianAssigner, + match_costs=[ + dict(type=FocalLossCost, weight=2.0), + dict(type=BBoxL1Cost, weight=5.0, box_format='xywh'), + dict(type=IoUCost, iou_mode='giou', weight=2.0) + ])), + test_cfg=dict(max_per_img=100)) + +# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different +# from the default setting in mmdet. +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomChoice, + transforms=[ + [ + dict( + type=RandomChoiceResize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + resize_type=Resize, + keep_ratio=True) + ], + [ + dict( + type=RandomChoiceResize, + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + resize_type=Resize, + keep_ratio=True), + dict( + type=RandomCrop, + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type=RandomChoiceResize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + resize_type=Resize, + keep_ratio=True) + ] + ]), + dict(type=PackDetInputs) +] +train_dataloader.update( + dict( + dataset=dict( + filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))) + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=AdamW, lr=0.0002, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1), + 'sampling_offsets': dict(lr_mult=0.1), + 'reference_points': dict(lr_mult=0.1) + })) + +# learning policy +max_epochs = 50 +train_cfg = dict( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +param_scheduler = [ + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[40], + gamma=0.1) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (16 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=32) diff --git a/mmdet/configs/deformable_detr/deformable_detr_refine_r50_16xb2_50e_coco.py b/mmdet/configs/deformable_detr/deformable_detr_refine_r50_16xb2_50e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..4f232d6111026488020e586440852c012dd94608 --- /dev/null +++ b/mmdet/configs/deformable_detr/deformable_detr_refine_r50_16xb2_50e_coco.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .deformable_detr_r50_16xb2_50e_coco import * + +model.update(dict(with_box_refine=True)) diff --git a/mmdet/configs/deformable_detr/deformable_detr_refine_twostage_r50_16xb2_50e_coco.py b/mmdet/configs/deformable_detr/deformable_detr_refine_twostage_r50_16xb2_50e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..1fac4d8c4f2020b6d87857fbe157419e4c4f0712 --- /dev/null +++ b/mmdet/configs/deformable_detr/deformable_detr_refine_twostage_r50_16xb2_50e_coco.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .deformable_detr_refine_r50_16xb2_50e_coco import * + +model.update(dict(as_two_stage=True)) diff --git a/mmdet/configs/detr/detr_r101_8xb2_500e_coco.py b/mmdet/configs/detr/detr_r101_8xb2_500e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..b961468114ce3adb0582378ac422649ef3bd5013 --- /dev/null +++ b/mmdet/configs/detr/detr_r101_8xb2_500e_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit + +with read_base(): + from .detr_r50_8xb2_500e_coco import * + +model.update( + dict( + backbone=dict( + depth=101, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet101')))) diff --git a/mmdet/configs/detr/detr_r18_8xb2_500e_coco.py b/mmdet/configs/detr/detr_r18_8xb2_500e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..11360af18de729bfd9e8d8cb6597067a588852c9 --- /dev/null +++ b/mmdet/configs/detr/detr_r18_8xb2_500e_coco.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit + +with read_base(): + from .detr_r50_8xb2_500e_coco import * + +model.update( + dict( + backbone=dict( + depth=18, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet18')), + neck=dict(in_channels=[512]))) diff --git a/mmdet/configs/detr/detr_r50_8xb2_150e_coco.py b/mmdet/configs/detr/detr_r50_8xb2_150e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..c50726c7890cb59bee4b921179be1949ff12199e --- /dev/null +++ b/mmdet/configs/detr/detr_r50_8xb2_150e_coco.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import RandomChoice, RandomChoiceResize +from mmcv.transforms.loading import LoadImageFromFile +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.nn.modules.activation import ReLU +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim.adamw import AdamW + +from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs, + RandomCrop, RandomFlip, Resize) +from mmdet.models import (DETR, ChannelMapper, DetDataPreprocessor, DETRHead, + ResNet) +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.losses.iou_loss import GIoULoss +from mmdet.models.losses.smooth_l1_loss import L1Loss +from mmdet.models.task_modules import (BBoxL1Cost, ClassificationCost, + HungarianAssigner, IoUCost) + +with read_base(): + from .._base_.datasets.coco_detection import * + from .._base_.default_runtime import * + +model = dict( + type=DETR, + num_queries=100, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=1), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(3, ), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet50')), + neck=dict( + type=ChannelMapper, + in_channels=[2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=None, + num_outs=1), + encoder=dict( # DetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type=ReLU, inplace=True)))), + decoder=dict( # DetrTransformerDecoder + num_layers=6, + layer_cfg=dict( # DetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type=ReLU, inplace=True))), + return_intermediate=True), + positional_encoding=dict(num_feats=128, normalize=True), + bbox_head=dict( + type=DETRHead, + num_classes=80, + embed_dims=256, + loss_cls=dict( + type=CrossEntropyLoss, + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_bbox=dict(type=L1Loss, loss_weight=5.0), + loss_iou=dict(type=GIoULoss, loss_weight=2.0)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type=HungarianAssigner, + match_costs=[ + dict(type=ClassificationCost, weight=1.), + dict(type=BBoxL1Cost, weight=5.0, box_format='xywh'), + dict(type=IoUCost, iou_mode='giou', weight=2.0) + ])), + test_cfg=dict(max_per_img=100)) + +# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different +# from the default setting in mmdet. +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomChoice, + transforms=[[ + dict( + type=RandomChoiceResize, + resize_type=Resize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type=RandomChoiceResize, + resize_type=Resize, + scales=[(400, 1333), (500, 1333), (600, 1333)], + keep_ratio=True), + dict( + type=RandomCrop, + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type=RandomChoiceResize, + resize_type=Resize, + scales=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + keep_ratio=True) + ]]), + dict(type=PackDetInputs) +] +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=AdamW, lr=0.0001, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) + +# learning policy +max_epochs = 150 +train_cfg = dict( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +param_scheduler = [ + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[100], + gamma=0.1) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=16) diff --git a/mmdet/configs/detr/detr_r50_8xb2_500e_coco.py b/mmdet/configs/detr/detr_r50_8xb2_500e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d0817766255a84237f0aea917806e191d161df --- /dev/null +++ b/mmdet/configs/detr/detr_r50_8xb2_500e_coco.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.optim.scheduler.lr_scheduler import MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop + +with read_base(): + from .detr_r50_8xb2_150e_coco import * + +# learning policy +max_epochs = 500 +train_cfg.update( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=10) + +param_scheduler = [ + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[334], + gamma=0.1) +] + +# only keep latest 2 checkpoints +default_hooks.update(checkpoint=dict(max_keep_ckpts=2)) diff --git a/mmdet/configs/dino/dino_4scale_r50_8xb2_12e_coco.py b/mmdet/configs/dino/dino_4scale_r50_8xb2_12e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8e95a9a76c0cedb78c66993fc7fb7f4623029c --- /dev/null +++ b/mmdet/configs/dino/dino_4scale_r50_8xb2_12e_coco.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import RandomChoice, RandomChoiceResize +from mmcv.transforms.loading import LoadImageFromFile +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import MultiStepLR +from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.normalization import GroupNorm +from torch.optim.adamw import AdamW + +from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs, + RandomCrop, RandomFlip, Resize) +from mmdet.models import (DINO, ChannelMapper, DetDataPreprocessor, DINOHead, + ResNet) +from mmdet.models.losses.focal_loss import FocalLoss +from mmdet.models.losses.iou_loss import GIoULoss +from mmdet.models.losses.smooth_l1_loss import L1Loss +from mmdet.models.task_modules import (BBoxL1Cost, FocalLossCost, + HungarianAssigner, IoUCost) + +with read_base(): + from .._base_.datasets.coco_detection import * + from .._base_.default_runtime import * + +model = dict( + type=DINO, + num_queries=900, # num_matching_queries + with_box_refine=True, + as_two_stage=True, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=1), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet50')), + neck=dict( + type=ChannelMapper, + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type=GroupNorm, num_groups=32), + num_outs=4), + encoder=dict( + num_layers=6, + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=256, num_levels=4, + dropout=0.0), # 0.1 for DeformDETR + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, # 1024 for DeformDETR + ffn_drop=0.0))), # 0.1 for DeformDETR + decoder=dict( + num_layers=6, + return_intermediate=True, + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=256, num_heads=8, + dropout=0.0), # 0.1 for DeformDETR + cross_attn_cfg=dict(embed_dims=256, num_levels=4, + dropout=0.0), # 0.1 for DeformDETR + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, # 1024 for DeformDETR + ffn_drop=0.0)), # 0.1 for DeformDETR + post_norm_cfg=None), + positional_encoding=dict( + num_feats=128, + normalize=True, + offset=0.0, # -0.5 for DeformDETR + temperature=20), # 10000 for DeformDETR + bbox_head=dict( + type=DINOHead, + num_classes=80, + sync_cls_avg_factor=True, + loss_cls=dict( + type=FocalLoss, + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), # 2.0 in DeformDETR + loss_bbox=dict(type=L1Loss, loss_weight=5.0), + loss_iou=dict(type=GIoULoss, loss_weight=2.0)), + dn_cfg=dict( # TODO: Move to model.train_cfg ? + label_noise_scale=0.5, + box_noise_scale=1.0, # 0.4 for DN-DETR + group_cfg=dict(dynamic=True, num_groups=None, + num_dn_queries=100)), # TODO: half num_dn_queries + # training and testing settings + train_cfg=dict( + assigner=dict( + type=HungarianAssigner, + match_costs=[ + dict(type=FocalLossCost, weight=2.0), + dict(type=BBoxL1Cost, weight=5.0, box_format='xywh'), + dict(type=IoUCost, iou_mode='giou', weight=2.0) + ])), + test_cfg=dict(max_per_img=300)) # 100 for DeformDETR + +# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different +# from the default setting in mmdet. +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomChoice, + transforms=[ + [ + dict( + type=RandomChoiceResize, + resize_type=Resize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type=RandomChoiceResize, + resize_type=Resize, + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type=RandomCrop, + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type=RandomChoiceResize, + resize_type=Resize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type=PackDetInputs) +] +train_dataloader.update( + dataset=dict( + filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline)) + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict( + type=AdamW, + lr=0.0001, # 0.0002 for DeformDETR + weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)}) +) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa + +# learning policy +max_epochs = 12 +train_cfg = dict( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1) + +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +param_scheduler = [ + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[11], + gamma=0.1) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=16) diff --git a/mmdet/configs/dino/dino_4scale_r50_8xb2_24e_coco.py b/mmdet/configs/dino/dino_4scale_r50_8xb2_24e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..c10cc2184de8f71571759ecbeac56696afceb5eb --- /dev/null +++ b/mmdet/configs/dino/dino_4scale_r50_8xb2_24e_coco.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.runner.loops import EpochBasedTrainLoop + +with read_base(): + from .dino_4scale_r50_8xb2_12e_coco import * + +max_epochs = 24 +train_cfg.update( + dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)) + +param_scheduler[0].update(dict(milestones=[20])) diff --git a/mmdet/configs/dino/dino_4scale_r50_8xb2_36e_coco.py b/mmdet/configs/dino/dino_4scale_r50_8xb2_36e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..3779744322a19d2865f1e6299aba564c4ec1e3d5 --- /dev/null +++ b/mmdet/configs/dino/dino_4scale_r50_8xb2_36e_coco.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.runner.loops import EpochBasedTrainLoop + +with read_base(): + from .dino_4scale_r50_8xb2_12e_coco import * + +max_epochs = 36 +train_cfg.update( + dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)) + +param_scheduler[0].update(dict(milestones=[30])) diff --git a/mmdet/configs/dino/dino_4scale_r50_improved_8xb2_12e_coco.py b/mmdet/configs/dino/dino_4scale_r50_improved_8xb2_12e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..43c07201079fcdbad3c9ea7a471306080e006cdc --- /dev/null +++ b/mmdet/configs/dino/dino_4scale_r50_improved_8xb2_12e_coco.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .dino_4scale_r50_8xb2_12e_coco import * + +# from deformable detr hyper +model.update( + dict( + backbone=dict(frozen_stages=-1), + bbox_head=dict(loss_cls=dict(loss_weight=2.0)), + positional_encoding=dict(offset=-0.5, temperature=10000), + dn_cfg=dict(group_cfg=dict(num_dn_queries=300)))) + +# optimizer +optim_wrapper.update( + dict( + optimizer=dict(lr=0.0002), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1), + 'sampling_offsets': dict(lr_mult=0.1), + 'reference_points': dict(lr_mult=0.1) + }))) diff --git a/mmdet/configs/dino/dino_5scale_swin_l_8xb2_12e_coco.py b/mmdet/configs/dino/dino_5scale_swin_l_8xb2_12e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..25aac0187ab2472dd062514fecf988dcd47504a5 --- /dev/null +++ b/mmdet/configs/dino/dino_5scale_swin_l_8xb2_12e_coco.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit + +from mmdet.models import SwinTransformer + +with read_base(): + from .dino_4scale_r50_8xb2_12e_coco import * + +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa +num_levels = 5 +model.merge( + dict( + num_feature_levels=num_levels, + backbone=dict( + _delete_=True, + type=SwinTransformer, + pretrain_img_size=384, + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(0, 1, 2, 3), + # Please only add indices that would be used + # in FPN, otherwise some parameter will not be used + with_cp=True, + convert_weights=True, + init_cfg=dict(type=PretrainedInit, checkpoint=pretrained)), + neck=dict(in_channels=[192, 384, 768, 1536], num_outs=num_levels), + encoder=dict( + layer_cfg=dict(self_attn_cfg=dict(num_levels=num_levels))), + decoder=dict( + layer_cfg=dict(cross_attn_cfg=dict(num_levels=num_levels))))) diff --git a/mmdet/configs/dino/dino_5scale_swin_l_8xb2_36e_coco.py b/mmdet/configs/dino/dino_5scale_swin_l_8xb2_36e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..494acf59f1c31fe419415920e8b65fbfb9267df1 --- /dev/null +++ b/mmdet/configs/dino/dino_5scale_swin_l_8xb2_36e_coco.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.runner.loops import EpochBasedTrainLoop + +with read_base(): + from .dino_5scale_swin_l_8xb2_12e_coco import * + +max_epochs = 36 +train_cfg.update( + dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)) + +param_scheduler[0].update(dict(milestones=[27, 33])) diff --git a/mmdet/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py b/mmdet/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a6d5a21470752fd26fa162edf5c2241afb1fed --- /dev/null +++ b/mmdet/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_detection import * + from .._base_.default_runtime import * + from .._base_.models.faster_rcnn_r50_fpn import * + from .._base_.schedules.schedule_1x import * diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r101_caffe_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r101_caffe_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..2780f4afddc05ccd4ae1746206a6a6ad8cece39e --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r101_caffe_fpn_1x_coco.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_poly_1x_coco import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet101_caffe'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r101_caffe_fpn_ms_poly_3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r101_caffe_fpn_ms_poly_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..8a1badfc4f04f6ad5466d9ec3aa2d07708887927 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r101_caffe_fpn_ms_poly_3x_coco.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from ..common.ms_poly_3x_coco_instance import * + from .._base_.models.mask_rcnn_r50_fpn import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + # use caffe img_norm + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False), + backbone=dict( + depth=101, + norm_cfg=dict(requires_grad=False), + norm_eval=True, + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet101_caffe'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..6770cec8eebe8c5130abd15f9bc44d5b5c5db875 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_1x_coco.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mask_rcnn_r50_fpn import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet101'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_2x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_2x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2aafb912ca84f776637e498d2743213a05d18a --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_2x_coco.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_2x_coco import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet101'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_8xb8_amp_lsj_200e_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_8xb8_amp_lsj_200e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..665808d5dc479ecb7c5a328af3861f59e460ac78 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_8xb8_amp_lsj_200e_coco.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r18_fpn_8xb8_amp_lsj_200e_coco import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet101'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_ms_poly_3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_ms_poly_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..14688795963cb28018f5897429b191b235a86b6b --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r101_fpn_ms_poly_3x_coco.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from ..common.ms_poly_3x_coco_instance import * + from .._base_.models.mask_rcnn_r50_fpn import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet101'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r18_fpn_8xb8_amp_lsj_200e_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r18_fpn_8xb8_amp_lsj_200e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..67bd86fa0e8f8b414eec681852511db3b3d4c9c6 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r18_fpn_8xb8_amp_lsj_200e_coco.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_8xb8_amp_lsj_200e_coco import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + backbone=dict( + depth=18, + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet18')), + neck=dict(in_channels=[64, 128, 256, 512])) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_c4_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_c4_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..494e6ba593efa663f06e1383ceba8b57b9d097b5 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_c4_1x_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_instance import * + from .._base_.default_runtime import * + from .._base_.models.mask_rcnn_r50_caffe_c4 import * + from .._base_.schedules.schedule_1x import * diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..6481fcfd49eeac603eced8e46ee3a8705add8367 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_1x_coco import * + +from mmengine.model.weight_init import PretrainedInit + +model = dict( + # use caffe img_norm + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False), + backbone=dict( + norm_cfg=dict(requires_grad=False), + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet50_caffe'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..5952ed587a431740bc3d17ac9d2e6b5a3d326061 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_1x_coco.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_1x_coco import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.model.weight_init import PretrainedInit + +model = dict( + # use caffe img_norm + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False), + backbone=dict( + norm_cfg=dict(requires_grad=False), + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet50_caffe'))) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args={{_base_.backend_args}}), + dict(type=LoadAnnotations, with_bbox=True, with_mask=True), + dict( + type=RandomChoiceResize, + scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs), +] + +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d62b9ebe958b3a8f790a6e9581942494f42bf7d6 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_1x_coco.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_1x_coco import * + +from mmcv.transforms import RandomChoiceResize +from mmengine.model.weight_init import PretrainedInit + +model = dict( + # use caffe img_norm + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False), + backbone=dict( + norm_cfg=dict(requires_grad=False), + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet50_caffe'))) +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args={{_base_.backend_args}}), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=RandomChoiceResize, + scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs) +] + +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_2x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_2x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..fa41b7e00ca153814f28ac29638cc497e7a2d3e9 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_2x_coco.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_caffe_fpn_ms_poly_1x_coco import * + +train_cfg = dict(max_epochs=24) +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=24, + by_epoch=True, + milestones=[16, 22], + gamma=0.1) +] diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f9b977b2dfaebfe01d834ac4ad8cf4522fe9c0 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_ms_poly_3x_coco.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_caffe_fpn_ms_poly_1x_coco import * + +train_cfg = dict(max_epochs=36) +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type=MultiStepLR, + begin=0, + end=24, + by_epoch=True, + milestones=[28, 34], + gamma=0.1) +] diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_poly_1x_coco_v1.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_poly_1x_coco_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..28ba7c77ddf10d295d371db2f46d6c1f117ac7c6 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_poly_1x_coco_v1.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_1x_coco import * + +from mmengine.model.weight_init import PretrainedInit + +from mmdet.models.losses import SmoothL1Loss + +model = dict( + # use caffe img_norm + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False), + backbone=dict( + norm_cfg=dict(requires_grad=False), + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet50_caffe')), + rpn_head=dict( + loss_bbox=dict(type=SmoothL1Loss, beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + bbox_roi_extractor=dict( + roi_layer=dict( + type=RoIAlign, output_size=7, sampling_ratio=2, + aligned=False)), + bbox_head=dict( + loss_bbox=dict(type=SmoothL1Loss, beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + roi_layer=dict( + type=RoIAlign, output_size=14, sampling_ratio=2, + aligned=False)))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..8145d08fee85c1758d3794cee952a3b7200b14bd --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_instance import * + from .._base_.default_runtime import * + from .._base_.models.mask_rcnn_r50_fpn import * + from .._base_.schedules.schedule_1x import * diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_wandb_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_wandb_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c0876541289d10832a1b26ddb6e91f6a66d89a --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_wandb_coco.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_instance import * + from .._base_.default_runtime import * + from .._base_.models.mask_rcnn_r50_fpn import * + from .._base_.schedules.schedule_1x import * + +from mmengine.visualization import LocalVisBackend, WandbVisBackend + +vis_backends.update(dict(type=WandbVisBackend)) +vis_backends.update(dict(type=LocalVisBackend)) +visualizer.update(dict(vis_backends=vis_backends)) + +# MMEngine support the following two ways, users can choose +# according to convenience +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +default_hooks.update(dict(checkpoint=dict(interval=4))) + +train_cfg.update(dict(val_interval=2)) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_2x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_2x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..6be010b4508d6ba300a1305a1d405ec9a265ae07 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_2x_coco.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_instance import * + from .._base_.default_runtime import * + from .._base_.models.mask_rcnn_r50_fpn import * + from .._base_.schedules.schedule_2x import * diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_8xb8_amp_lsj_200e_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_8xb8_amp_lsj_200e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ef101fec61e72abc0eb90266d453b5b22331378d --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_8xb8_amp_lsj_200e_coco.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_amp_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_amp_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..110c3c475429701a92321676d17f829f82cbfb76 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_amp_1x_coco.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_1x_coco import * + +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper + +optim_wrapper.update(dict(type=AmpOptimWrapper)) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_ms_poly_-3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_ms_poly_-3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4eec6d2be0f4bd61c7bd04057fd58b303120c8 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_ms_poly_-3x_coco.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mask_rcnn_r50_fpn import * + from ..common.ms_poly_3x_coco_instance import * diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_poly_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_poly_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..012e711cb96f9aa67460b86694838d592fd1ae25 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_r50_fpn_poly_1x_coco.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.coco_instance import * + from .._base_.default_runtime import * + from .._base_.models.mask_rcnn_r50_fpn import * + from .._base_.schedules.schedule_1x import * + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs), +] +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..5429b1bd5a62f4786936d19e65d6281807d800bf --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_1x_coco.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r101_fpn_1x_coco import * + +from mmengine.model.weight_init import PretrainedInit + +from mmdet.models.backbones.resnext import ResNeXt + +model = dict( + backbone=dict( + type=ResNeXt, + depth=101, + groups=32, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='open-mmlab://resnext101_32x4d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..ebae6c1dbc3a234ede68ba5b7a6e199edf966ead --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_2x_coco.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r50_fpn_2x_coco import * + +from mmengine.model.weight_init import PretrainedInit + +from mmdet.models import ResNeXt + +model = dict( + backbone=dict( + type=ResNeXt, + depth=101, + groups=32, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='open-mmlab://resnext101_32x4d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_ms_poly_3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_ms_poly_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..aff45d89f351037cec3115271feab678eac3382f --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x4d_fpn_ms_poly_3x_coco.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from ..common.ms_poly_3x_coco_instance import * + from .._base_.models.mask_rcnn_r50_fpn import * + +from mmengine.model.weight_init import PretrainedInit + +from mmdet.models.backbones import ResNeXt + +model = dict( + backbone=dict( + type=ResNeXt, + depth=101, + groups=32, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='open-mmlab://resnext101_32x4d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f2095dc2dff7396896a9b2af2fb05bcd765c69 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_1x_coco.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_x101_32x4d_fpn_1x_coco import * + +model = dict( + # ResNeXt-101-32x8d model trained with Caffe2 at FB, + # so the mean and std need to be changed. + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[57.375, 57.120, 58.395], + bgr_to_rgb=False), + backbone=dict( + type=ResNeXt, + depth=101, + groups=32, + base_width=8, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnext101_32x8d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_ms_poly_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_ms_poly_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..8eded941751ce71b9c63baa565275802c7ee9bb2 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_ms_poly_1x_coco.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_r101_fpn_1x_coco import * + +from mmcv.transforms import RandomChoiceResize, RandomFlip +from mmcv.transforms.loading import LoadImageFromFile + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.models.backbones import ResNeXt + +model = dict( + # ResNeXt-101-32x8d model trained with Caffe2 at FB, + # so the mean and std need to be changed. + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[57.375, 57.120, 58.395], + bgr_to_rgb=False), + backbone=dict( + type=ResNeXt, + depth=101, + groups=32, + base_width=8, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnext101_32x8d'))) + +backend_args = None +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=RandomChoiceResize, + scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + keep_ratio=True), + dict(type=RandomFlip, prob=0.5), + dict(type=PackDetInputs), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_ms_poly_3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_ms_poly_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f584675f6da93ec7c188753c2f0478bac25ba8 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_32x8d_fpn_ms_poly_3x_coco.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from ..common.ms_poly_3x_coco_instance import * + from .._base_.models.mask_rcnn_r50_fpn import * + +from mmdet.models.backbones import ResNeXt + +model = dict( + # ResNeXt-101-32x8d model trained with Caffe2 at FB, + # so the mean and std need to be changed. + data_preprocessor=dict( + mean=[103.530, 116.280, 123.675], + std=[57.375, 57.120, 58.395], + bgr_to_rgb=False), + backbone=dict( + type=ResNeXt, + depth=101, + groups=32, + base_width=8, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnext101_32x8d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_64_4d_fpn_1x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_64_4d_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb6f636e641138b902f69a543da0bd8a656db3d --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_64_4d_fpn_1x_coco.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_x101_32x4d_fpn_1x_coco import * + +model = dict( + backbone=dict( + type=ResNeXt, + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='open-mmlab://resnext101_64x4d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_64x4d_fpn_2x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_64x4d_fpn_2x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d661076dcf37df9668d6cdf726ecfc5720c561df --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_64x4d_fpn_2x_coco.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .mask_rcnn_x101_32x4d_fpn_2x_coco import * + +model = dict( + backbone=dict( + type=ResNeXt, + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='open-mmlab://resnext101_64x4d'))) diff --git a/mmdet/configs/mask_rcnn/mask_rcnn_x101_64x4d_fpn_ms_poly_3x_coco.py b/mmdet/configs/mask_rcnn/mask_rcnn_x101_64x4d_fpn_ms_poly_3x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ab3643ec27665f7a9411d95c7e01711dfe7623 --- /dev/null +++ b/mmdet/configs/mask_rcnn/mask_rcnn_x101_64x4d_fpn_ms_poly_3x_coco.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from ..common.ms_poly_3x_coco_instance import * + from .._base_.models.mask_rcnn_r50_fpn import * + +from mmdet.models.backbones import ResNeXt + +model = dict( + backbone=dict( + type=ResNeXt, + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=True), + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='open-mmlab://resnext101_64x4d'))) diff --git a/mmdet/configs/maskformer/maskformer_r50_ms_16xb1_75e_coco.py b/mmdet/configs/maskformer/maskformer_r50_ms_16xb1_75e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..70744013afcad76834d05ccc8aa6303dc6399bc0 --- /dev/null +++ b/mmdet/configs/maskformer/maskformer_r50_ms_16xb1_75e_coco.py @@ -0,0 +1,249 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import RandomChoice, RandomChoiceResize +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit +from mmengine.optim.optimizer import OptimWrapper +from mmengine.optim.scheduler import MultiStepLR +from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.nn.modules.activation import ReLU +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.normalization import GroupNorm +from torch.optim.adamw import AdamW + +from mmdet.datasets.transforms.transforms import RandomCrop +from mmdet.models import MaskFormer +from mmdet.models.backbones import ResNet +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.maskformer_head import MaskFormerHead +from mmdet.models.layers.pixel_decoder import TransformerEncoderPixelDecoder +from mmdet.models.losses import CrossEntropyLoss, DiceLoss, FocalLoss +from mmdet.models.seg_heads.panoptic_fusion_heads import MaskFormerFusionHead +from mmdet.models.task_modules.assigners.hungarian_assigner import \ + HungarianAssigner +from mmdet.models.task_modules.assigners.match_cost import (ClassificationCost, + DiceCost, + FocalLossCost) +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +with read_base(): + from .._base_.datasets.coco_panoptic import * + from .._base_.default_runtime import * + +data_preprocessor = dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=1, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255) + +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes +model = dict( + type=MaskFormer, + data_preprocessor=data_preprocessor, + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet50')), + panoptic_head=dict( + type=MaskFormerHead, + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + pixel_decoder=dict( + type=TransformerEncoderPixelDecoder, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # DetrTransformerDecoder + num_layers=6, + layer_cfg=dict( # DetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type=ReLU, inplace=True))), + return_intermediate=True), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=1.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type=FocalLoss, + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=20.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=1.0)), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), + train_cfg=dict( + assigner=dict( + type=HungarianAssigner, + match_costs=[ + dict(type=ClassificationCost, weight=1.0), + dict(type=FocalLossCost, weight=20.0, binary_input=True), + dict(type=DiceCost, weight=1.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=False, + # max_per_image is for instance segmentation. + max_per_image=100, + object_mask_thr=0.8, + iou_thr=0.8, + # In MaskFormer's panoptic postprocessing, + # it will not filter masks whose score is smaller than 0.5 . + filter_low_score=False), + init_cfg=None) + +# dataset settings +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=LoadPanopticAnnotations, + with_bbox=True, + with_mask=True, + with_seg=True), + dict(type=RandomFlip, prob=0.5), + # dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict( + type=RandomChoice, + transforms=[[ + dict( + type=RandomChoiceResize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + resize_type=Resize, + keep_ratio=True) + ], + [ + dict( + type=RandomChoiceResize, + scales=[(400, 1333), (500, 1333), (600, 1333)], + resize_type=Resize, + keep_ratio=True), + dict( + type=RandomCrop, + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type=RandomChoiceResize, + scales=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + resize_type=Resize, + keep_ratio=True) + ]]), + dict(type=PackDetInputs) +] + +train_dataloader.update( + dict(batch_size=1, num_workers=1, dataset=dict(pipeline=train_pipeline))) + +val_dataloader.update(dict(batch_size=1, num_workers=1)) + +test_dataloader = val_dataloader + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict( + type=AdamW, + lr=0.0001, + weight_decay=0.0001, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': dict(lr_mult=1.0, decay_mult=0.0) + }, + norm_decay_mult=0.0), + clip_grad=dict(max_norm=0.01, norm_type=2)) + +max_epochs = 75 + +# learning rate +param_scheduler = dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[50], + gamma=0.1) + +train_cfg = dict( + type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (16 GPUs) x (1 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/mmdet/configs/maskformer/maskformer_swin_l_p4_w12_64xb1_ms_300e_coco.py b/mmdet/configs/maskformer/maskformer_swin_l_p4_w12_64xb1_ms_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..2affe520918d0f26c0a858f97bb69646a2860f87 --- /dev/null +++ b/mmdet/configs/maskformer/maskformer_swin_l_p4_w12_64xb1_ms_300e_coco.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.optim.scheduler import LinearLR + +from mmdet.models.backbones import SwinTransformer +from mmdet.models.layers import PixelDecoder + +with read_base(): + from .maskformer_r50_ms_16xb1_75e_coco import * + +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa +depths = [2, 2, 18, 2] +model.update( + dict( + backbone=dict( + _delete_=True, + type=SwinTransformer, + pretrain_img_size=384, + embed_dims=192, + patch_size=4, + window_size=12, + mlp_ratio=4, + depths=depths, + num_heads=[6, 12, 24, 48], + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(0, 1, 2, 3), + with_cp=False, + convert_weights=True, + init_cfg=dict(type=PretrainedInit, checkpoint=pretrained)), + panoptic_head=dict( + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + pixel_decoder=dict( + _delete_=True, + type=PixelDecoder, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU)), + enforce_decoder_input_project=True))) + +# optimizer + +# weight_decay = 0.01 +# norm_weight_decay = 0.0 +# embed_weight_decay = 0.0 +embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +norm_multi = dict(lr_mult=1.0, decay_mult=0.0) +custom_keys = { + 'norm': norm_multi, + 'absolute_pos_embed': embed_multi, + 'relative_position_bias_table': embed_multi, + 'query_embed': embed_multi +} + +optim_wrapper.update( + dict( + optimizer=dict(lr=6e-5, weight_decay=0.01), + paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))) + +max_epochs = 300 + +# learning rate +param_scheduler = [ + dict(type=LinearLR, start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type=MultiStepLR, + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[250], + gamma=0.1) +] + +train_cfg.update(dict(max_epochs=max_epochs)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (64 GPUs) x (1 samples per GPU) +auto_scale_lr.update(dict(base_batch_size=64)) diff --git a/mmdet/configs/panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py b/mmdet/configs/panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8932803ca0d1fd52bee7d450fc12898e0ec7b3 --- /dev/null +++ b/mmdet/configs/panoptic_fpn/panoptic_fpn_r50_fpn_1x_coco.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mask_rcnn_r50_fpn import * + from .._base_.datasets.coco_panoptic import * + from .._base_.schedules.schedule_1x import * + from .._base_.default_runtime import * + +from mmcv.ops import nms +from torch.nn import GroupNorm + +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.detectors.panoptic_fpn import PanopticFPN +from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmdet.models.seg_heads.panoptic_fpn_head import PanopticFPNHead +from mmdet.models.seg_heads.panoptic_fusion_heads import HeuristicFusionHead + +model.update( + dict( + type=PanopticFPN, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255), + semantic_head=dict( + type=PanopticFPNHead, + num_things_classes=80, + num_stuff_classes=53, + in_channels=256, + inner_channels=128, + start_level=0, + end_level=4, + norm_cfg=dict(type=GroupNorm, num_groups=32, requires_grad=True), + conv_cfg=None, + loss_seg=dict( + type=CrossEntropyLoss, ignore_index=255, loss_weight=0.5)), + panoptic_fusion_head=dict( + type=HeuristicFusionHead, + num_things_classes=80, + num_stuff_classes=53), + test_cfg=dict( + rcnn=dict( + score_thr=0.6, + nms=dict(type=nms, iou_threshold=0.5, class_agnostic=True), + max_per_img=100, + mask_thr_binary=0.5), + # used in HeuristicFusionHead + panoptic=dict(mask_overlap=0.5, stuff_area_limit=4096)))) + +# Forced to remove NumClassCheckHook +custom_hooks = [] diff --git a/mmdet/configs/qdtrack/qdtrack_faster_rcnn_r50_fpn_4e_base.py b/mmdet/configs/qdtrack/qdtrack_faster_rcnn_r50_fpn_4e_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c672e82c6498092b57c389be01af64a9e26d14bc --- /dev/null +++ b/mmdet/configs/qdtrack/qdtrack_faster_rcnn_r50_fpn_4e_base.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.faster_rcnn_r50_fpn import * + from .._base_.models.faster_rcnn_r50_fpn import model + from .._base_.default_runtime import * + +from mmcv.ops import RoIAlign +from mmengine.hooks import LoggerHook, SyncBuffersHook +from mmengine.model.weight_init import PretrainedInit +from mmengine.optim import MultiStepLR, OptimWrapper +from mmengine.runner.runner import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.normalization import GroupNorm +from torch.optim import SGD + +from mmdet.engine.hooks import TrackVisualizationHook +from mmdet.models import (QDTrack, QuasiDenseEmbedHead, QuasiDenseTracker, + QuasiDenseTrackHead, SingleRoIExtractor, + TrackDataPreprocessor) +from mmdet.models.losses import (L1Loss, MarginL2Loss, + MultiPosCrossEntropyLoss, SmoothL1Loss) +from mmdet.models.task_modules import (CombinedSampler, + InstanceBalancedPosSampler, + MaxIoUAssigner, RandomSampler) +from mmdet.visualization import TrackLocalVisualizer + +detector = model +detector.pop('data_preprocessor') + +detector['backbone'].update( + dict( + norm_cfg=dict(type=BatchNorm2d, requires_grad=False), + style='caffe', + init_cfg=dict( + type=PretrainedInit, + checkpoint='open-mmlab://detectron2/resnet50_caffe'))) +detector.rpn_head.loss_bbox.update( + dict(type=SmoothL1Loss, beta=1.0 / 9.0, loss_weight=1.0)) +detector.rpn_head.bbox_coder.update(dict(clip_border=False)) +detector.roi_head.bbox_head.update(dict(num_classes=1)) +detector.roi_head.bbox_head.bbox_coder.update(dict(clip_border=False)) +detector['init_cfg'] = dict( + type=PretrainedInit, + checkpoint= # noqa: E251 + 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/' + 'faster_rcnn_r50_fpn_1x_coco-person/' + 'faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth' + # noqa: E501 +) +del model + +model = dict( + type=QDTrack, + data_preprocessor=dict( + type=TrackDataPreprocessor, + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_size_divisor=32), + detector=detector, + track_head=dict( + type=QuasiDenseTrackHead, + roi_extractor=dict( + type=SingleRoIExtractor, + roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + embed_head=dict( + type=QuasiDenseEmbedHead, + num_convs=4, + num_fcs=1, + embed_channels=256, + norm_cfg=dict(type=GroupNorm, num_groups=32), + loss_track=dict(type=MultiPosCrossEntropyLoss, loss_weight=0.25), + loss_track_aux=dict( + type=MarginL2Loss, + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + loss_bbox=dict(type=L1Loss, loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type=MaxIoUAssigner, + pos_iou_thr=0.7, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type=CombinedSampler, + num=256, + pos_fraction=0.5, + neg_pos_ub=3, + add_gt_as_proposals=True, + pos_sampler=dict(type=InstanceBalancedPosSampler), + neg_sampler=dict(type=RandomSampler)))), + tracker=dict( + type=QuasiDenseTracker, + init_score_thr=0.9, + obj_score_thr=0.5, + match_score_thr=0.5, + memo_tracklet_frames=30, + memo_backdrop_frames=1, + memo_momentum=0.8, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.3, + nms_class_iou_thr=0.7, + with_cats=True, + match_metric='bisoftmax')) +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=SGD, lr=0.02, momentum=0.9, weight_decay=0.0001), + clip_grad=dict(max_norm=35, norm_type=2)) +# learning policy +param_scheduler = [ + dict(type=MultiStepLR, begin=0, end=4, by_epoch=True, milestones=[3]) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=4, val_interval=4) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +default_hooks.update( + logger=dict(type=LoggerHook, interval=50), + visualization=dict(type=TrackVisualizationHook, draw=False)) + +visualizer.update( + type=TrackLocalVisualizer, vis_backends=vis_backends, name='visualizer') + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type=SyncBuffersHook) +] diff --git a/mmdet/configs/qdtrack/qdtrack_faster_rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py b/mmdet/configs/qdtrack/qdtrack_faster_rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa715e1b3806f9f9816e3b23a100167c791f0b8 --- /dev/null +++ b/mmdet/configs/qdtrack/qdtrack_faster_rcnn_r50_fpn_8xb2-4e_mot17halftrain_test-mot17halfval.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.mot_challenge import * + from .qdtrack_faster_rcnn_r50_fpn_4e_base import * + +from mmdet.evaluation import CocoVideoMetric, MOTChallengeMetric + +# evaluator +val_evaluator = [ + dict(type=CocoVideoMetric, metric=['bbox'], classwise=True), + dict(type=MOTChallengeMetric, metric=['HOTA', 'CLEAR', 'Identity']) +] diff --git a/mmdet/configs/retinanet/retinanet_r50_fpn_1x_coco.py b/mmdet/configs/retinanet/retinanet_r50_fpn_1x_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..847600e61b3daf556ff24d06af2f08249deb2284 --- /dev/null +++ b/mmdet/configs/retinanet/retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.retinanet_r50_fpn import * + from .._base_.datasets.coco_detection import * + from .._base_.schedules.schedule_1x import * + from .._base_.default_runtime import * + from .retinanet_tta import * + +from torch.optim.sgd import SGD + +# optimizer +optim_wrapper.update( + dict(optimizer=dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0001))) diff --git a/mmdet/configs/retinanet/retinanet_tta.py b/mmdet/configs/retinanet/retinanet_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..4e340e5854e58a332ee174b0e69e7f3f9ec2c486 --- /dev/null +++ b/mmdet/configs/retinanet/retinanet_tta.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import TestTimeAug + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import RandomFlip, Resize +from mmdet.models.test_time_augs.det_tta import DetTTAModel + +tta_model = dict( + type=DetTTAModel, + tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)) + +img_scales = [(1333, 800), (666, 400), (2000, 1200)] +tta_pipeline = [ + dict(type=LoadImageFromFile, backend_args=None), + dict( + type=TestTimeAug, + transforms=[ + [dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales], + [dict(type=RandomFlip, prob=1.), + dict(type=RandomFlip, prob=0.)], + [dict(type=LoadAnnotations, with_bbox=True)], + [ + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')) + ] + ]) +] diff --git a/mmdet/configs/rtmdet/rtmdet_ins_l_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_ins_l_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..302d7cda110b7a598ba525549e1d96d27ee51990 --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_ins_l_8xb32_300e_coco.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_l_8xb32_300e_coco import * + +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import RandomResize +from mmengine.hooks.ema_hook import EMAHook +from torch.nn.modules.activation import SiLU + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + Resize, YOLOXHSVRandomAug) +from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook +from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsSepBNHead +from mmdet.models.layers.ema import ExpMomentumEMA +from mmdet.models.losses.dice_loss import DiceLoss +from mmdet.models.losses.gfocal_loss import QualityFocalLoss +from mmdet.models.losses.iou_loss import GIoULoss +from mmdet.models.task_modules.coders.distance_point_bbox_coder import \ + DistancePointBBoxCoder +from mmdet.models.task_modules.prior_generators.point_generator import \ + MlvlPointGenerator + +model.merge( + dict( + bbox_head=dict( + _delete_=True, + type=RTMDetInsSepBNHead, + num_classes=80, + in_channels=256, + stacked_convs=2, + share_conv=True, + pred_kernel_size=1, + feat_channels=256, + act_cfg=dict(type=SiLU, inplace=True), + norm_cfg=dict(type='SyncBN', requires_grad=True), + anchor_generator=dict( + type=MlvlPointGenerator, offset=0, strides=[8, 16, 32]), + bbox_coder=dict(type=DistancePointBBoxCoder), + loss_cls=dict( + type=QualityFocalLoss, + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type=GIoULoss, loss_weight=2.0), + loss_mask=dict( + type=DiceLoss, loss_weight=2.0, eps=5e-6, reduction='mean')), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100, + mask_thr_binary=0.5), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0), + dict( + type=RandomResize, + scale=(1280, 1280), + ratio_range=(0.1, 2.0), + resize_type=Resize, + keep_ratio=True), + dict( + type=RandomCrop, + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type=CachedMixUp, + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)), + dict(type=PackDetInputs) +] + +train_dataloader.update( + dict(pin_memory=True, dataset=dict(pipeline=train_pipeline))) + +train_pipeline_stage2 = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=RandomResize, + scale=(640, 640), + ratio_range=(0.1, 2.0), + resize_type=Resize, + keep_ratio=True), + dict( + type=RandomCrop, + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type=PackDetInputs) +] +custom_hooks = [ + dict( + type=EMAHook, + ema_type=ExpMomentumEMA, + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type=PipelineSwitchHook, + switch_epoch=280, + switch_pipeline=train_pipeline_stage2) +] + +val_evaluator.update(dict(metric=['bbox', 'segm'])) +test_evaluator = val_evaluator diff --git a/mmdet/configs/rtmdet/rtmdet_ins_m_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_ins_m_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..d90be9293a18cfd703ee1d9993b03237fb3c3dab --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_ins_m_8xb32_300e_coco.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_ins_l_8xb32_300e_coco import * + +model.update( + dict( + backbone=dict(deepen_factor=0.67, widen_factor=0.75), + neck=dict( + in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2), + bbox_head=dict(in_channels=192, feat_channels=192))) diff --git a/mmdet/configs/rtmdet/rtmdet_ins_s_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_ins_s_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..58b5b1aff0cff8d770798288b74237bc5183d37b --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_ins_s_8xb32_300e_coco.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_ins_l_8xb32_300e_coco import * + +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import RandomResize +from mmengine.hooks.ema_hook import EMAHook + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + Resize, YOLOXHSVRandomAug) +from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook +from mmdet.models.layers.ema import ExpMomentumEMA + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa +model.update( + dict( + backbone=dict( + deepen_factor=0.33, + widen_factor=0.5, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict( + in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(in_channels=128, feat_channels=128))) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0), + dict( + type=RandomResize, + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + resize_type=Resize, + keep_ratio=True), + dict( + type=RandomCrop, + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type=CachedMixUp, + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)), + dict(type=PackDetInputs) +] + +train_pipeline_stage2 = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=RandomResize, + scale=(640, 640), + ratio_range=(0.5, 2.0), + resize_type=Resize, + keep_ratio=True), + dict( + type=RandomCrop, + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type=PackDetInputs) +] + +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) + +custom_hooks = [ + dict( + type=EMAHook, + ema_type=ExpMomentumEMA, + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type=PipelineSwitchHook, + switch_epoch=280, + switch_pipeline=train_pipeline_stage2) +] diff --git a/mmdet/configs/rtmdet/rtmdet_ins_tiny_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_ins_tiny_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..0356b1951da584034cf65014a39c7440fc3da56d --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_ins_tiny_8xb32_300e_coco.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_ins_s_8xb32_300e_coco import * + +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import RandomResize + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import (FilterAnnotations, + LoadAnnotations) +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + Resize, YOLOXHSVRandomAug) + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa + +model.update( + dict( + backbone=dict( + deepen_factor=0.167, + widen_factor=0.375, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict( + in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1), + bbox_head=dict(in_channels=96, feat_channels=96))) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict( + type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False), + dict( + type=CachedMosaic, + img_scale=(640, 640), + pad_val=114.0, + max_cached_images=20, + random_pop=False), + dict( + type=RandomResize, + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + resize_type=Resize, + keep_ratio=True), + dict(type=RandomCrop, crop_size=(640, 640)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type=CachedMixUp, + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=10, + random_pop=False, + pad_val=(114, 114, 114), + prob=0.5), + dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)), + dict(type=PackDetInputs) +] + +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) diff --git a/mmdet/configs/rtmdet/rtmdet_ins_x_8xb16_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_ins_x_8xb16_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..555b10102f67ee625d65dbfe0894eb4b41198595 --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_ins_x_8xb16_300e_coco.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_ins_l_8xb32_300e_coco import * +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR + +model.update( + dict( + backbone=dict(deepen_factor=1.33, widen_factor=1.25), + neck=dict( + in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), + bbox_head=dict(in_channels=320, feat_channels=320))) + +base_lr = 0.002 + +# optimizer +optim_wrapper.update(dict(optimizer=dict(lr=base_lr))) + +# learning rate +param_scheduler = [ + dict( + type=LinearLR, start_factor=1.0e-5, by_epoch=False, begin=0, end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type=CosineAnnealingLR, + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] diff --git a/mmdet/configs/rtmdet/rtmdet_l_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_l_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..5dcda7bf994db9f3f5c785d8dea824b3ab8e56a2 --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_l_8xb32_300e_coco.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + from .._base_.schedules.schedule_1x import * + from .._base_.datasets.coco_detection import * + from .rtmdet_tta import * + +from mmcv.ops import nms +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import RandomResize +from mmengine.hooks.ema_hook import EMAHook +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from torch.nn import SyncBatchNorm +from torch.nn.modules.activation import SiLU +from torch.optim.adamw import AdamW + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + Resize, YOLOXHSVRandomAug) +from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook +from mmdet.models.backbones.cspnext import CSPNeXt +from mmdet.models.data_preprocessors.data_preprocessor import \ + DetDataPreprocessor +from mmdet.models.dense_heads.rtmdet_head import RTMDetSepBNHead +from mmdet.models.detectors.rtmdet import RTMDet +from mmdet.models.layers.ema import ExpMomentumEMA +from mmdet.models.losses.gfocal_loss import QualityFocalLoss +from mmdet.models.losses.iou_loss import GIoULoss +from mmdet.models.necks.cspnext_pafpn import CSPNeXtPAFPN +from mmdet.models.task_modules.assigners.dynamic_soft_label_assigner import \ + DynamicSoftLabelAssigner +from mmdet.models.task_modules.coders.distance_point_bbox_coder import \ + DistancePointBBoxCoder +from mmdet.models.task_modules.prior_generators.point_generator import \ + MlvlPointGenerator + +model = dict( + type=RTMDet, + data_preprocessor=dict( + type=DetDataPreprocessor, + mean=[103.53, 116.28, 123.675], + std=[57.375, 57.12, 58.395], + bgr_to_rgb=False, + batch_augments=None), + backbone=dict( + type=CSPNeXt, + arch='P5', + expand_ratio=0.5, + deepen_factor=1, + widen_factor=1, + channel_attention=True, + norm_cfg=dict(type=SyncBatchNorm), + act_cfg=dict(type=SiLU, inplace=True)), + neck=dict( + type=CSPNeXtPAFPN, + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + expand_ratio=0.5, + norm_cfg=dict(type=SyncBatchNorm), + act_cfg=dict(type=SiLU, inplace=True)), + bbox_head=dict( + type=RTMDetSepBNHead, + num_classes=80, + in_channels=256, + stacked_convs=2, + feat_channels=256, + anchor_generator=dict( + type=MlvlPointGenerator, offset=0, strides=[8, 16, 32]), + bbox_coder=dict(type=DistancePointBBoxCoder), + loss_cls=dict( + type=QualityFocalLoss, use_sigmoid=True, beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type=GIoULoss, loss_weight=2.0), + with_objectness=False, + exp_on_reg=True, + share_conv=True, + pred_kernel_size=1, + norm_cfg=dict(type=SyncBatchNorm), + act_cfg=dict(type=SiLU, inplace=True)), + train_cfg=dict( + assigner=dict(type=DynamicSoftLabelAssigner, topk=13), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=30000, + min_bbox_size=0, + score_thr=0.001, + nms=dict(type=nms, iou_threshold=0.65), + max_per_img=300), +) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0), + dict( + type=RandomResize, + scale=(1280, 1280), + ratio_range=(0.1, 2.0), + resize_type=Resize, + keep_ratio=True), + dict(type=RandomCrop, crop_size=(640, 640)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type=CachedMixUp, + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type=PackDetInputs) +] + +train_pipeline_stage2 = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=RandomResize, + scale=(640, 640), + ratio_range=(0.1, 2.0), + resize_type=Resize, + keep_ratio=True), + dict(type=RandomCrop, crop_size=(640, 640)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type=PackDetInputs) +] + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(640, 640), keep_ratio=True), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader.update( + dict( + batch_size=32, + num_workers=10, + batch_sampler=None, + pin_memory=True, + dataset=dict(pipeline=train_pipeline))) +val_dataloader.update( + dict(batch_size=5, num_workers=10, dataset=dict(pipeline=test_pipeline))) +test_dataloader = val_dataloader + +max_epochs = 300 +stage2_num_epochs = 20 +base_lr = 0.004 +interval = 10 + +train_cfg.update( + dict( + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])) + +val_evaluator.update(dict(proposal_nums=(100, 1, 10))) +test_evaluator = val_evaluator + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=AdamW, lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type=LinearLR, start_factor=1.0e-5, by_epoch=False, begin=0, end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type=CosineAnnealingLR, + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# hooks +default_hooks.update( + dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + ))) + +custom_hooks = [ + dict( + type=EMAHook, + ema_type=ExpMomentumEMA, + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type=PipelineSwitchHook, + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] diff --git a/mmdet/configs/rtmdet/rtmdet_m_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_m_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..e741d8220fe8831894b7b803060031e18dbac62b --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_m_8xb32_300e_coco.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_l_8xb32_300e_coco import * + +model.update( + dict( + backbone=dict(deepen_factor=0.67, widen_factor=0.75), + neck=dict( + in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2), + bbox_head=dict(in_channels=192, feat_channels=192))) diff --git a/mmdet/configs/rtmdet/rtmdet_s_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_s_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..db21b747e95a15c69af1c17c16a5e6cfd4a2be78 --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_s_8xb32_300e_coco.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_l_8xb32_300e_coco import * + +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import RandomResize +from mmengine.hooks.ema_hook import EMAHook + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + Resize, YOLOXHSVRandomAug) +from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook +from mmdet.models.layers.ema import ExpMomentumEMA + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa +model.update( + dict( + backbone=dict( + deepen_factor=0.33, + widen_factor=0.5, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict( + in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(in_channels=128, feat_channels=128, exp_on_reg=False))) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0), + dict( + type=RandomResize, + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + resize_type=Resize, + keep_ratio=True), + dict(type=RandomCrop, crop_size=(640, 640)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type=CachedMixUp, + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type=PackDetInputs) +] + +train_pipeline_stage2 = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=RandomResize, + scale=(640, 640), + ratio_range=(0.5, 2.0), + resize_type=Resize, + keep_ratio=True), + dict(type=RandomCrop, crop_size=(640, 640)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type=PackDetInputs) +] + +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) + +custom_hooks = [ + dict( + type=EMAHook, + ema_type=ExpMomentumEMA, + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type=PipelineSwitchHook, + switch_epoch=280, + switch_pipeline=train_pipeline_stage2) +] diff --git a/mmdet/configs/rtmdet/rtmdet_tiny_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_tiny_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..949d056f16303751d121aeba8f3d859de07b06d2 --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_tiny_8xb32_300e_coco.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_s_8xb32_300e_coco import * + +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import RandomResize + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic, + Pad, RandomCrop, RandomFlip, + Resize, YOLOXHSVRandomAug) + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa + +model.update( + dict( + backbone=dict( + deepen_factor=0.167, + widen_factor=0.375, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict( + in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1), + bbox_head=dict(in_channels=96, feat_channels=96, exp_on_reg=False))) + +train_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=CachedMosaic, + img_scale=(640, 640), + pad_val=114.0, + max_cached_images=20, + random_pop=False), + dict( + type=RandomResize, + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + resize_type=Resize, + keep_ratio=True), + dict(type=RandomCrop, crop_size=(640, 640)), + dict(type=YOLOXHSVRandomAug), + dict(type=RandomFlip, prob=0.5), + dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type=CachedMixUp, + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=10, + random_pop=False, + pad_val=(114, 114, 114), + prob=0.5), + dict(type=PackDetInputs) +] + +train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline))) diff --git a/mmdet/configs/rtmdet/rtmdet_tta.py b/mmdet/configs/rtmdet/rtmdet_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..f27b7aa4a3bf13a28cab3e25be755a9792620ece --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_tta.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import TestTimeAug + +from mmdet.datasets.transforms.formatting import PackDetInputs +from mmdet.datasets.transforms.loading import LoadAnnotations +from mmdet.datasets.transforms.transforms import Pad, RandomFlip, Resize +from mmdet.models.test_time_augs.det_tta import DetTTAModel + +tta_model = dict( + type=DetTTAModel, + tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) + +img_scales = [(640, 640), (320, 320), (960, 960)] + +tta_pipeline = [ + dict(type=LoadImageFromFile, backend_args=None), + dict( + type=TestTimeAug, + transforms=[ + [dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales], + [ + # ``RandomFlip`` must be placed before ``Pad``, otherwise + # bounding box coordinates after flipping cannot be + # recovered correctly. + dict(type=RandomFlip, prob=1.), + dict(type=RandomFlip, prob=0.) + ], + [ + dict( + type=Pad, + size=(960, 960), + pad_val=dict(img=(114, 114, 114))), + ], + [dict(type=LoadAnnotations, with_bbox=True)], + [ + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')) + ] + ]) +] diff --git a/mmdet/configs/rtmdet/rtmdet_x_8xb32_300e_coco.py b/mmdet/configs/rtmdet/rtmdet_x_8xb32_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..04d67d0ca8f08860462eb0eafc645c403e792394 --- /dev/null +++ b/mmdet/configs/rtmdet/rtmdet_x_8xb32_300e_coco.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa +# mmcv >= 2.0.1 +# mmengine >= 0.8.0 + +from mmengine.config import read_base + +with read_base(): + from .rtmdet_l_8xb32_300e_coco import * + +model.update( + dict( + backbone=dict(deepen_factor=1.33, widen_factor=1.25), + neck=dict( + in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), + bbox_head=dict(in_channels=320, feat_channels=320))) diff --git a/mmdet/datasets/.DS_Store b/mmdet/datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4174e96aaa7a6a3e53b99897a4a9a28a1f6cd38f Binary files /dev/null and b/mmdet/datasets/.DS_Store differ diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..044efe4cad7d64ab641f8dde017e48b7e8b9cc09 --- /dev/null +++ b/mmdet/datasets/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ade20k import (ADE20KInstanceDataset, ADE20KPanopticDataset, + ADE20KSegDataset) +from .base_det_dataset import BaseDetDataset +from .base_semseg_dataset import BaseSegDataset +from .base_video_dataset import BaseVideoDataset +from .cityscapes import CityscapesDataset +from .coco import CocoDataset +from .coco_caption import CocoCaptionDataset +from .coco_panoptic import CocoPanopticDataset +from .coco_semantic import CocoSegDataset +from .crowdhuman import CrowdHumanDataset +from .dataset_wrappers import ConcatDataset, MultiImageMixDataset +from .deepfashion import DeepFashionDataset +from .dsdl import DSDLDetDataset +from .isaid import iSAIDDataset +from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset +from .mot_challenge_dataset import MOTChallengeDataset +from .objects365 import Objects365V1Dataset, Objects365V2Dataset +from .openimages import OpenImagesChallengeDataset, OpenImagesDataset +from .refcoco import RefCocoDataset +from .reid_dataset import ReIDDataset +from .samplers import (AspectRatioBatchSampler, ClassAwareSampler, + GroupMultiSourceSampler, MultiSourceSampler, + TrackAspectRatioBatchSampler, TrackImgSampler) +from .utils import get_loading_pipeline +from .v3det import V3DetDataset +from .voc import VOCDataset +from .wider_face import WIDERFaceDataset +from .xml_style import XMLDataset +from .youtube_vis_dataset import YouTubeVISDataset + +__all__ = [ + 'XMLDataset', 'CocoDataset', 'DeepFashionDataset', 'VOCDataset', + 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', 'LVISV1Dataset', + 'WIDERFaceDataset', 'get_loading_pipeline', 'CocoPanopticDataset', + 'MultiImageMixDataset', 'OpenImagesDataset', 'OpenImagesChallengeDataset', + 'AspectRatioBatchSampler', 'ClassAwareSampler', 'MultiSourceSampler', + 'GroupMultiSourceSampler', 'BaseDetDataset', 'CrowdHumanDataset', + 'Objects365V1Dataset', 'Objects365V2Dataset', 'DSDLDetDataset', + 'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler', + 'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler', + 'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset', + 'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset', + 'ADE20KInstanceDataset', 'iSAIDDataset', 'V3DetDataset', 'ConcatDataset' +] diff --git a/mmdet/datasets/ade20k.py b/mmdet/datasets/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..573271cb5d0cb83571564272895bddde9a5f6ad7 --- /dev/null +++ b/mmdet/datasets/ade20k.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +from mmengine import fileio + +from mmdet.registry import DATASETS +from .base_semseg_dataset import BaseSegDataset +from .coco import CocoDataset +from .coco_panoptic import CocoPanopticDataset + +ADE_PALETTE = [(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50), + (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255), + (230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7), + (150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82), + (143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3), + (0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220), + (255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224), + (255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153), + (6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255), + (140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0), + (255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255), + (255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255), + (11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255), + (0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0), + (255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0), + (0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255), + (173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255), + (255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20), + (255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255), + (255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255), + (0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255), + (0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0), + (143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255), + (255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112), + (92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160), + (163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163), + (255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0), + (255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0), + (10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255), + (255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204), + (41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255), + (71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255), + (184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255)] + + +@DATASETS.register_module() +class ADE20KPanopticDataset(CocoPanopticDataset): + METAINFO = { + 'classes': + ('bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain', + 'chair', 'car', 'painting, picture', 'sofa', 'shelf', 'mirror', + 'armchair', 'seat', 'fence', 'desk', 'wardrobe, closet, press', + 'lamp', 'tub', 'rail', 'cushion', 'box', 'column, pillar', + 'signboard, sign', 'chest of drawers, chest, bureau, dresser', + 'counter', 'sink', 'fireplace', 'refrigerator, icebox', 'stairs', + 'case, display case, showcase, vitrine', + 'pool table, billiard table, snooker table', 'pillow', + 'screen door, screen', 'bookcase', 'coffee table', + 'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower', + 'book', 'bench', 'countertop', 'stove', 'palm, palm tree', + 'kitchen island', 'computer', 'swivel chair', 'boat', + 'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier', + 'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv', + 'airplane', 'clothes', 'pole', + 'bannister, banister, balustrade, balusters, handrail', + 'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'van', 'ship', + 'fountain', 'washer, automatic washer, washing machine', + 'plaything, toy', 'stool', 'barrel, cask', 'basket, handbasket', + 'bag', 'minibike, motorbike', 'oven', 'ball', 'food, solid food', + 'step, stair', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', + 'dishwasher', 'screen', 'sculpture', 'hood, exhaust hood', 'sconce', + 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'plate', + 'monitor', 'bulletin board', 'radiator', 'glass, drinking glass', + 'clock', 'flag', 'wall', 'building', 'sky', 'floor', 'tree', + 'ceiling', 'road, route', 'grass', 'sidewalk, pavement', + 'earth, ground', 'mountain, mount', 'plant', 'water', 'house', 'sea', + 'rug', 'field', 'rock, stone', 'base, pedestal, stand', 'sand', + 'skyscraper', 'grandstand, covered stand', 'path', 'runway', + 'stairway, staircase', 'river', 'bridge, span', 'blind, screen', + 'hill', 'bar', 'hovel, hut, hutch, shack, shanty', 'tower', + 'dirt track', 'land, ground, soil', + 'escalator, moving staircase, moving stairway', + 'buffet, counter, sideboard', + 'poster, posting, placard, notice, bill, card', 'stage', + 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', + 'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank', + 'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'), + 'thing_classes': + ('bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain', + 'chair', 'car', 'painting, picture', 'sofa', 'shelf', 'mirror', + 'armchair', 'seat', 'fence', 'desk', 'wardrobe, closet, press', + 'lamp', 'tub', 'rail', 'cushion', 'box', 'column, pillar', + 'signboard, sign', 'chest of drawers, chest, bureau, dresser', + 'counter', 'sink', 'fireplace', 'refrigerator, icebox', 'stairs', + 'case, display case, showcase, vitrine', + 'pool table, billiard table, snooker table', 'pillow', + 'screen door, screen', 'bookcase', 'coffee table', + 'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower', + 'book', 'bench', 'countertop', 'stove', 'palm, palm tree', + 'kitchen island', 'computer', 'swivel chair', 'boat', + 'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier', + 'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv', + 'airplane', 'clothes', 'pole', + 'bannister, banister, balustrade, balusters, handrail', + 'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'van', 'ship', + 'fountain', 'washer, automatic washer, washing machine', + 'plaything, toy', 'stool', 'barrel, cask', 'basket, handbasket', + 'bag', 'minibike, motorbike', 'oven', 'ball', 'food, solid food', + 'step, stair', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', + 'dishwasher', 'screen', 'sculpture', 'hood, exhaust hood', 'sconce', + 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'plate', + 'monitor', 'bulletin board', 'radiator', 'glass, drinking glass', + 'clock', 'flag'), + 'stuff_classes': + ('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road, route', + 'grass', 'sidewalk, pavement', 'earth, ground', 'mountain, mount', + 'plant', 'water', 'house', 'sea', 'rug', 'field', 'rock, stone', + 'base, pedestal, stand', 'sand', 'skyscraper', + 'grandstand, covered stand', 'path', 'runway', 'stairway, staircase', + 'river', 'bridge, span', 'blind, screen', 'hill', 'bar', + 'hovel, hut, hutch, shack, shanty', 'tower', 'dirt track', + 'land, ground, soil', 'escalator, moving staircase, moving stairway', + 'buffet, counter, sideboard', + 'poster, posting, placard, notice, bill, card', 'stage', + 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', + 'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank', + 'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'), + 'palette': + ADE_PALETTE + } + + +@DATASETS.register_module() +class ADE20KInstanceDataset(CocoDataset): + METAINFO = { + 'classes': + ('bed', 'windowpane', 'cabinet', 'person', 'door', 'table', 'curtain', + 'chair', 'car', 'painting', 'sofa', 'shelf', 'mirror', 'armchair', + 'seat', 'fence', 'desk', 'wardrobe', 'lamp', 'bathtub', 'railing', + 'cushion', 'box', 'column', 'signboard', 'chest of drawers', + 'counter', 'sink', 'fireplace', 'refrigerator', 'stairs', 'case', + 'pool table', 'pillow', 'screen door', 'bookcase', 'coffee table', + 'toilet', 'flower', 'book', 'bench', 'countertop', 'stove', 'palm', + 'kitchen island', 'computer', 'swivel chair', 'boat', + 'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier', + 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', + 'apparel', 'pole', 'bannister', 'ottoman', 'bottle', 'van', 'ship', + 'fountain', 'washer', 'plaything', 'stool', 'barrel', 'basket', 'bag', + 'minibike', 'oven', 'ball', 'food', 'step', 'trade name', 'microwave', + 'pot', 'animal', 'bicycle', 'dishwasher', 'screen', 'sculpture', + 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', + 'plate', 'monitor', 'bulletin board', 'radiator', 'glass', 'clock', + 'flag'), + 'palette': [(204, 5, 255), (230, 230, 230), (224, 5, 255), + (150, 5, 61), (8, 255, 51), (255, 6, 82), (255, 51, 7), + (204, 70, 3), (0, 102, 200), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (220, 220, 220), (8, 255, 214), + (7, 255, 224), (255, 184, 6), (10, 255, 71), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (0, 255, 20), (255, 8, 41), (255, 5, 153), (6, 51, 255), + (235, 12, 255), (0, 163, 255), (250, 10, 15), (20, 255, 0), + (255, 224, 0), (0, 0, 255), (255, 71, 0), (0, 235, 255), + (0, 173, 255), (0, 255, 245), (0, 255, 112), (0, 255, 133), + (255, 0, 0), (255, 163, 0), (194, 255, 0), (0, 143, 255), + (51, 255, 0), (0, 82, 255), (0, 255, 41), (0, 255, 173), + (10, 0, 255), (173, 255, 0), (255, 92, 0), (255, 0, 245), + (255, 0, 102), (255, 173, 0), (255, 0, 20), (0, 31, 255), + (0, 255, 61), (0, 71, 255), (255, 0, 204), (0, 255, 194), + (0, 255, 82), (0, 112, 255), (51, 0, 255), (0, 122, 255), + (255, 153, 0), (0, 255, 10), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (184, 0, 255), (255, 0, 31), (0, 214, 255), + (255, 0, 112), (92, 255, 0), (70, 184, 160), (163, 0, 255), + (71, 255, 0), (255, 0, 163), (255, 204, 0), (255, 0, 143), + (133, 255, 0), (255, 0, 235), (245, 0, 255), (255, 0, 122), + (255, 245, 0), (214, 255, 0), (0, 204, 255), (255, 255, 0), + (0, 153, 255), (0, 41, 255), (0, 255, 204), (41, 0, 255), + (41, 255, 0), (173, 0, 255), (0, 245, 255), (0, 255, 184), + (0, 92, 255), (184, 255, 0), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255)], + } + + +@DATASETS.register_module() +class ADE20KSegDataset(BaseSegDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. The ``img_suffix`` is fixed to '.jpg', + and ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', + 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', + 'person', 'earth', 'door', 'table', 'mountain', 'plant', + 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', + 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', + 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', + 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', + 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', + 'screen door', 'stairway', 'river', 'bridge', 'bookcase', + 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', + 'bench', 'countertop', 'stove', 'palm', 'kitchen island', + 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', + 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', + 'television receiver', 'airplane', 'dirt track', 'apparel', + 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', + 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', + 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', + 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', + 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', + 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', + 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', + 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', + 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag'), + palette=ADE_PALETTE) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + return_classes=False, + **kwargs) -> None: + self.return_classes = return_classes + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + List[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + if self.return_classes: + data_info['text'] = list(self._metainfo['classes']) + data_list.append(data_info) + return data_list diff --git a/mmdet/datasets/api_wrappers/__init__.py b/mmdet/datasets/api_wrappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3c41a2f87b14d10339955208e0502aeeeb7082 --- /dev/null +++ b/mmdet/datasets/api_wrappers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coco_api import COCO, COCOeval, COCOPanoptic +from .cocoeval_mp import COCOevalMP + +__all__ = ['COCO', 'COCOeval', 'COCOPanoptic', 'COCOevalMP'] diff --git a/mmdet/datasets/api_wrappers/coco_api.py b/mmdet/datasets/api_wrappers/coco_api.py new file mode 100644 index 0000000000000000000000000000000000000000..40f7f2c9b930de3dadd967db9d131913fc9bf54c --- /dev/null +++ b/mmdet/datasets/api_wrappers/coco_api.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This file add snake case alias for coco api + +import warnings +from collections import defaultdict +from typing import List, Optional, Union + +import pycocotools +from pycocotools.coco import COCO as _COCO +from pycocotools.cocoeval import COCOeval as _COCOeval + + +class COCO(_COCO): + """This class is almost the same as official pycocotools package. + + It implements some snake case function aliases. So that the COCO class has + the same interface as LVIS class. + """ + + def __init__(self, annotation_file=None): + if getattr(pycocotools, '__version__', '0') >= '12.0.2': + warnings.warn( + 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501 + UserWarning) + super().__init__(annotation_file=annotation_file) + self.img_ann_map = self.imgToAnns + self.cat_img_map = self.catToImgs + + def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None): + return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd) + + def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]): + return self.getCatIds(cat_names, sup_names, cat_ids) + + def get_img_ids(self, img_ids=[], cat_ids=[]): + return self.getImgIds(img_ids, cat_ids) + + def load_anns(self, ids): + return self.loadAnns(ids) + + def load_cats(self, ids): + return self.loadCats(ids) + + def load_imgs(self, ids): + return self.loadImgs(ids) + + +# just for the ease of import +COCOeval = _COCOeval + + +class COCOPanoptic(COCO): + """This wrapper is for loading the panoptic style annotation file. + + The format is shown in the CocoPanopticDataset class. + + Args: + annotation_file (str, optional): Path of annotation file. + Defaults to None. + """ + + def __init__(self, annotation_file: Optional[str] = None) -> None: + super(COCOPanoptic, self).__init__(annotation_file) + + def createIndex(self) -> None: + """Create index.""" + # create index + print('creating index...') + # anns stores 'segment_id -> annotation' + anns, cats, imgs = {}, {}, {} + img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list) + if 'annotations' in self.dataset: + for ann in self.dataset['annotations']: + for seg_ann in ann['segments_info']: + # to match with instance.json + seg_ann['image_id'] = ann['image_id'] + img_to_anns[ann['image_id']].append(seg_ann) + # segment_id is not unique in coco dataset orz... + # annotations from different images but + # may have same segment_id + if seg_ann['id'] in anns.keys(): + anns[seg_ann['id']].append(seg_ann) + else: + anns[seg_ann['id']] = [seg_ann] + + # filter out annotations from other images + img_to_anns_ = defaultdict(list) + for k, v in img_to_anns.items(): + img_to_anns_[k] = [x for x in v if x['image_id'] == k] + img_to_anns = img_to_anns_ + + if 'images' in self.dataset: + for img_info in self.dataset['images']: + img_info['segm_file'] = img_info['file_name'].replace( + 'jpg', 'png') + imgs[img_info['id']] = img_info + + if 'categories' in self.dataset: + for cat in self.dataset['categories']: + cats[cat['id']] = cat + + if 'annotations' in self.dataset and 'categories' in self.dataset: + for ann in self.dataset['annotations']: + for seg_ann in ann['segments_info']: + cat_to_imgs[seg_ann['category_id']].append(ann['image_id']) + + print('index created!') + + self.anns = anns + self.imgToAnns = img_to_anns + self.catToImgs = cat_to_imgs + self.imgs = imgs + self.cats = cats + + def load_anns(self, + ids: Union[List[int], int] = []) -> Optional[List[dict]]: + """Load anns with the specified ids. + + ``self.anns`` is a list of annotation lists instead of a + list of annotations. + + Args: + ids (Union[List[int], int]): Integer ids specifying anns. + + Returns: + anns (List[dict], optional): Loaded ann objects. + """ + anns = [] + + if hasattr(ids, '__iter__') and hasattr(ids, '__len__'): + # self.anns is a list of annotation lists instead of + # a list of annotations + for id in ids: + anns += self.anns[id] + return anns + elif type(ids) == int: + return self.anns[ids] diff --git a/mmdet/datasets/api_wrappers/cocoeval_mp.py b/mmdet/datasets/api_wrappers/cocoeval_mp.py new file mode 100644 index 0000000000000000000000000000000000000000..b3673ea7a7edc593cb49fb336f352a20c1b1015b --- /dev/null +++ b/mmdet/datasets/api_wrappers/cocoeval_mp.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import itertools +import time +from collections import defaultdict + +import numpy as np +import torch.multiprocessing as mp +from mmengine.logging import MMLogger +from pycocotools.cocoeval import COCOeval +from tqdm import tqdm + + +class COCOevalMP(COCOeval): + + def _prepare(self): + ''' + Prepare ._gts and ._dts for evaluation based on params + :return: None + ''' + + def _toMask(anns, coco): + # modify ann['segmentation'] by reference + for ann in anns: + rle = coco.annToRLE(ann) + ann['segmentation'] = rle + + p = self.params + if p.useCats: + gts = [] + dts = [] + img_ids = set(p.imgIds) + cat_ids = set(p.catIds) + for gt in self.cocoGt.dataset['annotations']: + if (gt['category_id'] in cat_ids) and (gt['image_id'] + in img_ids): + gts.append(gt) + for dt in self.cocoDt.dataset['annotations']: + if (dt['category_id'] in cat_ids) and (dt['image_id'] + in img_ids): + dts.append(dt) + # gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # noqa + # dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # noqa + # gts=self.cocoGt.dataset['annotations'] + # dts=self.cocoDt.dataset['annotations'] + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + + # convert ground truth to mask if iouType == 'segm' + if p.iouType == 'segm': + _toMask(gts, self.cocoGt) + _toMask(dts, self.cocoDt) + # set ignore flag + for gt in gts: + gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0 + gt['ignore'] = 'iscrowd' in gt and gt['iscrowd'] + if p.iouType == 'keypoints': + gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore'] + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + for gt in gts: + self._gts[gt['image_id'], gt['category_id']].append(gt) + for dt in dts: + self._dts[dt['image_id'], dt['category_id']].append(dt) + self.evalImgs = defaultdict( + list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + def evaluate(self): + """Run per image evaluation on given images and store results (a list + of dict) in self.evalImgs. + + :return: None + """ + tic = time.time() + print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'. + format(p.iouType)) + print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + nproc = 8 + split_size = len(catIds) // nproc + mp_params = [] + for i in range(nproc): + begin = i * split_size + end = (i + 1) * split_size + if i == nproc - 1: + end = len(catIds) + mp_params.append((catIds[begin:end], )) + + MMLogger.get_current_instance().info( + 'start multi processing evaluation ...') + with mp.Pool(nproc) as pool: + self.evalImgs = pool.starmap(self._evaluateImg, mp_params) + + self.evalImgs = list(itertools.chain(*self.evalImgs)) + + self._paramsEval = copy.deepcopy(self.params) + toc = time.time() + print('DONE (t={:0.2f}s).'.format(toc - tic)) + + def _evaluateImg(self, catids_chunk): + self._prepare() + p = self.params + maxDet = max(p.maxDets) + all_params = [] + for catId in catids_chunk: + for areaRng in p.areaRng: + for imgId in p.imgIds: + all_params.append((catId, areaRng, imgId)) + evalImgs = [ + self.evaluateImg(imgId, catId, areaRng, maxDet) + for catId, areaRng, imgId in tqdm(all_params) + ] + return evalImgs + + def evaluateImg(self, imgId, catId, aRng, maxDet): + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return None + + for g in gt: + if g['ignore'] or (g['area'] < aRng[0] or g['area'] > aRng[1]): + g['_ignore'] = 1 + else: + g['_ignore'] = 0 + + # sort dt highest score first, sort gt ignore last + gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') + gt = [gt[i] for i in gtind] + dtind = np.argsort([-d['score'] for d in dt], kind='mergesort') + dt = [dt[i] for i in dtind[0:maxDet]] + iscrowd = [int(o['iscrowd']) for o in gt] + # load computed ious + # ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] # noqa + ious = self.computeIoU(imgId, catId) + ious = ious[:, gtind] if len(ious) > 0 else ious + + T = len(p.iouThrs) + G = len(gt) + D = len(dt) + gtm = np.zeros((T, G)) + dtm = np.zeros((T, D)) + gtIg = np.array([g['_ignore'] for g in gt]) + dtIg = np.zeros((T, D)) + if not len(ious) == 0: + for tind, t in enumerate(p.iouThrs): + for dind, d in enumerate(dt): + # information about best match so far (m=-1 -> unmatched) + iou = min([t, 1 - 1e-10]) + m = -1 + for gind, g in enumerate(gt): + # if this gt already matched, and not a crowd, continue + if gtm[tind, gind] > 0 and not iscrowd[gind]: + continue + # if dt matched to reg gt, and on ignore gt, stop + if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1: + break + # continue to next gt unless better match made + if ious[dind, gind] < iou: + continue + # if match successful and best so far, + # store appropriately + iou = ious[dind, gind] + m = gind + # if match made store id of match for both dt and gt + if m == -1: + continue + dtIg[tind, dind] = gtIg[m] + dtm[tind, dind] = gt[m]['id'] + gtm[tind, m] = d['id'] + # set unmatched detections outside of area range to ignore + a = np.array([d['area'] < aRng[0] or d['area'] > aRng[1] + for d in dt]).reshape((1, len(dt))) + dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, + 0))) + # store results for given image and category + + return { + 'image_id': imgId, + 'category_id': catId, + 'aRng': aRng, + 'maxDet': maxDet, + 'dtIds': [d['id'] for d in dt], + 'gtIds': [g['id'] for g in gt], + 'dtMatches': dtm, + 'gtMatches': gtm, + 'dtScores': [d['score'] for d in dt], + 'gtIgnore': gtIg, + 'dtIgnore': dtIg, + } + + def summarize(self): + """Compute and display summary metrics for evaluation results. + + Note this function can *only* be applied on the default parameter + setting + """ + + def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100): + p = self.params + iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' # noqa + titleStr = 'Average Precision' if ap == 1 else 'Average Recall' + typeStr = '(AP)' if ap == 1 else '(AR)' + iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ + if iouThr is None else '{:0.2f}'.format(iouThr) + + aind = [ + i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng + ] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval['precision'] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval['recall'] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print( + iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, + mean_s)) + return mean_s + + def _summarizeDets(): + stats = [] + stats.append(_summarize(1, maxDets=self.params.maxDets[-1])) + stats.append( + _summarize(1, iouThr=.5, maxDets=self.params.maxDets[-1])) + stats.append( + _summarize(1, iouThr=.75, maxDets=self.params.maxDets[-1])) + for area_rng in ('small', 'medium', 'large'): + stats.append( + _summarize( + 1, areaRng=area_rng, maxDets=self.params.maxDets[-1])) + for max_det in self.params.maxDets: + stats.append(_summarize(0, maxDets=max_det)) + for area_rng in ('small', 'medium', 'large'): + stats.append( + _summarize( + 0, areaRng=area_rng, maxDets=self.params.maxDets[-1])) + stats = np.array(stats) + return stats + + def _summarizeKps(): + stats = np.zeros((10, )) + stats[0] = _summarize(1, maxDets=20) + stats[1] = _summarize(1, maxDets=20, iouThr=.5) + stats[2] = _summarize(1, maxDets=20, iouThr=.75) + stats[3] = _summarize(1, maxDets=20, areaRng='medium') + stats[4] = _summarize(1, maxDets=20, areaRng='large') + stats[5] = _summarize(0, maxDets=20) + stats[6] = _summarize(0, maxDets=20, iouThr=.5) + stats[7] = _summarize(0, maxDets=20, iouThr=.75) + stats[8] = _summarize(0, maxDets=20, areaRng='medium') + stats[9] = _summarize(0, maxDets=20, areaRng='large') + return stats + + if not self.eval: + raise Exception('Please run accumulate() first') + iouType = self.params.iouType + if iouType == 'segm' or iouType == 'bbox': + summarize = _summarizeDets + elif iouType == 'keypoints': + summarize = _summarizeKps + self.stats = summarize() diff --git a/mmdet/datasets/base_det_dataset.py b/mmdet/datasets/base_det_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..57bc7098387ee0bec2d53641c1ea7ce2c3cdf618 --- /dev/null +++ b/mmdet/datasets/base_det_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List, Optional + +from mmengine.dataset import BaseDataset +from mmengine.fileio import load +from mmengine.utils import is_abs + +from ..registry import DATASETS + + +@DATASETS.register_module() +class BaseDetDataset(BaseDataset): + """Base dataset for detection. + + Args: + proposal_file (str, optional): Proposals file path. Defaults to None. + file_client_args (dict): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + return_classes (bool): Whether to return class information + for open vocabulary-based algorithms. Defaults to False. + """ + + def __init__(self, + *args, + seg_map_suffix: str = '.png', + proposal_file: Optional[str] = None, + file_client_args: dict = None, + backend_args: dict = None, + return_classes: bool = False, + **kwargs) -> None: + self.seg_map_suffix = seg_map_suffix + self.proposal_file = proposal_file + self.backend_args = backend_args + self.return_classes = return_classes + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + super().__init__(*args, **kwargs) + + def full_init(self) -> None: + """Load annotation file and set ``BaseDataset._fully_initialized`` to + True. + + If ``lazy_init=False``, ``full_init`` will be called during the + instantiation and ``self._fully_initialized`` will be set to True. If + ``obj._fully_initialized=False``, the class method decorated by + ``force_full_init`` will call ``full_init`` automatically. + + Several steps to initialize annotation: + + - load_data_list: Load annotations from annotation file. + - load_proposals: Load proposals from proposal file, if + `self.proposal_file` is not None. + - filter data information: Filter annotations according to + filter_cfg. + - slice_data: Slice dataset according to ``self._indices`` + - serialize_data: Serialize ``self.data_list`` if + ``self.serialize_data`` is True. + """ + if self._fully_initialized: + return + # load data information + self.data_list = self.load_data_list() + # get proposals from file + if self.proposal_file is not None: + self.load_proposals() + # filter illegal data, such as data that has no annotations. + self.data_list = self.filter_data() + + # Get subset data according to indices. + if self._indices is not None: + self.data_list = self._get_unserialized_subset(self._indices) + + # serialize data_list + if self.serialize_data: + self.data_bytes, self.data_address = self._serialize_data() + + self._fully_initialized = True + + def load_proposals(self) -> None: + """Load proposals from proposals file. + + The `proposals_list` should be a dict[img_path: proposals] + with the same length as `data_list`. And the `proposals` should be + a `dict` or :obj:`InstanceData` usually contains following keys. + + - bboxes (np.ndarry): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - scores (np.ndarry): Classification scores, has a shape + (num_instance, ). + """ + # TODO: Add Unit Test after fully support Dump-Proposal Metric + if not is_abs(self.proposal_file): + self.proposal_file = osp.join(self.data_root, self.proposal_file) + proposals_list = load( + self.proposal_file, backend_args=self.backend_args) + assert len(self.data_list) == len(proposals_list) + for data_info in self.data_list: + img_path = data_info['img_path'] + # `file_name` is the key to obtain the proposals from the + # `proposals_list`. + file_name = osp.join( + osp.split(osp.split(img_path)[0])[-1], + osp.split(img_path)[-1]) + proposals = proposals_list[file_name] + data_info['proposals'] = proposals + + def get_cat_ids(self, idx: int) -> List[int]: + """Get COCO category ids by index. + + Args: + idx (int): Index of data. + + Returns: + List[int]: All categories in the image of specified index. + """ + instances = self.get_data_info(idx)['instances'] + return [instance['bbox_label'] for instance in instances] diff --git a/mmdet/datasets/base_semseg_dataset.py b/mmdet/datasets/base_semseg_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d10f762a21a897ab8274fbe9eefab054691a7c60 --- /dev/null +++ b/mmdet/datasets/base_semseg_dataset.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.dataset import BaseDataset, Compose + +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class BaseSegDataset(BaseDataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + use_label_map (bool, optional): Whether to use label map. + Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img_path='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + use_label_map: bool = False, + max_refetch: int = 1000, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.seg_map_suffix = seg_map_suffix + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map( + new_classes) if use_label_map else None + self._metainfo.update(dict(label_map=self.label_map)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + # 0 is background + label_map[i] = 0 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + # 0 is background + if new_id != 0: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + elif len(palette) >= len(classes): + # Allow palette length is greater than classes. + return palette + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix)) + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_list.append(data_info) + else: + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/mmdet/datasets/base_video_dataset.py b/mmdet/datasets/base_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4a7a25f16206f06c7b64a7ce4c3588efd5455e --- /dev/null +++ b/mmdet/datasets/base_video_dataset.py @@ -0,0 +1,304 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from collections import defaultdict +from typing import Any, List, Tuple + +import mmengine.fileio as fileio +from mmengine.dataset import BaseDataset +from mmengine.logging import print_log + +from mmdet.datasets.api_wrappers import COCO +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class BaseVideoDataset(BaseDataset): + """Base video dataset for VID, MOT and VIS tasks.""" + + META = dict(classes=None) + # ann_id is unique in coco dataset. + ANN_ID_UNIQUE = True + + def __init__(self, *args, backend_args: dict = None, **kwargs): + self.backend_args = backend_args + super().__init__(*args, **kwargs) + + def load_data_list(self) -> Tuple[List[dict], List]: + """Load annotations from an annotation file named as ``self.ann_file``. + + Returns: + tuple(list[dict], list): A list of annotation and a list of + valid data indices. + """ + with fileio.get_local_path(self.ann_file) as local_path: + self.coco = COCO(local_path) + # The order of returned `cat_ids` will not + # change with the order of the classes + self.cat_ids = self.coco.get_cat_ids( + cat_names=self.metainfo['classes']) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + # used in `filter_data` + self.img_ids_with_ann = set() + + img_ids = self.coco.get_img_ids() + total_ann_ids = [] + # if ``video_id`` is not in the annotation file, we will assign a big + # unique video_id for this video. + single_video_id = 100000 + videos = {} + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + if 'video_id' not in raw_img_info: + single_video_id = single_video_id + 1 + video_id = single_video_id + else: + video_id = raw_img_info['video_id'] + + if video_id not in videos: + videos[video_id] = { + 'video_id': video_id, + 'images': [], + 'video_length': 0 + } + + videos[video_id]['video_length'] += 1 + ann_ids = self.coco.get_ann_ids( + img_ids=[img_id], cat_ids=self.cat_ids) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info( + dict(raw_img_info=raw_img_info, raw_ann_info=raw_ann_info)) + + if len(parsed_data_info['instances']) > 0: + self.img_ids_with_ann.add(parsed_data_info['img_id']) + + videos[video_id]['images'].append(parsed_data_info) + + data_list = [v for v in videos.values()] + + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list + + def parse_data_info(self, raw_data_info: dict) -> dict: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. + + Returns: + dict: Parsed annotation. + """ + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + data_info = {} + + data_info.update(img_info) + if self.data_prefix.get('img_path', None) is not None: + img_path = osp.join(self.data_prefix['img_path'], + img_info['file_name']) + else: + img_path = img_info['file_name'] + data_info['img_path'] = img_path + + instances = [] + for i, ann in enumerate(ann_info): + instance = {} + + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + + if ann.get('iscrowd', False): + instance['ignore_flag'] = 1 + else: + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = self.cat2label[ann['category_id']] + if ann.get('segmentation', None): + instance['mask'] = ann['segmentation'] + if ann.get('instance_id', None): + instance['instance_id'] = ann['instance_id'] + else: + # image dataset usually has no `instance_id`. + # Therefore, we set it to `i`. + instance['instance_id'] = i + instances.append(instance) + data_info['instances'] = instances + return data_info + + def filter_data(self) -> List[int]: + """Filter image annotations according to filter_cfg. + + Returns: + list[int]: Filtered results. + """ + if self.test_mode: + return self.data_list + + num_imgs_before_filter = sum( + [len(info['images']) for info in self.data_list]) + num_imgs_after_filter = 0 + + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= self.img_ids_with_ann + + new_data_list = [] + for video_data_info in self.data_list: + imgs_data_info = video_data_info['images'] + valid_imgs_data_info = [] + + for data_info in imgs_data_info: + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + # TODO: simplify these conditions + if self.filter_cfg is None: + if img_id not in ids_in_cat: + video_data_info['video_length'] -= 1 + continue + if min(width, height) >= 32: + valid_imgs_data_info.append(data_info) + num_imgs_after_filter += 1 + else: + video_data_info['video_length'] -= 1 + else: + if self.filter_cfg.get('filter_empty_gt', + True) and img_id not in ids_in_cat: + video_data_info['video_length'] -= 1 + continue + if min(width, height) >= self.filter_cfg.get( + 'min_size', 32): + valid_imgs_data_info.append(data_info) + num_imgs_after_filter += 1 + else: + video_data_info['video_length'] -= 1 + video_data_info['images'] = valid_imgs_data_info + new_data_list.append(video_data_info) + + print_log( + 'The number of samples before and after filtering: ' + f'{num_imgs_before_filter} / {num_imgs_after_filter}', 'current') + return new_data_list + + def prepare_data(self, idx) -> Any: + """Get date processed by ``self.pipeline``. Note that ``idx`` is a + video index in default since the base element of video dataset is a + video. However, in some cases, we need to specific both the video index + and frame index. For example, in traing mode, we may want to sample the + specific frames and all the frames must be sampled once in a epoch; in + test mode, we may want to output data of a single image rather than the + whole video for saving memory. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + if isinstance(idx, tuple): + assert len(idx) == 2, 'The length of idx must be 2: ' + '(video_index, frame_index)' + video_idx, frame_idx = idx[0], idx[1] + else: + video_idx, frame_idx = idx, None + + data_info = self.get_data_info(video_idx) + if self.test_mode: + # Support two test_mode: frame-level and video-level + final_data_info = defaultdict(list) + if frame_idx is None: + frames_idx_list = list(range(data_info['video_length'])) + else: + frames_idx_list = [frame_idx] + for index in frames_idx_list: + frame_ann = data_info['images'][index] + frame_ann['video_id'] = data_info['video_id'] + # Collate data_list (list of dict to dict of list) + for key, value in frame_ann.items(): + final_data_info[key].append(value) + # copy the info in video-level into img-level + # TODO: the value of this key is the same as that of + # `video_length` in test mode + final_data_info['ori_video_length'].append( + data_info['video_length']) + + final_data_info['video_length'] = [len(frames_idx_list) + ] * len(frames_idx_list) + return self.pipeline(final_data_info) + else: + # Specify `key_frame_id` for the frame sampling in the pipeline + if frame_idx is not None: + data_info['key_frame_id'] = frame_idx + return self.pipeline(data_info) + + def get_cat_ids(self, index) -> List[int]: + """Following image detection, we provide this interface function. Get + category ids by video index and frame index. + + Args: + index: The index of the dataset. It support two kinds of inputs: + Tuple: + video_idx (int): Index of video. + frame_idx (int): Index of frame. + Int: Index of video. + + Returns: + List[int]: All categories in the image of specified video index + and frame index. + """ + if isinstance(index, tuple): + assert len( + index + ) == 2, f'Expect the length of index is 2, but got {len(index)}' + video_idx, frame_idx = index + instances = self.get_data_info( + video_idx)['images'][frame_idx]['instances'] + return [instance['bbox_label'] for instance in instances] + else: + cat_ids = [] + for img in self.get_data_info(index)['images']: + for instance in img['instances']: + cat_ids.append(instance['bbox_label']) + return cat_ids + + @property + def num_all_imgs(self): + """Get the number of all the images in this video dataset.""" + return sum( + [len(self.get_data_info(i)['images']) for i in range(len(self))]) + + def get_len_per_video(self, idx): + """Get length of one video. + + Args: + idx (int): Index of video. + + Returns: + int (int): The length of the video. + """ + return len(self.get_data_info(idx)['images']) diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..09755eb1e8b0f0c278085bd2fafbb7247a3fc946 --- /dev/null +++ b/mmdet/datasets/cityscapes.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa +# and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa + +from typing import List + +from mmdet.registry import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class CityscapesDataset(CocoDataset): + """Dataset for Cityscapes.""" + + METAINFO = { + 'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + 'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), + (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)] + } + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) + min_size = self.filter_cfg.get('min_size', 0) + + # obtain images that contain annotation + ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + all_is_crowd = all([ + instance['ignore_flag'] == 1 + for instance in data_info['instances'] + ]) + if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd): + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..277b75988da6accc54973b4983bf5a08c379b668 --- /dev/null +++ b/mmdet/datasets/coco.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import List, Union + +from mmengine.fileio import get_local_path + +from mmdet.registry import DATASETS +from .api_wrappers import COCO +from .base_det_dataset import BaseDetDataset + + +@DATASETS.register_module() +class CocoDataset(BaseDetDataset): + """Dataset for COCO.""" + + METAINFO = { + 'classes': + ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), + # palette is a list of color tuples, which is used for visualization. + 'palette': + [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), + (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), + (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), + (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), + (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), + (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), + (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), + (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), + (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), + (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), + (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), + (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), + (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), + (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), + (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), + (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), + (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), + (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), + (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), + (246, 0, 122), (191, 162, 208)] + } + COCOAPI = COCO + # ann_id is unique in coco dataset. + ANN_ID_UNIQUE = True + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.coco = self.COCOAPI(local_path) + # The order of returned `cat_ids` will not + # change with the order of the `classes` + self.cat_ids = self.coco.get_cat_ids( + cat_names=self.metainfo['classes']) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list + + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file`` + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + + data_info = {} + + # TODO: need to change data_prefix['img'] to data_prefix['img_path'] + img_path = osp.join(self.data_prefix['img'], img_info['file_name']) + if self.data_prefix.get('seg', None): + seg_map_path = osp.join( + self.data_prefix['seg'], + img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix) + else: + seg_map_path = None + data_info['img_path'] = img_path + data_info['img_id'] = img_info['img_id'] + data_info['seg_map_path'] = seg_map_path + data_info['height'] = img_info['height'] + data_info['width'] = img_info['width'] + + if self.return_classes: + data_info['text'] = self.metainfo['classes'] + data_info['custom_entities'] = True + + instances = [] + for i, ann in enumerate(ann_info): + instance = {} + + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + + if ann.get('iscrowd', False): + instance['ignore_flag'] = 1 + else: + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = self.cat2label[ann['category_id']] + + if ann.get('segmentation', None): + instance['mask'] = ann['segmentation'] + + instances.append(instance) + data_info['instances'] = instances + return data_info + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) + min_size = self.filter_cfg.get('min_size', 0) + + # obtain images that contain annotation + ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + if filter_empty_gt and img_id not in ids_in_cat: + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/mmdet/datasets/coco_caption.py b/mmdet/datasets/coco_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..ee695fe9a768f2be5345c6ad6bafc74177f252c0 --- /dev/null +++ b/mmdet/datasets/coco_caption.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class CocoCaptionDataset(BaseDataset): + """COCO2014 Caption dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + for ann in annotations: + data_info = { + 'img_id': Path(ann['image']).stem.split('_')[-1], + 'img_path': file_backend.join_path(img_prefix, ann['image']), + 'gt_caption': ann['caption'], + } + + data_list.append(data_info) + + return data_list diff --git a/mmdet/datasets/coco_panoptic.py b/mmdet/datasets/coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ca78555095f8266c402be885a31cdfc24e5925 --- /dev/null +++ b/mmdet/datasets/coco_panoptic.py @@ -0,0 +1,292 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Callable, List, Optional, Sequence, Union + +from mmdet.registry import DATASETS +from .api_wrappers import COCOPanoptic +from .coco import CocoDataset + + +@DATASETS.register_module() +class CocoPanopticDataset(CocoDataset): + """Coco dataset for Panoptic segmentation. + + The annotation format is shown as follows. The `ann` field is optional + for testing. + + .. code-block:: none + + [ + { + 'filename': f'{image_id:012}.png', + 'image_id':9 + 'segments_info': + [ + { + 'id': 8345037, (segment_id in panoptic png, + convert from rgb) + 'category_id': 51, + 'iscrowd': 0, + 'bbox': (x1, y1, w, h), + 'area': 24315 + }, + ... + ] + }, + ... + ] + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + ``dict(img=None, ann=None, seg=None)``. The prefix ``seg`` which is + for panoptic segmentation map must be not None. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + """ + + METAINFO = { + 'classes': + ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', + 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', + 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', + 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', + 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', + 'wall-wood', 'water-other', 'window-blind', 'window-other', + 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', + 'cabinet-merged', 'table-merged', 'floor-other-merged', + 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', + 'paper-merged', 'food-other-merged', 'building-other-merged', + 'rock-merged', 'wall-other-merged', 'rug-merged'), + 'thing_classes': + ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), + 'stuff_classes': + ('banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', + 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house', + 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', + 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', + 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', + 'wall-wood', 'water-other', 'window-blind', 'window-other', + 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', + 'cabinet-merged', 'table-merged', 'floor-other-merged', + 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', + 'paper-merged', 'food-other-merged', 'building-other-merged', + 'rock-merged', 'wall-other-merged', 'rug-merged'), + 'palette': + [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), + (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), + (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), + (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), + (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), + (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), + (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), + (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), + (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), + (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), + (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), + (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), + (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), + (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), + (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), + (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), + (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), + (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), + (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), + (246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203), + (150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100), + (92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255), + (124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0), + (193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176), + (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55), + (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255), + (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74), + (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149), + (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153), + (146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140), + (96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152), + (208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0), + (0, 114, 143), (102, 102, 156), (250, 141, 255)] + } + COCOAPI = COCOPanoptic + # ann_id is not unique in coco panoptic dataset. + ANN_ID_UNIQUE = False + + def __init__(self, + ann_file: str = '', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=None, ann=None, seg=None), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + backend_args: dict = None, + **kwargs) -> None: + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch, + backend_args=backend_args, + **kwargs) + + def parse_data_info(self, raw_data_info: dict) -> dict: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file``. + + Returns: + dict: Parsed annotation. + """ + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + # filter out unmatched annotations which have + # same segment_id but belong to other image + ann_info = [ + ann for ann in ann_info if ann['image_id'] == img_info['img_id'] + ] + data_info = {} + + img_path = osp.join(self.data_prefix['img'], img_info['file_name']) + if self.data_prefix.get('seg', None): + seg_map_path = osp.join( + self.data_prefix['seg'], + img_info['file_name'].replace('jpg', 'png')) + else: + seg_map_path = None + data_info['img_path'] = img_path + data_info['img_id'] = img_info['img_id'] + data_info['seg_map_path'] = seg_map_path + data_info['height'] = img_info['height'] + data_info['width'] = img_info['width'] + + if self.return_classes: + data_info['text'] = self.metainfo['thing_classes'] + data_info['stuff_text'] = self.metainfo['stuff_classes'] + data_info['custom_entities'] = True # no important + + instances = [] + segments_info = [] + for ann in ann_info: + instance = {} + x1, y1, w, h = ann['bbox'] + if ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + category_id = ann['category_id'] + contiguous_cat_id = self.cat2label[category_id] + + is_thing = self.coco.load_cats(ids=category_id)[0]['isthing'] + if is_thing: + is_crowd = ann.get('iscrowd', False) + instance['bbox'] = bbox + instance['bbox_label'] = contiguous_cat_id + if not is_crowd: + instance['ignore_flag'] = 0 + else: + instance['ignore_flag'] = 1 + is_thing = False + + segment_info = { + 'id': ann['id'], + 'category': contiguous_cat_id, + 'is_thing': is_thing + } + segments_info.append(segment_info) + if len(instance) > 0 and is_thing: + instances.append(instance) + data_info['instances'] = instances + data_info['segments_info'] = segments_info + return data_info + + def filter_data(self) -> List[dict]: + """Filter images too small or without ground truth. + + Returns: + List[dict]: ``self.data_list`` after filtering. + """ + if self.test_mode: + return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) + min_size = self.filter_cfg.get('min_size', 0) + + ids_with_ann = set() + # check whether images have legal thing annotations. + for data_info in self.data_list: + for segment_info in data_info['segments_info']: + if not segment_info['is_thing']: + continue + ids_with_ann.add(data_info['img_id']) + + valid_data_list = [] + for data_info in self.data_list: + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + if filter_empty_gt and img_id not in ids_with_ann: + continue + if min(width, height) >= min_size: + valid_data_list.append(data_info) + + return valid_data_list diff --git a/mmdet/datasets/coco_semantic.py b/mmdet/datasets/coco_semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..752568454456c1e5edcb2a24c6c2b46f042cb334 --- /dev/null +++ b/mmdet/datasets/coco_semantic.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .ade20k import ADE20KSegDataset + + +@DATASETS.register_module() +class CocoSegDataset(ADE20KSegDataset): + """COCO dataset. + + In segmentation map annotation for COCO. The ``img_suffix`` is fixed to + '.jpg', and ``seg_map_suffix`` is fixed to '.png'. + """ + + METAINFO = dict( + classes=( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', + 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', + 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', + 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', + 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', + 'stone', 'straw', 'structural-other', 'table', 'tent', + 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', + 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone', + 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood'), + palette=[(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50), + (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255), + (230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7), + (150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82), + (143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3), + (0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220), + (255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224), + (255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153), + (6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255), + (140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0), + (255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255), + (255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255), + (11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255), + (0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0), + (255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0), + (0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255), + (173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255), + (255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20), + (255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255), + (255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255), + (0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255), + (0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0), + (143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255), + (255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112), + (92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160), + (163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163), + (255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0), + (255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0), + (10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255), + (255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204), + (41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255), + (71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255), + (184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255), (107, 255, 200), (58, 41, 149), + (183, 121, 142), (255, 73, 97), (107, 142, 35), + (190, 153, 153), (146, 139, 141), (70, 130, 180), + (134, 199, 156), (209, 226, 140), (96, 36, 108), (96, 96, 96), + (64, 170, 64), (152, 251, 152), (208, 229, 228), + (206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143), + (102, 102, 156), (250, 141, 255)]) diff --git a/mmdet/datasets/crowdhuman.py b/mmdet/datasets/crowdhuman.py new file mode 100644 index 0000000000000000000000000000000000000000..650176ee545ba6a10a816517553b3b77718d945b --- /dev/null +++ b/mmdet/datasets/crowdhuman.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import logging +import os.path as osp +import warnings +from typing import List, Union + +import mmcv +from mmengine.dist import get_rank +from mmengine.fileio import dump, get, get_text, load +from mmengine.logging import print_log +from mmengine.utils import ProgressBar + +from mmdet.registry import DATASETS +from .base_det_dataset import BaseDetDataset + + +@DATASETS.register_module() +class CrowdHumanDataset(BaseDetDataset): + r"""Dataset for CrowdHuman. + + Args: + data_root (str): The root directory for + ``data_prefix`` and ``ann_file``. + ann_file (str): Annotation file path. + extra_ann_file (str | optional):The path of extra image metas + for CrowdHuman. It can be created by CrowdHumanDataset + automatically or by tools/misc/get_crowdhuman_id_hw.py + manually. Defaults to None. + """ + + METAINFO = { + 'classes': ('person', ), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(220, 20, 60)] + } + + def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs): + # extra_ann_file record the size of each image. This file is + # automatically created when you first load the CrowdHuman + # dataset by mmdet. + if extra_ann_file is not None: + self.extra_ann_exist = True + self.extra_anns = load(extra_ann_file) + else: + ann_file_name = osp.basename(ann_file) + if 'train' in ann_file_name: + self.extra_ann_file = osp.join(data_root, 'id_hw_train.json') + elif 'val' in ann_file_name: + self.extra_ann_file = osp.join(data_root, 'id_hw_val.json') + self.extra_ann_exist = False + if not osp.isfile(self.extra_ann_file): + print_log( + 'extra_ann_file does not exist, prepare to collect ' + 'image height and width...', + level=logging.INFO) + self.extra_anns = {} + else: + self.extra_ann_exist = True + self.extra_anns = load(self.extra_ann_file) + super().__init__(data_root=data_root, ann_file=ann_file, **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + anno_strs = get_text( + self.ann_file, backend_args=self.backend_args).strip().split('\n') + print_log('loading CrowdHuman annotation...', level=logging.INFO) + data_list = [] + prog_bar = ProgressBar(len(anno_strs)) + for i, anno_str in enumerate(anno_strs): + anno_dict = json.loads(anno_str) + parsed_data_info = self.parse_data_info(anno_dict) + data_list.append(parsed_data_info) + prog_bar.update() + if not self.extra_ann_exist and get_rank() == 0: + # TODO: support file client + try: + dump(self.extra_anns, self.extra_ann_file, file_format='json') + except: # noqa + warnings.warn( + 'Cache files can not be saved automatically! To speed up' + 'loading the dataset, please manually generate the cache' + ' file by file tools/misc/get_crowdhuman_id_hw.py') + + print_log( + f'\nsave extra_ann_file in {self.data_root}', + level=logging.INFO) + + del self.extra_anns + print_log('\nDone', level=logging.INFO) + return data_list + + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file`` + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + data_info = {} + img_path = osp.join(self.data_prefix['img'], + f"{raw_data_info['ID']}.jpg") + data_info['img_path'] = img_path + data_info['img_id'] = raw_data_info['ID'] + + if not self.extra_ann_exist: + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, backend='cv2') + data_info['height'], data_info['width'] = img.shape[:2] + self.extra_anns[raw_data_info['ID']] = img.shape[:2] + del img, img_bytes + else: + data_info['height'], data_info['width'] = self.extra_anns[ + raw_data_info['ID']] + + instances = [] + for i, ann in enumerate(raw_data_info['gtboxes']): + instance = {} + if ann['tag'] not in self.metainfo['classes']: + instance['bbox_label'] = -1 + instance['ignore_flag'] = 1 + else: + instance['bbox_label'] = self.metainfo['classes'].index( + ann['tag']) + instance['ignore_flag'] = 0 + if 'extra' in ann: + if 'ignore' in ann['extra']: + if ann['extra']['ignore'] != 0: + instance['bbox_label'] = -1 + instance['ignore_flag'] = 1 + + x1, y1, w, h = ann['fbox'] + bbox = [x1, y1, x1 + w, y1 + h] + instance['bbox'] = bbox + + # Record the full bbox(fbox), head bbox(hbox) and visible + # bbox(vbox) as additional information. If you need to use + # this information, you just need to design the pipeline + # instead of overriding the CrowdHumanDataset. + instance['fbox'] = bbox + hbox = ann['hbox'] + instance['hbox'] = [ + hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3] + ] + vbox = ann['vbox'] + instance['vbox'] = [ + vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3] + ] + + instances.append(instance) + + data_info['instances'] = instances + return data_info diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..e651e2b990220bcfb85066fecf856e68896c5409 --- /dev/null +++ b/mmdet/datasets/dataset_wrappers.py @@ -0,0 +1,252 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import copy +from typing import List, Sequence, Union + +from mmengine.dataset import BaseDataset +from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset +from mmengine.dataset import force_full_init + +from mmdet.registry import DATASETS, TRANSFORMS + + +@DATASETS.register_module() +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. At the same time, we provide the `dynamic_scale` parameter + to dynamically change the output image size. + + Args: + dataset (:obj:`CustomDataset`): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + dynamic_scale (tuple[int], optional): The image scale can be changed + dynamically. Default to None. It is deprecated. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + max_refetch (int): The maximum number of retry iterations for getting + valid results from the pipeline. If the number of iterations is + greater than `max_refetch`, but results is still None, then the + iteration is terminated and raise the error. Default: 15. + """ + + def __init__(self, + dataset: Union[BaseDataset, dict], + pipeline: Sequence[str], + skip_type_keys: Union[Sequence[str], None] = None, + max_refetch: int = 15, + lazy_init: bool = False) -> None: + assert isinstance(pipeline, collections.abc.Sequence) + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + transform = TRANSFORMS.build(transform) + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self.dataset: BaseDataset + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') + + self._metainfo = self.dataset.metainfo + if hasattr(self.dataset, 'flag'): + self.flag = self.dataset.flag + self.num_samples = len(self.dataset) + self.max_refetch = max_refetch + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of the multi-image-mixed dataset. + + Returns: + dict: The meta information of multi-image-mixed dataset. + """ + return copy.deepcopy(self._metainfo) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._ori_len = len(self.dataset) + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indexes'): + for i in range(self.max_refetch): + # Make sure the results passed the loading pipeline + # of the original dataset is not None. + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + if None not in mix_results: + results['mix_results'] = mix_results + break + else: + raise RuntimeError( + 'The loading pipeline of the original dataset' + ' always return None. Please check the correctness ' + 'of the dataset and its pipeline.') + + for i in range(self.max_refetch): + # To confirm the results passed the training pipeline + # of the wrapper is not None. + updated_results = transform(copy.deepcopy(results)) + if updated_results is not None: + results = updated_results + break + else: + raise RuntimeError( + 'The training pipeline of the dataset wrapper' + ' always return None.Please check the correctness ' + 'of the dataset and its pipeline.') + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + +@DATASETS.register_module() +class ConcatDataset(MMENGINE_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as ``torch.utils.data.dataset.ConcatDataset``, support + lazy_init and get_dataset_source. + + Note: + ``ConcatDataset`` should not inherit from ``BaseDataset`` since + ``get_subset`` and ``get_subset_`` could produce ambiguous meaning + sub-dataset which conflicts with original dataset. If you want to use + a sub-dataset of ``ConcatDataset``, you should set ``indices`` + arguments for wrapped dataset which inherit from ``BaseDataset``. + + Args: + datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets + which will be concatenated. + lazy_init (bool, optional): Whether to load annotation during + instantiation. Defaults to False. + ignore_keys (List[str] or str): Ignore the keys that can be + unequal in `dataset.metainfo`. Defaults to None. + `New in version 0.3.0.` + """ + + def __init__(self, + datasets: Sequence[Union[BaseDataset, dict]], + lazy_init: bool = False, + ignore_keys: Union[str, List[str], None] = None): + self.datasets: List[BaseDataset] = [] + for i, dataset in enumerate(datasets): + if isinstance(dataset, dict): + self.datasets.append(DATASETS.build(dataset)) + elif isinstance(dataset, BaseDataset): + self.datasets.append(dataset) + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') + if ignore_keys is None: + self.ignore_keys = [] + elif isinstance(ignore_keys, str): + self.ignore_keys = [ignore_keys] + elif isinstance(ignore_keys, list): + self.ignore_keys = ignore_keys + else: + raise TypeError('ignore_keys should be a list or str, ' + f'but got {type(ignore_keys)}') + + meta_keys: set = set() + for dataset in self.datasets: + meta_keys |= dataset.metainfo.keys() + # if the metainfo of multiple datasets are the same, use metainfo + # of the first dataset, else the metainfo is a list with metainfo + # of all the datasets + is_all_same = True + self._metainfo_first = self.datasets[0].metainfo + for i, dataset in enumerate(self.datasets, 1): + for key in meta_keys: + if key in self.ignore_keys: + continue + if key not in dataset.metainfo: + is_all_same = False + break + if self._metainfo_first[key] != dataset.metainfo[key]: + is_all_same = False + break + + if is_all_same: + self._metainfo = self.datasets[0].metainfo + else: + self._metainfo = [dataset.metainfo for dataset in self.datasets] + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + def get_dataset_source(self, idx: int) -> int: + dataset_idx, _ = self._get_ori_dataset_idx(idx) + return dataset_idx diff --git a/mmdet/datasets/deepfashion.py b/mmdet/datasets/deepfashion.py new file mode 100644 index 0000000000000000000000000000000000000000..f853fc63398d598b90a88323e660ba6f4d81e2df --- /dev/null +++ b/mmdet/datasets/deepfashion.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class DeepFashionDataset(CocoDataset): + """Dataset for DeepFashion.""" + + METAINFO = { + 'classes': ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants', + 'bag', 'neckwear', 'headwear', 'eyeglass', 'belt', + 'footwear', 'hair', 'skin', 'face'), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(0, 192, 64), (0, 64, 96), (128, 192, 192), (0, 64, 64), + (0, 192, 224), (0, 192, 192), (128, 192, 64), (0, 192, 96), + (128, 32, 192), (0, 0, 224), (0, 0, 64), (0, 160, 192), + (128, 0, 96), (128, 0, 192), (0, 32, 192)] + } diff --git a/mmdet/datasets/dsdl.py b/mmdet/datasets/dsdl.py new file mode 100644 index 0000000000000000000000000000000000000000..75570a2a6396e0e7a4ce5cac5dbf2a23cd164629 --- /dev/null +++ b/mmdet/datasets/dsdl.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import List + +from mmdet.registry import DATASETS +from .base_det_dataset import BaseDetDataset + +try: + from dsdl.dataset import DSDLDataset +except ImportError: + DSDLDataset = None + + +@DATASETS.register_module() +class DSDLDetDataset(BaseDetDataset): + """Dataset for dsdl detection. + + Args: + with_bbox(bool): Load bbox or not, defaults to be True. + with_polygon(bool): Load polygon or not, defaults to be False. + with_mask(bool): Load seg map mask or not, defaults to be False. + with_imagelevel_label(bool): Load image level label or not, + defaults to be False. + with_hierarchy(bool): Load hierarchy information or not, + defaults to be False. + specific_key_path(dict): Path of specific key which can not + be loaded by it's field name. + pre_transform(dict): pre-transform functions before loading. + """ + + METAINFO = {} + + def __init__(self, + with_bbox: bool = True, + with_polygon: bool = False, + with_mask: bool = False, + with_imagelevel_label: bool = False, + with_hierarchy: bool = False, + specific_key_path: dict = {}, + pre_transform: dict = {}, + **kwargs) -> None: + + if DSDLDataset is None: + raise RuntimeError( + 'Package dsdl is not installed. Please run "pip install dsdl".' + ) + + self.with_hierarchy = with_hierarchy + self.specific_key_path = specific_key_path + + loc_config = dict(type='LocalFileReader', working_dir='') + if kwargs.get('data_root'): + kwargs['ann_file'] = os.path.join(kwargs['data_root'], + kwargs['ann_file']) + self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag'] + if with_bbox: + self.required_fields.append('Bbox') + if with_polygon: + self.required_fields.append('Polygon') + if with_mask: + self.required_fields.append('LabelMap') + if with_imagelevel_label: + self.required_fields.append('image_level_labels') + assert 'image_level_labels' in specific_key_path.keys( + ), '`image_level_labels` not specified in `specific_key_path` !' + + self.extra_keys = [ + key for key in self.specific_key_path.keys() + if key not in self.required_fields + ] + + self.dsdldataset = DSDLDataset( + dsdl_yaml=kwargs['ann_file'], + location_config=loc_config, + required_fields=self.required_fields, + specific_key_path=specific_key_path, + transform=pre_transform, + ) + + BaseDetDataset.__init__(self, **kwargs) + + def load_data_list(self) -> List[dict]: + """Load data info from an dsdl yaml file named as ``self.ann_file`` + + Returns: + List[dict]: A list of data info. + """ + if self.with_hierarchy: + # get classes_names and relation_matrix + classes_names, relation_matrix = \ + self.dsdldataset.class_dom.get_hierarchy_info() + self._metainfo['classes'] = tuple(classes_names) + self._metainfo['RELATION_MATRIX'] = relation_matrix + + else: + self._metainfo['classes'] = tuple(self.dsdldataset.class_names) + + data_list = [] + + for i, data in enumerate(self.dsdldataset): + # basic image info, including image id, path and size. + datainfo = dict( + img_id=i, + img_path=os.path.join(self.data_prefix['img_path'], + data['Image'][0].location), + width=data['ImageShape'][0].width, + height=data['ImageShape'][0].height, + ) + + # get image label info + if 'image_level_labels' in data.keys(): + if self.with_hierarchy: + # get leaf node name when using hierarchy classes + datainfo['image_level_labels'] = [ + self._metainfo['classes'].index(i.leaf_node_name) + for i in data['image_level_labels'] + ] + else: + datainfo['image_level_labels'] = [ + self._metainfo['classes'].index(i.name) + for i in data['image_level_labels'] + ] + + # get semantic segmentation info + if 'LabelMap' in data.keys(): + datainfo['seg_map_path'] = data['LabelMap'] + + # load instance info + instances = [] + if 'Bbox' in data.keys(): + for idx in range(len(data['Bbox'])): + bbox = data['Bbox'][idx] + if self.with_hierarchy: + # get leaf node name when using hierarchy classes + label = data['Label'][idx].leaf_node_name + label_index = self._metainfo['classes'].index(label) + else: + label = data['Label'][idx].name + label_index = self._metainfo['classes'].index(label) + + instance = {} + instance['bbox'] = bbox.xyxy + instance['bbox_label'] = label_index + + if 'ignore_flag' in data.keys(): + # get ignore flag + instance['ignore_flag'] = data['ignore_flag'][idx] + else: + instance['ignore_flag'] = 0 + + if 'Polygon' in data.keys(): + # get polygon info + polygon = data['Polygon'][idx] + instance['mask'] = polygon.openmmlabformat + + for key in self.extra_keys: + # load extra instance info + instance[key] = data[key][idx] + + instances.append(instance) + + datainfo['instances'] = instances + # append a standard sample in data list + if len(datainfo['instances']) > 0: + data_list.append(datainfo) + + return data_list + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ + if self.filter_cfg is not None else False + min_size = self.filter_cfg.get('min_size', 0) \ + if self.filter_cfg is not None else 0 + + valid_data_list = [] + for i, data_info in enumerate(self.data_list): + width = data_info['width'] + height = data_info['height'] + if filter_empty_gt and len(data_info['instances']) == 0: + continue + if min(width, height) >= min_size: + valid_data_list.append(data_info) + + return valid_data_list diff --git a/mmdet/datasets/isaid.py b/mmdet/datasets/isaid.py new file mode 100644 index 0000000000000000000000000000000000000000..87067d8459c4dd6e80e5f808f613e0bd600b5f2f --- /dev/null +++ b/mmdet/datasets/isaid.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class iSAIDDataset(CocoDataset): + """Dataset for iSAID instance segmentation. + + iSAID: A Large-scale Dataset for Instance Segmentation + in Aerial Images. + + For more detail, please refer to "projects/iSAID/README.md" + """ + + METAINFO = dict( + classes=('background', 'ship', 'store_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'Ground_Track_Field', + 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', + 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', + 'Harbor'), + palette=[(0, 0, 0), (0, 0, 63), (0, 63, 63), (0, 63, 0), (0, 63, 127), + (0, 63, 191), (0, 63, 255), (0, 127, 63), (0, 127, 127), + (0, 0, 127), (0, 0, 191), (0, 0, 255), (0, 191, 127), + (0, 127, 191), (0, 127, 255), (0, 100, 155)]) diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py new file mode 100644 index 0000000000000000000000000000000000000000..b9629f5d463da183f0b4ab4c5d0f7ff7b07e4348 --- /dev/null +++ b/mmdet/datasets/lvis.py @@ -0,0 +1,638 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import List + +from mmengine.fileio import get_local_path + +from mmdet.registry import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class LVISV05Dataset(CocoDataset): + """LVIS v0.5 dataset for detection.""" + + METAINFO = { + 'classes': + ('acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', + 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', + 'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron', + 'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', + 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', + 'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack', + 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', + 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', + 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', + 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', + 'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop', + 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', + 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', + 'bead', 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', + 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', + 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', + 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder', + 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage', + 'birdhouse', 'birthday_cake', 'birthday_card', 'biscuit_(bread)', + 'pirate_flag', 'black_sheep', 'blackboard', 'blanket', 'blazer', + 'blender', 'blimp', 'blinker', 'blueberry', 'boar', 'gameboard', + 'boat', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', + 'bolt', 'bonnet', 'book', 'book_bag', 'bookcase', 'booklet', + 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener', + 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', + 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin', + 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', + 'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase', + 'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie', + 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull', + 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', + 'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed', + 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife', + 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', + 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', + 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', + 'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder', + 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon', + 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', + 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', + 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', + 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', + 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', + 'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player', + 'celery', 'cellular_telephone', 'chain_mail', 'chair', + 'chaise_longue', 'champagne', 'chandelier', 'chap', 'checkbook', + 'checkerboard', 'cherry', 'chessboard', + 'chest_of_drawers_(furniture)', 'chicken_(animal)', 'chicken_wire', + 'chickpea', 'Chihuahua', 'chili_(vegetable)', 'chime', 'chinaware', + 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', + 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', + 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', + 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', + 'clasp', 'cleansing_agent', 'clementine', 'clip', 'clipboard', + 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', + 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'coconut', + 'coffee_filter', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', + 'coin', 'colander', 'coleslaw', 'coloring_material', + 'combination_lock', 'pacifier', 'comic_book', 'computer_keyboard', + 'concrete_mixer', 'cone', 'control', 'convertible_(automobile)', + 'sofa_bed', 'cookie', 'cookie_jar', 'cooking_utensil', + 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew', + 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', + 'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell', + 'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon', + 'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot', + 'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship', + 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', + 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler', + 'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool', + 'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard', + 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', + 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', + 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', + 'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog', + 'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask', + 'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', + 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', + 'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper', + 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', + 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', + 'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel', + 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', + 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', + 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', + 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', + 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', + 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose', + 'fireplace', 'fireplug', 'fish', 'fish_(food)', 'fishbowl', + 'fishing_boat', 'fishing_rod', 'flag', 'flagpole', 'flamingo', + 'flannel', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', + 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', + 'folding_chair', 'food_processor', 'football_(American)', + 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', + 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', + 'fruit_salad', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', + 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', + 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda', + 'gift_wrap', 'ginger', 'giraffe', 'cincture', + 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', + 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', + 'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater', + 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', + 'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag', + 'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush', + 'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock', + 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', + 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', + 'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil', + 'headband', 'headboard', 'headlight', 'headscarf', 'headset', + 'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater', + 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', + 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', + 'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', + 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', + 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', + 'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod', + 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean', + 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick', + 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', + 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', + 'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)', + 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', + 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', + 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', + 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', + 'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', + 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', + 'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard', + 'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion', + 'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine', + 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth', + 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', + 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', + 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', + 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', + 'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan', + 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', + 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', + 'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle', + 'mound_(baseball)', 'mouse_(animal_rodent)', + 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', + 'music_stool', 'musical_instrument', 'nailfile', 'nameplate', + 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', + 'newsstand', 'nightshirt', 'nosebag_(for_animals)', + 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', + 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', + 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'oregano', + 'ostrich', 'ottoman', 'overalls_(clothing)', 'owl', 'packet', + 'inkpad', 'pad', 'paddle', 'padlock', 'paintbox', 'paintbrush', + 'painting', 'pajamas', 'palette', 'pan_(for_cooking)', + 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya', + 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book', + 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', + 'parchment', 'parka', 'parking_meter', 'parrot', + 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', + 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', + 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard', + 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', + 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', + 'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood', + 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', + 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', + 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', + 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', + 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', + 'plate', 'platter', 'playing_card', 'playpen', 'pliers', + 'plow_(farm_equipment)', 'pocket_watch', 'pocketknife', + 'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt', + 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait', + 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', + 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', + 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', + 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', + 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', + 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', + 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', + 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', + 'recliner', 'record_player', 'red_cabbage', 'reflector', + 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', + 'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate', + 'Rollerblade', 'rolling_pin', 'root_beer', + 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', + 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', + 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', + 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', + 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', + 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', + 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', + 'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver', + 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', + 'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker', + 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', + 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', + 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', + 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag', + 'shovel', 'shower_head', 'shower_curtain', 'shredder_(for_paper)', + 'sieve', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski', + 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'sled', 'sleeping_bag', + 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake', + 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock', + 'soda_fountain', 'carbonated_water', 'sofa', 'softball', + 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', + 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', + 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'sponge', + 'spoon', 'sportswear', 'spotlight', 'squirrel', + 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)', + 'steak_(food)', 'steak_knife', 'steamer_(kitchen_appliance)', + 'steering_wheel', 'stencil', 'stepladder', 'step_stool', + 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup', + 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', 'stove', + 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', + 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', + 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', + 'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop', + 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', + 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', + 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', + 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', + 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', + 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', + 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', + 'telephone_pole', 'telephoto_lens', 'television_camera', + 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', + 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', + 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', + 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', + 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', + 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', + 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', + 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', + 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', + 'tree_house', 'trench_coat', 'triangle_(musical_instrument)', + 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', + 'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip', + 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', + 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve', + 'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin', + 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', + 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', + 'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch', + 'water_bottle', 'water_cooler', 'water_faucet', 'water_filter', + 'water_heater', 'water_jug', 'water_gun', 'water_scooter', + 'water_ski', 'water_tower', 'watering_can', 'watermelon', + 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit', + 'wheel', 'wheelchair', 'whipped_cream', 'whiskey', 'whistle', 'wick', + 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', + 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', + 'wineglass', 'wing_chair', 'blinder_(for_horses)', 'wok', 'wolf', + 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', + 'yak', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'), + 'palette': + None + } + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + try: + import lvis + if getattr(lvis, '__version__', '0') >= '10.5.3': + warnings.warn( + 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501 + UserWarning) + from lvis import LVIS + except ImportError: + raise ImportError( + 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501 + ) + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.lvis = LVIS(local_path) + self.cat_ids = self.lvis.get_cat_ids() + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map) + + img_ids = self.lvis.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.lvis.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + if raw_img_info['file_name'].startswith('COCO'): + # Convert form the COCO 2014 file naming convention of + # COCO_[train/val/test]2014_000000000000.jpg to the 2017 + # naming convention of 000000000000.jpg + # (LVIS v1 will fix this naming issue) + raw_img_info['file_name'] = raw_img_info['file_name'][-16:] + ann_ids = self.lvis.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.lvis.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.lvis + + return data_list + + +LVISDataset = LVISV05Dataset +DATASETS.register_module(name='LVISDataset', module=LVISDataset) + + +@DATASETS.register_module() +class LVISV1Dataset(LVISDataset): + """LVIS v1 dataset for detection.""" + + METAINFO = { + 'classes': + ('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', + 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', + 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', + 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', + 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', + 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', + 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', + 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', + 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', + 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', + 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', + 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', + 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', + 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', + 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', + 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', + 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', + 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', + 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', + 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', + 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', + 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', + 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', + 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', + 'bottle_opener', 'bouquet', 'bow_(weapon)', + 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl', + 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders', + 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread', + 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach', + 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket', + 'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train', + 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed', + 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter', + 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet', + 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder', + 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can', + 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane', + 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen', + 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', + 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', + 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', + 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', + 'cash_register', 'casserole', 'cassette', 'cast', 'cat', + 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery', + 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', + 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard', + 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea', + 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', + 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', + 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', + 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', + 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', + 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard', + 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower', + 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', + 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)', + 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', + 'coin', 'colander', 'coleslaw', 'coloring_material', + 'combination_lock', 'pacifier', 'comic_book', 'compass', + 'computer_keyboard', 'condiment', 'cone', 'control', + 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', + 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', + 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', + 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', + 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', + 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', + 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', + 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', + 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', + 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', + 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', + 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', + 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', + 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', + 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', + 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', + 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', + 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat', + 'dress_suit', 'dresser', 'drill', 'drone', 'dropper', + 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', + 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle', + 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg', + 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair', + 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot', + 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret', + 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine', + 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine', + 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug', + 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod', + 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash', + 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', + 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', + 'food_processor', 'football_(American)', 'football_helmet', + 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', + 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge', + 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose', + 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin', + 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger', + 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove', + 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart', + 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater', + 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', + 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun', + 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger', + 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass', + 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle', + 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', + 'headband', 'headboard', 'headlight', 'headscarf', 'headset', + 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', + 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', + 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', + 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', + 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', + 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', + 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', + 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', + 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', + 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', + 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', + 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', + 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', + 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', + 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', + 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', + 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', + 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', + 'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat', + 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', + 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', + 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot', + 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)', + 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick', + 'meatball', 'medicine', 'melon', 'microphone', 'microscope', + 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake', + 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', + 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey', + 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle', + 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad', + 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument', + 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', + 'nest', 'newspaper', 'newsstand', 'nightshirt', + 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook', + 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', + 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', + 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven', + 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', + 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette', + 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', + 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', + 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', + 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot', + 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', + 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', + 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', + 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', + 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', + 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', + 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', + 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', + 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', + 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', + 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', + 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', + 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', + 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', + 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', + 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', + 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller', + 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', + 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', + 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', + 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', + 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', + 'recliner', 'record_player', 'reflector', 'remote_control', + 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', + 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', + 'rolling_pin', 'root_beer', 'router_(computer_equipment)', + 'rubber_band', 'runner_(carpet)', 'plastic_bag', + 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', + 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', + 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', + 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', + 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', + 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', + 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', + 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', + 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', + 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', + 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', + 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', + 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', + 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', + 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', + 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', + 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', + 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', + 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', + 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', + 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', + 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', + 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', + 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', + 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', + 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', + 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', + 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', + 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', + 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', + 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', + 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', + 'tambourine', 'army_tank', 'tank_(storage_vessel)', + 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', + 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', + 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', + 'telephone_pole', 'telephoto_lens', 'television_camera', + 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', + 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', + 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', + 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', + 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', + 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', + 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', + 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', + 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', + 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod', + 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban', + 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', + 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', + 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', + 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', + 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', + 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', + 'washbasin', 'automatic_washer', 'watch', 'water_bottle', + 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', + 'water_gun', 'water_scooter', 'water_ski', 'water_tower', + 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', + 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', + 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', + 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', + 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', + 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', + 'yoke_(animal_equipment)', 'zebra', 'zucchini'), + 'palette': + None + } + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + try: + import lvis + if getattr(lvis, '__version__', '0') >= '10.5.3': + warnings.warn( + 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501 + UserWarning) + from lvis import LVIS + except ImportError: + raise ImportError( + 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501 + ) + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.lvis = LVIS(local_path) + self.cat_ids = self.lvis.get_cat_ids() + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map) + + img_ids = self.lvis.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.lvis.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + # coco_url is used in LVISv1 instead of file_name + # e.g. http://images.cocodataset.org/train2017/000000391895.jpg + # train/val split in specified in url + raw_img_info['file_name'] = raw_img_info['coco_url'].replace( + 'http://images.cocodataset.org/', '') + ann_ids = self.lvis.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.lvis.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.lvis + + return data_list diff --git a/mmdet/datasets/mot_challenge_dataset.py b/mmdet/datasets/mot_challenge_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ffbdc48ebf8d4a4ba11a605c8bc2a479cf2a0c96 --- /dev/null +++ b/mmdet/datasets/mot_challenge_dataset.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List, Union + +from mmdet.registry import DATASETS +from .base_video_dataset import BaseVideoDataset + + +@DATASETS.register_module() +class MOTChallengeDataset(BaseVideoDataset): + """Dataset for MOTChallenge. + + Args: + visibility_thr (float, optional): The minimum visibility + for the objects during training. Default to -1. + """ + + METAINFO = { + 'classes': + ('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike', + 'non_mot_vehicle', 'static_person', 'distractor', 'occluder', + 'occluder_on_ground', 'occluder_full', 'reflection', 'crowd') + } + + def __init__(self, visibility_thr: float = -1, *args, **kwargs): + self.visibility_thr = visibility_thr + super().__init__(*args, **kwargs) + + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. The difference between this + function and the one in ``BaseVideoDataset`` is that the parsing here + adds ``visibility`` and ``mot_conf``. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file`` + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + data_info = {} + + data_info.update(img_info) + if self.data_prefix.get('img_path', None) is not None: + img_path = osp.join(self.data_prefix['img_path'], + img_info['file_name']) + else: + img_path = img_info['file_name'] + data_info['img_path'] = img_path + + instances = [] + for i, ann in enumerate(ann_info): + instance = {} + + if (not self.test_mode) and (ann['visibility'] < + self.visibility_thr): + continue + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + + if ann.get('iscrowd', False): + instance['ignore_flag'] = 1 + else: + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = self.cat2label[ann['category_id']] + instance['instance_id'] = ann['instance_id'] + instance['category_id'] = ann['category_id'] + instance['mot_conf'] = ann['mot_conf'] + instance['visibility'] = ann['visibility'] + if len(instance) > 0: + instances.append(instance) + if not self.test_mode: + assert len(instances) > 0, f'No valid instances found in ' \ + f'image {data_info["img_path"]}!' + data_info['instances'] = instances + return data_info diff --git a/mmdet/datasets/objects365.py b/mmdet/datasets/objects365.py new file mode 100644 index 0000000000000000000000000000000000000000..e99869bfa309635af3c03cbfa77f732db3f50637 --- /dev/null +++ b/mmdet/datasets/objects365.py @@ -0,0 +1,284 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import List + +from mmengine.fileio import get_local_path + +from mmdet.registry import DATASETS +from .api_wrappers import COCO +from .coco import CocoDataset + +# images exist in annotations but not in image folder. +objv2_ignore_list = [ + osp.join('patch16', 'objects365_v2_00908726.jpg'), + osp.join('patch6', 'objects365_v1_00320532.jpg'), + osp.join('patch6', 'objects365_v1_00320534.jpg'), +] + + +@DATASETS.register_module() +class Objects365V1Dataset(CocoDataset): + """Objects365 v1 dataset for detection.""" + + METAINFO = { + 'classes': + ('person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle', + 'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk', + 'handbag', 'street lights', 'book', 'plate', 'helmet', + 'leather shoes', 'pillow', 'glove', 'potted plant', 'bracelet', + 'flower', 'tv', 'storage box', 'vase', 'bench', 'wine glass', 'boots', + 'bowl', 'dining table', 'umbrella', 'boat', 'flag', 'speaker', + 'trash bin/can', 'stool', 'backpack', 'couch', 'belt', 'carpet', + 'basket', 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table', + 'suv', 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil', + 'microphone', 'sandals', 'canned', 'necklace', 'mirror', 'faucet', + 'bicycle', 'bread', 'high heels', 'ring', 'van', 'watch', 'sink', + 'horse', 'fish', 'apple', 'camera', 'candle', 'teddy bear', 'cake', + 'motorcycle', 'wild bird', 'laptop', 'knife', 'traffic sign', + 'cell phone', 'paddle', 'truck', 'cow', 'power outlet', 'clock', + 'drum', 'fork', 'bus', 'hanger', 'nightstand', 'pot/pan', 'sheep', + 'guitar', 'traffic cone', 'tea pot', 'keyboard', 'tripod', 'hockey', + 'fan', 'dog', 'spoon', 'blackboard/whiteboard', 'balloon', + 'air conditioner', 'cymbal', 'mouse', 'telephone', 'pickup truck', + 'orange', 'banana', 'airplane', 'luggage', 'skis', 'soccer', + 'trolley', 'oven', 'remote', 'baseball glove', 'paper towel', + 'refrigerator', 'train', 'tomato', 'machinery vehicle', 'tent', + 'shampoo/shower gel', 'head phone', 'lantern', 'donut', + 'cleaning products', 'sailboat', 'tangerine', 'pizza', 'kite', + 'computer box', 'elephant', 'toiletries', 'gas stove', 'broccoli', + 'toilet', 'stroller', 'shovel', 'baseball bat', 'microwave', + 'skateboard', 'surfboard', 'surveillance camera', 'gun', 'life saver', + 'cat', 'lemon', 'liquid soap', 'zebra', 'duck', 'sports car', + 'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', 'converter', + 'tissue ', 'carrot', 'washing machine', 'vent', 'cookies', + 'cutting/chopping board', 'tennis racket', 'candy', + 'skating and skiing shoes', 'scissors', 'folder', 'baseball', + 'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine', + 'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear', + 'american football', 'basketball', 'potato', 'paint brush', 'printer', + 'billiards', 'fire hydrant', 'goose', 'projector', 'sausage', + 'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball', + 'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee', + 'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender', + 'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango', + 'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion', + 'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale', + 'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple', + 'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle', + 'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar', + 'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD', + 'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado', + 'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear', + 'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn', + 'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball', + 'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice', + 'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel', + 'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste', + 'antelope', 'shrimp', 'rickshaw', 'trombone', 'pomegranate', + 'coconut', 'jellyfish', 'mushroom', 'calculator', 'treadmill', + 'butterfly', 'egg tart', 'cheese', 'pig', 'pomelo', 'race car', + 'rice cooker', 'tuba', 'crosswalk sign', 'papaya', 'hair drier', + 'green onion', 'chips', 'dolphin', 'sushi', 'urinal', 'donkey', + 'electric drill', 'spring rolls', 'tortoise/turtle', 'parrot', + 'flute', 'measuring cup', 'shark', 'steak', 'poker card', + 'binoculars', 'llama', 'radish', 'noodles', 'yak', 'mop', 'crab', + 'microscope', 'barbell', 'bread/bun', 'baozi', 'lion', 'red cabbage', + 'polar bear', 'lighter', 'seal', 'mangosteen', 'comb', 'eraser', + 'pitaya', 'scallop', 'pencil case', 'saw', 'table tennis paddle', + 'okra', 'starfish', 'eagle', 'monkey', 'durian', 'game board', + 'rabbit', 'french horn', 'ambulance', 'asparagus', 'hoverboard', + 'pasta', 'target', 'hotair balloon', 'chainsaw', 'lobster', 'iron', + 'flashlight'), + 'palette': + None + } + + COCOAPI = COCO + # ann_id is unique in coco dataset. + ANN_ID_UNIQUE = True + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.coco = self.COCOAPI(local_path) + + # 'categories' list in objects365_train.json and objects365_val.json + # is inconsistent, need sort list(or dict) before get cat_ids. + cats = self.coco.cats + sorted_cats = {i: cats[i] for i in sorted(cats)} + self.coco.cats = sorted_cats + categories = self.coco.dataset['categories'] + sorted_categories = sorted(categories, key=lambda i: i['id']) + self.coco.dataset['categories'] = sorted_categories + # The order of returned `cat_ids` will not + # change with the order of the `classes` + self.cat_ids = self.coco.get_cat_ids( + cat_names=self.metainfo['classes']) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list + + +@DATASETS.register_module() +class Objects365V2Dataset(CocoDataset): + """Objects365 v2 dataset for detection.""" + METAINFO = { + 'classes': + ('Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', + 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf', + 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', + 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', + 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots', + 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt', + 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', + 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', + 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum', + 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle', + 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', + 'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', + 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed', + 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', + 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', + 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', + 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger', + 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine', + 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', + 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', + 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', + 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', + 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane', + 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat', + 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', + 'Elephant', 'Skateboard', 'Surfboard', 'Gun', + 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot', + 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper', + 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', + 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board', + 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder', + 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', + 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', + 'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', + 'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', + 'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear', + 'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', + 'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask', + 'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide', + 'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee', + 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon', + 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon', + 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', + 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer', + 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', + 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle', + 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone', + 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', + 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', + 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', + 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese', + 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue', + 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap', + 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', + 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak', + 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate', + 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker', + 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', + 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', + 'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', + 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', + 'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', + 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop', + 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle', + 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster', + 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling', + 'Table Tennis '), + 'palette': + None + } + + COCOAPI = COCO + # ann_id is unique in coco dataset. + ANN_ID_UNIQUE = True + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.coco = self.COCOAPI(local_path) + # The order of returned `cat_ids` will not + # change with the order of the `classes` + self.cat_ids = self.coco.get_cat_ids( + cat_names=self.metainfo['classes']) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + # file_name should be `patchX/xxx.jpg` + file_name = osp.join( + osp.split(osp.split(raw_img_info['file_name'])[0])[-1], + osp.split(raw_img_info['file_name'])[-1]) + + if file_name in objv2_ignore_list: + continue + + raw_img_info['file_name'] = file_name + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list diff --git a/mmdet/datasets/openimages.py b/mmdet/datasets/openimages.py new file mode 100644 index 0000000000000000000000000000000000000000..a3c6c8ec44fdfe86a653fc6a716009836f7d471c --- /dev/null +++ b/mmdet/datasets/openimages.py @@ -0,0 +1,484 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import os.path as osp +from collections import defaultdict +from typing import Dict, List, Optional + +import numpy as np +from mmengine.fileio import get_local_path, load +from mmengine.utils import is_abs + +from mmdet.registry import DATASETS +from .base_det_dataset import BaseDetDataset + + +@DATASETS.register_module() +class OpenImagesDataset(BaseDetDataset): + """Open Images dataset for detection. + + Args: + ann_file (str): Annotation file path. + label_file (str): File path of the label description file that + maps the classes names in MID format to their short + descriptions. + meta_file (str): File path to get image metas. + hierarchy_file (str): The file path of the class hierarchy. + image_level_ann_file (str): Human-verified image level annotation, + which is used in evaluation. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + + METAINFO: dict = dict(dataset_type='oid_v6') + + def __init__(self, + label_file: str, + meta_file: str, + hierarchy_file: str, + image_level_ann_file: Optional[str] = None, + **kwargs) -> None: + self.label_file = label_file + self.meta_file = meta_file + self.hierarchy_file = hierarchy_file + self.image_level_ann_file = image_level_ann_file + super().__init__(**kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ + classes_names, label_id_mapping = self._parse_label_file( + self.label_file) + self._metainfo['classes'] = classes_names + self.label_id_mapping = label_id_mapping + + if self.image_level_ann_file is not None: + img_level_anns = self._parse_img_level_ann( + self.image_level_ann_file) + else: + img_level_anns = None + + # OpenImagesMetric can get the relation matrix from the dataset meta + relation_matrix = self._get_relation_matrix(self.hierarchy_file) + self._metainfo['RELATION_MATRIX'] = relation_matrix + + data_list = [] + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + reader = csv.reader(f) + last_img_id = None + instances = [] + for i, line in enumerate(reader): + if i == 0: + continue + img_id = line[0] + if last_img_id is None: + last_img_id = img_id + label_id = line[2] + assert label_id in self.label_id_mapping + label = int(self.label_id_mapping[label_id]) + bbox = [ + float(line[4]), # xmin + float(line[6]), # ymin + float(line[5]), # xmax + float(line[7]) # ymax + ] + is_occluded = True if int(line[8]) == 1 else False + is_truncated = True if int(line[9]) == 1 else False + is_group_of = True if int(line[10]) == 1 else False + is_depiction = True if int(line[11]) == 1 else False + is_inside = True if int(line[12]) == 1 else False + + instance = dict( + bbox=bbox, + bbox_label=label, + ignore_flag=0, + is_occluded=is_occluded, + is_truncated=is_truncated, + is_group_of=is_group_of, + is_depiction=is_depiction, + is_inside=is_inside) + last_img_path = osp.join(self.data_prefix['img'], + f'{last_img_id}.jpg') + if img_id != last_img_id: + # switch to a new image, record previous image's data. + data_info = dict( + img_path=last_img_path, + img_id=last_img_id, + instances=instances, + ) + data_list.append(data_info) + instances = [] + instances.append(instance) + last_img_id = img_id + data_list.append( + dict( + img_path=last_img_path, + img_id=last_img_id, + instances=instances, + )) + + # add image metas to data list + img_metas = load( + self.meta_file, file_format='pkl', backend_args=self.backend_args) + assert len(img_metas) == len(data_list) + for i, meta in enumerate(img_metas): + img_id = data_list[i]['img_id'] + assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1] + h, w = meta['ori_shape'][:2] + data_list[i]['height'] = h + data_list[i]['width'] = w + # denormalize bboxes + for j in range(len(data_list[i]['instances'])): + data_list[i]['instances'][j]['bbox'][0] *= w + data_list[i]['instances'][j]['bbox'][2] *= w + data_list[i]['instances'][j]['bbox'][1] *= h + data_list[i]['instances'][j]['bbox'][3] *= h + # add image-level annotation + if img_level_anns is not None: + img_labels = [] + confidences = [] + img_ann_list = img_level_anns.get(img_id, []) + for ann in img_ann_list: + img_labels.append(int(ann['image_level_label'])) + confidences.append(float(ann['confidence'])) + data_list[i]['image_level_labels'] = np.array( + img_labels, dtype=np.int64) + data_list[i]['confidences'] = np.array( + confidences, dtype=np.float32) + return data_list + + def _parse_label_file(self, label_file: str) -> tuple: + """Get classes name and index mapping from cls-label-description file. + + Args: + label_file (str): File path of the label description file that + maps the classes names in MID format to their short + descriptions. + + Returns: + tuple: Class name of OpenImages. + """ + + index_list = [] + classes_names = [] + with get_local_path( + label_file, backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + reader = csv.reader(f) + for line in reader: + # self.cat2label[line[0]] = line[1] + classes_names.append(line[1]) + index_list.append(line[0]) + index_mapping = {index: i for i, index in enumerate(index_list)} + return classes_names, index_mapping + + def _parse_img_level_ann(self, + img_level_ann_file: str) -> Dict[str, List[dict]]: + """Parse image level annotations from csv style ann_file. + + Args: + img_level_ann_file (str): CSV style image level annotation + file path. + + Returns: + Dict[str, List[dict]]: Annotations where item of the defaultdict + indicates an image, each of which has (n) dicts. + Keys of dicts are: + + - `image_level_label` (int): Label id. + - `confidence` (float): Labels that are human-verified to be + present in an image have confidence = 1 (positive labels). + Labels that are human-verified to be absent from an image + have confidence = 0 (negative labels). Machine-generated + labels have fractional confidences, generally >= 0.5. + The higher the confidence, the smaller the chance for + the label to be a false positive. + """ + + item_lists = defaultdict(list) + with get_local_path( + img_level_ann_file, + backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + reader = csv.reader(f) + for i, line in enumerate(reader): + if i == 0: + continue + img_id = line[0] + item_lists[img_id].append( + dict( + image_level_label=int( + self.label_id_mapping[line[2]]), + confidence=float(line[3]))) + return item_lists + + def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray: + """Get the matrix of class hierarchy from the hierarchy file. Hierarchy + for 600 classes can be found at https://storage.googleapis.com/openimag + es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html. + + Args: + hierarchy_file (str): File path to the hierarchy for classes. + + Returns: + np.ndarray: The matrix of the corresponding relationship between + the parent class and the child class, of shape + (class_num, class_num). + """ # noqa + + hierarchy = load( + hierarchy_file, file_format='json', backend_args=self.backend_args) + class_num = len(self._metainfo['classes']) + relation_matrix = np.eye(class_num, class_num) + relation_matrix = self._convert_hierarchy_tree(hierarchy, + relation_matrix) + return relation_matrix + + def _convert_hierarchy_tree(self, + hierarchy_map: dict, + relation_matrix: np.ndarray, + parents: list = [], + get_all_parents: bool = True) -> np.ndarray: + """Get matrix of the corresponding relationship between the parent + class and the child class. + + Args: + hierarchy_map (dict): Including label name and corresponding + subcategory. Keys of dicts are: + + - `LabeName` (str): Name of the label. + - `Subcategory` (dict | list): Corresponding subcategory(ies). + relation_matrix (ndarray): The matrix of the corresponding + relationship between the parent class and the child class, + of shape (class_num, class_num). + parents (list): Corresponding parent class. + get_all_parents (bool): Whether get all parent names. + Default: True + + Returns: + ndarray: The matrix of the corresponding relationship between + the parent class and the child class, of shape + (class_num, class_num). + """ + + if 'Subcategory' in hierarchy_map: + for node in hierarchy_map['Subcategory']: + if 'LabelName' in node: + children_name = node['LabelName'] + children_index = self.label_id_mapping[children_name] + children = [children_index] + else: + continue + if len(parents) > 0: + for parent_index in parents: + if get_all_parents: + children.append(parent_index) + relation_matrix[children_index, parent_index] = 1 + relation_matrix = self._convert_hierarchy_tree( + node, relation_matrix, parents=children) + return relation_matrix + + def _join_prefix(self): + """Join ``self.data_root`` with annotation path.""" + super()._join_prefix() + if not is_abs(self.label_file) and self.label_file: + self.label_file = osp.join(self.data_root, self.label_file) + if not is_abs(self.meta_file) and self.meta_file: + self.meta_file = osp.join(self.data_root, self.meta_file) + if not is_abs(self.hierarchy_file) and self.hierarchy_file: + self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file) + if self.image_level_ann_file and not is_abs(self.image_level_ann_file): + self.image_level_ann_file = osp.join(self.data_root, + self.image_level_ann_file) + + +@DATASETS.register_module() +class OpenImagesChallengeDataset(OpenImagesDataset): + """Open Images Challenge dataset for detection. + + Args: + ann_file (str): Open Images Challenge box annotation in txt format. + """ + + METAINFO: dict = dict(dataset_type='oid_challenge') + + def __init__(self, ann_file: str, **kwargs) -> None: + if not ann_file.endswith('txt'): + raise TypeError('The annotation file of Open Images Challenge ' + 'should be a txt file.') + + super().__init__(ann_file=ann_file, **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ + classes_names, label_id_mapping = self._parse_label_file( + self.label_file) + self._metainfo['classes'] = classes_names + self.label_id_mapping = label_id_mapping + + if self.image_level_ann_file is not None: + img_level_anns = self._parse_img_level_ann( + self.image_level_ann_file) + else: + img_level_anns = None + + # OpenImagesMetric can get the relation matrix from the dataset meta + relation_matrix = self._get_relation_matrix(self.hierarchy_file) + self._metainfo['RELATION_MATRIX'] = relation_matrix + + data_list = [] + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + lines = f.readlines() + i = 0 + while i < len(lines): + instances = [] + filename = lines[i].rstrip() + i += 2 + img_gt_size = int(lines[i]) + i += 1 + for j in range(img_gt_size): + sp = lines[i + j].split() + instances.append( + dict( + bbox=[ + float(sp[1]), + float(sp[2]), + float(sp[3]), + float(sp[4]) + ], + bbox_label=int(sp[0]) - 1, # labels begin from 1 + ignore_flag=0, + is_group_ofs=True if int(sp[5]) == 1 else False)) + i += img_gt_size + data_list.append( + dict( + img_path=osp.join(self.data_prefix['img'], filename), + instances=instances, + )) + + # add image metas to data list + img_metas = load( + self.meta_file, file_format='pkl', backend_args=self.backend_args) + assert len(img_metas) == len(data_list) + for i, meta in enumerate(img_metas): + img_id = osp.split(data_list[i]['img_path'])[-1][:-4] + assert img_id == osp.split(meta['filename'])[-1][:-4] + h, w = meta['ori_shape'][:2] + data_list[i]['height'] = h + data_list[i]['width'] = w + data_list[i]['img_id'] = img_id + # denormalize bboxes + for j in range(len(data_list[i]['instances'])): + data_list[i]['instances'][j]['bbox'][0] *= w + data_list[i]['instances'][j]['bbox'][2] *= w + data_list[i]['instances'][j]['bbox'][1] *= h + data_list[i]['instances'][j]['bbox'][3] *= h + # add image-level annotation + if img_level_anns is not None: + img_labels = [] + confidences = [] + img_ann_list = img_level_anns.get(img_id, []) + for ann in img_ann_list: + img_labels.append(int(ann['image_level_label'])) + confidences.append(float(ann['confidence'])) + data_list[i]['image_level_labels'] = np.array( + img_labels, dtype=np.int64) + data_list[i]['confidences'] = np.array( + confidences, dtype=np.float32) + return data_list + + def _parse_label_file(self, label_file: str) -> tuple: + """Get classes name and index mapping from cls-label-description file. + + Args: + label_file (str): File path of the label description file that + maps the classes names in MID format to their short + descriptions. + + Returns: + tuple: Class name of OpenImages. + """ + label_list = [] + id_list = [] + index_mapping = {} + with get_local_path( + label_file, backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + reader = csv.reader(f) + for line in reader: + label_name = line[0] + label_id = int(line[2]) + label_list.append(line[1]) + id_list.append(label_id) + index_mapping[label_name] = label_id - 1 + indexes = np.argsort(id_list) + classes_names = [] + for index in indexes: + classes_names.append(label_list[index]) + return classes_names, index_mapping + + def _parse_img_level_ann(self, image_level_ann_file): + """Parse image level annotations from csv style ann_file. + + Args: + image_level_ann_file (str): CSV style image level annotation + file path. + + Returns: + defaultdict[list[dict]]: Annotations where item of the defaultdict + indicates an image, each of which has (n) dicts. + Keys of dicts are: + + - `image_level_label` (int): of shape 1. + - `confidence` (float): of shape 1. + """ + + item_lists = defaultdict(list) + with get_local_path( + image_level_ann_file, + backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + reader = csv.reader(f) + i = -1 + for line in reader: + i += 1 + if i == 0: + continue + else: + img_id = line[0] + label_id = line[1] + assert label_id in self.label_id_mapping + image_level_label = int( + self.label_id_mapping[label_id]) + confidence = float(line[2]) + item_lists[img_id].append( + dict( + image_level_label=image_level_label, + confidence=confidence)) + return item_lists + + def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray: + """Get the matrix of class hierarchy from the hierarchy file. + + Args: + hierarchy_file (str): File path to the hierarchy for classes. + + Returns: + np.ndarray: The matrix of the corresponding + relationship between the parent class and the child class, + of shape (class_num, class_num). + """ + with get_local_path( + hierarchy_file, backend_args=self.backend_args) as local_path: + class_label_tree = np.load(local_path, allow_pickle=True) + return class_label_tree[1:, 1:] diff --git a/mmdet/datasets/refcoco.py b/mmdet/datasets/refcoco.py new file mode 100644 index 0000000000000000000000000000000000000000..0dae75fd547216a5b69033cc821b93a1d9ac6abc --- /dev/null +++ b/mmdet/datasets/refcoco.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import os.path as osp +import random +from typing import Dict, List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class RefCocoDataset(BaseDataset): + """RefCOCO dataset. + + The `Refcoco` and `Refcoco+` dataset is based on + `ReferItGame: Referring to Objects in Photographs of Natural Scenes + `_. + + The `Refcocog` dataset is based on + `Generation and Comprehension of Unambiguous Object Descriptions + `_. + + Args: + ann_file (str): Annotation file path. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str): Prefix for training data. + split_file (str): Split file path. + split (str): Split name. Defaults to 'train'. + text_mode (str): Text mode. Defaults to 'random'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + ann_file: str, + split_file: str, + data_prefix: Dict, + split: str = 'train', + text_mode: str = 'random', + **kwargs): + self.split_file = split_file + self.split = split + + assert text_mode in ['original', 'random', 'concat', 'select_first'] + self.text_mode = text_mode + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + **kwargs, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.split_file) and self.split_file: + self.split_file = osp.join(self.data_root, self.split_file) + + return super()._join_prefix() + + def _init_refs(self): + """Initialize the refs for RefCOCO.""" + anns, imgs = {}, {} + for ann in self.instances['annotations']: + anns[ann['id']] = ann + for img in self.instances['images']: + imgs[img['id']] = img + + refs, ref_to_ann = {}, {} + for ref in self.splits: + # ids + ref_id = ref['ref_id'] + ann_id = ref['ann_id'] + # add mapping related to ref + refs[ref_id] = ref + ref_to_ann[ref_id] = anns[ann_id] + + self.refs = refs + self.ref_to_ann = ref_to_ann + + def load_data_list(self) -> List[dict]: + """Load data list.""" + self.splits = mmengine.load(self.split_file, file_format='pkl') + self.instances = mmengine.load(self.ann_file, file_format='json') + self._init_refs() + img_prefix = self.data_prefix['img_path'] + + ref_ids = [ + ref['ref_id'] for ref in self.splits if ref['split'] == self.split + ] + full_anno = [] + for ref_id in ref_ids: + ref = self.refs[ref_id] + ann = self.ref_to_ann[ref_id] + ann.update(ref) + full_anno.append(ann) + + image_id_list = [] + final_anno = {} + for anno in full_anno: + image_id_list.append(anno['image_id']) + final_anno[anno['ann_id']] = anno + annotations = [value for key, value in final_anno.items()] + + coco_train_id = [] + image_annot = {} + for i in range(len(self.instances['images'])): + coco_train_id.append(self.instances['images'][i]['id']) + image_annot[self.instances['images'][i] + ['id']] = self.instances['images'][i] + + images = [] + for image_id in list(set(image_id_list)): + images += [image_annot[image_id]] + + data_list = [] + + grounding_dict = collections.defaultdict(list) + for anno in annotations: + image_id = int(anno['image_id']) + grounding_dict[image_id].append(anno) + + join_path = mmengine.fileio.get_file_backend(img_prefix).join_path + for image in images: + img_id = image['id'] + instances = [] + sentences = [] + for grounding_anno in grounding_dict[img_id]: + texts = [x['raw'].lower() for x in grounding_anno['sentences']] + # random select one text + if self.text_mode == 'random': + idx = random.randint(0, len(texts) - 1) + text = [texts[idx]] + # concat all texts + elif self.text_mode == 'concat': + text = [''.join(texts)] + # select the first text + elif self.text_mode == 'select_first': + text = [texts[0]] + # use all texts + elif self.text_mode == 'original': + text = texts + else: + raise ValueError(f'Invalid text mode "{self.text_mode}".') + ins = [{ + 'mask': grounding_anno['segmentation'], + 'ignore_flag': 0 + }] * len(text) + instances.extend(ins) + sentences.extend(text) + data_info = { + 'img_path': join_path(img_prefix, image['file_name']), + 'img_id': img_id, + 'instances': instances, + 'text': sentences + } + data_list.append(data_info) + + if len(data_list) == 0: + raise ValueError(f'No sample in split "{self.split}".') + + return data_list diff --git a/mmdet/datasets/reid_dataset.py b/mmdet/datasets/reid_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1eed3ee4f0358edf59d19695c2b28394336dffd3 --- /dev/null +++ b/mmdet/datasets/reid_dataset.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from collections import defaultdict +from typing import Any, Dict, List + +import numpy as np +from mmengine.dataset import BaseDataset +from mmengine.utils import check_file_exist + +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class ReIDDataset(BaseDataset): + """Dataset for ReID. + + Args: + triplet_sampler (dict, optional): The sampler for hard mining + triplet loss. Defaults to None. + keys: num_ids (int): The number of person ids. + ins_per_id (int): The number of image for each person. + """ + + def __init__(self, triplet_sampler: dict = None, *args, **kwargs): + self.triplet_sampler = triplet_sampler + super().__init__(*args, **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ''self.ann_file''. + + Returns: + list[dict]: A list of annotation. + """ + assert isinstance(self.ann_file, str) + check_file_exist(self.ann_file) + data_list = [] + with open(self.ann_file) as f: + samples = [x.strip().split(' ') for x in f.readlines()] + for filename, gt_label in samples: + info = dict(img_prefix=self.data_prefix) + if self.data_prefix['img_path'] is not None: + info['img_path'] = osp.join(self.data_prefix['img_path'], + filename) + else: + info['img_path'] = filename + info['gt_label'] = np.array(gt_label, dtype=np.int64) + data_list.append(info) + self._parse_ann_info(data_list) + return data_list + + def _parse_ann_info(self, data_list: List[dict]): + """Parse person id annotations.""" + index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN] + self.index_dic = dict() # pid->array([idx1,...,idxN]) + for idx, info in enumerate(data_list): + pid = info['gt_label'] + index_tmp_dic[int(pid)].append(idx) + for pid, idxs in index_tmp_dic.items(): + self.index_dic[pid] = np.asarray(idxs, dtype=np.int64) + self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64) + + def prepare_data(self, idx: int) -> Any: + """Get data processed by ''self.pipeline''. + + Args: + idx (int): The index of ''data_info'' + + Returns: + Any: Depends on ''self.pipeline'' + """ + data_info = self.get_data_info(idx) + if self.triplet_sampler is not None: + img_info = self.triplet_sampling(data_info['gt_label'], + **self.triplet_sampler) + data_info = copy.deepcopy(img_info) # triplet -> list + else: + data_info = copy.deepcopy(data_info) # no triplet -> dict + return self.pipeline(data_info) + + def triplet_sampling(self, + pos_pid, + num_ids: int = 8, + ins_per_id: int = 4) -> Dict: + """Triplet sampler for hard mining triplet loss. First, for one + pos_pid, random sample ins_per_id images with same person id. + + Then, random sample num_ids - 1 images for each negative id. + Finally, random sample ins_per_id images for each negative id. + + Args: + pos_pid (ndarray): The person id of the anchor. + num_ids (int): The number of person ids. + ins_per_id (int): The number of images for each person. + + Returns: + Dict: Annotation information of num_ids X ins_per_id images. + """ + assert len(self.pids) >= num_ids, \ + 'The number of person ids in the training set must ' \ + 'be greater than the number of person ids in the sample.' + + pos_idxs = self.index_dic[int( + pos_pid)] # all positive idxs for pos_pid + idxs_list = [] + # select positive samplers + idxs_list.extend(pos_idxs[np.random.choice( + pos_idxs.shape[0], ins_per_id, replace=True)]) + # select negative ids + neg_pids = np.random.choice( + [i for i, _ in enumerate(self.pids) if i != pos_pid], + num_ids - 1, + replace=False) + # select negative samplers for each negative id + for neg_pid in neg_pids: + neg_idxs = self.index_dic[neg_pid] + idxs_list.extend(neg_idxs[np.random.choice( + neg_idxs.shape[0], ins_per_id, replace=True)]) + # return the final triplet batch + triplet_img_infos = [] + for idx in idxs_list: + triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx))) + # Collect data_list scatters (list of dict -> dict of list) + out = dict() + for key in triplet_img_infos[0].keys(): + out[key] = [_info[key] for _info in triplet_img_infos] + return out diff --git a/mmdet/datasets/samplers/__init__.py b/mmdet/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a942ff2199cc19e5957e312ab0a944d52e5081cc --- /dev/null +++ b/mmdet/datasets/samplers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .batch_sampler import (AspectRatioBatchSampler, + MultiDataAspectRatioBatchSampler, + TrackAspectRatioBatchSampler) +from .class_aware_sampler import ClassAwareSampler +from .multi_data_sampler import MultiDataSampler +from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler +from .track_img_sampler import TrackImgSampler + +__all__ = [ + 'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler', + 'GroupMultiSourceSampler', 'TrackImgSampler', + 'TrackAspectRatioBatchSampler', 'MultiDataSampler', + 'MultiDataAspectRatioBatchSampler' +] diff --git a/mmdet/datasets/samplers/batch_sampler.py b/mmdet/datasets/samplers/batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c17789c4e3ea51f1fa140d039a679f797a7660f6 --- /dev/null +++ b/mmdet/datasets/samplers/batch_sampler.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +from torch.utils.data import BatchSampler, Sampler + +from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler +from mmdet.registry import DATA_SAMPLERS + + +# TODO: maybe replace with a data_loader wrapper +@DATA_SAMPLERS.register_module() +class AspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. + + >= 1) into a same batch. + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + """ + + def __init__(self, + sampler: Sampler, + batch_size: int, + drop_last: bool = False) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + # two groups for w < h and w >= h + self._aspect_ratio_buckets = [[] for _ in range(2)] + + def __iter__(self) -> Sequence[int]: + for idx in self.sampler: + data_info = self.sampler.dataset.get_data_info(idx) + width, height = data_info['width'], data_info['height'] + bucket_id = 0 if width < height else 1 + bucket = self._aspect_ratio_buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ + 1] + self._aspect_ratio_buckets = [[] for _ in range(2)] + while len(left_data) > 0: + if len(left_data) <= self.batch_size: + if not self.drop_last: + yield left_data[:] + left_data = [] + else: + yield left_data[:self.batch_size] + left_data = left_data[self.batch_size:] + + def __len__(self) -> int: + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size + + +@DATA_SAMPLERS.register_module() +class TrackAspectRatioBatchSampler(AspectRatioBatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. + + >= 1) into a same batch. + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + """ + + def __iter__(self) -> Sequence[int]: + for idx in self.sampler: + # hard code to solve TrackImgSampler + if isinstance(self.sampler, TrackImgSampler): + video_idx, _ = idx + else: + video_idx = idx + # video_idx + data_info = self.sampler.dataset.get_data_info(video_idx) + # data_info {video_id, images, video_length} + img_data_info = data_info['images'][0] + width, height = img_data_info['width'], img_data_info['height'] + bucket_id = 0 if width < height else 1 + bucket = self._aspect_ratio_buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ + 1] + self._aspect_ratio_buckets = [[] for _ in range(2)] + while len(left_data) > 0: + if len(left_data) <= self.batch_size: + if not self.drop_last: + yield left_data[:] + left_data = [] + else: + yield left_data[:self.batch_size] + left_data = left_data[self.batch_size:] + + +@DATA_SAMPLERS.register_module() +class MultiDataAspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. + + >= 1) into a same batch for multi-source datasets. + + Args: + sampler (Sampler): Base sampler. + batch_size (Sequence(int)): Size of mini-batch for multi-source + datasets. + num_datasets(int): Number of multi-source datasets. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + """ + + def __init__(self, + sampler: Sampler, + batch_size: Sequence[int], + num_datasets: int, + drop_last: bool = True) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + self.sampler = sampler + self.batch_size = batch_size + self.num_datasets = num_datasets + self.drop_last = drop_last + # two groups for w < h and w >= h for each dataset --> 2 * num_datasets + self._buckets = [[] for _ in range(2 * self.num_datasets)] + + def __iter__(self) -> Sequence[int]: + for idx in self.sampler: + data_info = self.sampler.dataset.get_data_info(idx) + width, height = data_info['width'], data_info['height'] + dataset_source_idx = self.sampler.dataset.get_dataset_source(idx) + aspect_ratio_bucket_id = 0 if width < height else 1 + bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id + bucket = self._buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size[dataset_source_idx]: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + for i in range(self.num_datasets): + left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1] + while len(left_data) > 0: + if len(left_data) <= self.batch_size[i]: + if not self.drop_last: + yield left_data[:] + left_data = [] + else: + yield left_data[:self.batch_size[i]] + left_data = left_data[self.batch_size[i]:] + + self._buckets = [[] for _ in range(2 * self.num_datasets)] + + def __len__(self) -> int: + sizes = [0 for _ in range(self.num_datasets)] + for idx in self.sampler: + dataset_source_idx = self.sampler.dataset.get_dataset_source(idx) + sizes[dataset_source_idx] += 1 + + if self.drop_last: + lens = 0 + for i in range(self.num_datasets): + lens += sizes[i] // self.batch_size[i] + return lens + else: + lens = 0 + for i in range(self.num_datasets): + lens += (sizes[i] + self.batch_size[i] - + 1) // self.batch_size[i] + return lens diff --git a/mmdet/datasets/samplers/class_aware_sampler.py b/mmdet/datasets/samplers/class_aware_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca2f9b3ffb7c780ab25cc3704b67589763259e0 --- /dev/null +++ b/mmdet/datasets/samplers/class_aware_sampler.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Iterator, Optional, Union + +import numpy as np +import torch +from mmengine.dataset import BaseDataset +from mmengine.dist import get_dist_info, sync_random_seed +from torch.utils.data import Sampler + +from mmdet.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class ClassAwareSampler(Sampler): + r"""Sampler that restricts data loading to the label of the dataset. + + A class-aware sampling strategy to effectively tackle the + non-uniform class distribution. The length of the training data is + consistent with source data. Simple improvements based on `Relay + Backpropagation for Effective Learning of Deep Convolutional + Neural Networks `_ + + The implementation logic is referred to + https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py + + Args: + dataset: Dataset used for sampling. + seed (int, optional): random seed used to shuffle the sampler. + This number should be identical across all + processes in the distributed group. Defaults to None. + num_sample_class (int): The number of samples taken from each + per-label list. Defaults to 1. + """ + + def __init__(self, + dataset: BaseDataset, + seed: Optional[int] = None, + num_sample_class: int = 1) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.epoch = 0 + # Must be the same across all workers. If None, will use a + # random seed shared among workers + # (require synchronization among all workers) + if seed is None: + seed = sync_random_seed() + self.seed = seed + + # The number of samples taken from each per-label list + assert num_sample_class > 0 and isinstance(num_sample_class, int) + self.num_sample_class = num_sample_class + # Get per-label image list from dataset + self.cat_dict = self.get_cat2imgs() + + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size)) + self.total_size = self.num_samples * self.world_size + + # get number of images containing each category + self.num_cat_imgs = [len(x) for x in self.cat_dict.values()] + # filter labels without images + self.valid_cat_inds = [ + i for i, length in enumerate(self.num_cat_imgs) if length != 0 + ] + self.num_classes = len(self.valid_cat_inds) + + def get_cat2imgs(self) -> Dict[int, list]: + """Get a dict with class as key and img_ids as values. + + Returns: + dict[int, list]: A dict of per-label image list, + the item of the dict indicates a label index, + corresponds to the image index that contains the label. + """ + classes = self.dataset.metainfo.get('classes', None) + if classes is None: + raise ValueError('dataset metainfo must contain `classes`') + # sort the label index + cat2imgs = {i: [] for i in range(len(classes))} + for i in range(len(self.dataset)): + cat_ids = set(self.dataset.get_cat_ids(i)) + for cat in cat_ids: + cat2imgs[cat].append(i) + return cat2imgs + + def __iter__(self) -> Iterator[int]: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + # initialize label list + label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g) + # initialize each per-label image list + data_iter_dict = dict() + for i in self.valid_cat_inds: + data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g) + + def gen_cat_img_inds(cls_list, data_dict, num_sample_cls): + """Traverse the categories and extract `num_sample_cls` image + indexes of the corresponding categories one by one.""" + id_indices = [] + for _ in range(len(cls_list)): + cls_idx = next(cls_list) + for _ in range(num_sample_cls): + id = next(data_dict[cls_idx]) + id_indices.append(id) + return id_indices + + # deterministically shuffle based on epoch + num_bins = int( + math.ceil(self.total_size * 1.0 / self.num_classes / + self.num_sample_class)) + indices = [] + for i in range(num_bins): + indices += gen_cat_img_inds(label_iter_list, data_iter_dict, + self.num_sample_class) + + # fix extra samples to make it evenly divisible + if len(indices) >= self.total_size: + indices = indices[:self.total_size] + else: + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +class RandomCycleIter: + """Shuffle the list and do it again after the list have traversed. + + The implementation logic is referred to + https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py + + Example: + >>> label_list = [0, 1, 2, 4, 5] + >>> g = torch.Generator() + >>> g.manual_seed(0) + >>> label_iter_list = RandomCycleIter(label_list, generator=g) + >>> index = next(label_iter_list) + Args: + data (list or ndarray): The data that needs to be shuffled. + generator: An torch.Generator object, which is used in setting the seed + for generating random numbers. + """ # noqa: W605 + + def __init__(self, + data: Union[list, np.ndarray], + generator: torch.Generator = None) -> None: + self.data = data + self.length = len(data) + self.index = torch.randperm(self.length, generator=generator).numpy() + self.i = 0 + self.generator = generator + + def __iter__(self) -> Iterator: + return self + + def __len__(self) -> int: + return len(self.data) + + def __next__(self): + if self.i == self.length: + self.index = torch.randperm( + self.length, generator=self.generator).numpy() + self.i = 0 + idx = self.data[self.index[self.i]] + self.i += 1 + return idx diff --git a/mmdet/datasets/samplers/multi_data_sampler.py b/mmdet/datasets/samplers/multi_data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a4b60d84122ce9eb2090095e9744c2bd73cc3d --- /dev/null +++ b/mmdet/datasets/samplers/multi_data_sampler.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Iterator, Optional, Sequence, Sized + +import torch +from mmengine.dist import get_dist_info, sync_random_seed +from mmengine.registry import DATA_SAMPLERS +from torch.utils.data import Sampler + + +@DATA_SAMPLERS.register_module() +class MultiDataSampler(Sampler): + """The default data sampler for both distributed and non-distributed + environment. + + It has several differences from the PyTorch ``DistributedSampler`` as + below: + + 1. This sampler supports non-distributed environment. + + 2. The round up behaviors are a little different. + + - If ``round_up=True``, this sampler will add extra samples to make the + number of samples is evenly divisible by the world size. And + this behavior is the same as the ``DistributedSampler`` with + ``drop_last=False``. + - If ``round_up=False``, this sampler won't remove or add any samples + while the ``DistributedSampler`` with ``drop_last=True`` will remove + tail samples. + + Args: + dataset (Sized): The dataset. + dataset_ratio (Sequence(int)) The ratios of different datasets. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + """ + + def __init__(self, + dataset: Sized, + dataset_ratio: Sequence[int], + seed: Optional[int] = None, + round_up: bool = True) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.dataset_ratio = dataset_ratio + + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil(len(self.dataset) / world_size) + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + self.sizes = [len(dataset) for dataset in self.dataset.datasets] + + dataset_weight = [ + torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio) + for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes)) + ] + self.weights = torch.cat(dataset_weight) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + indices = torch.multinomial( + self.weights, len(self.weights), generator=g, + replacement=True).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + indices = indices[self.rank:self.total_size:self.world_size] + + return iter(indices) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/mmdet/datasets/samplers/multi_source_sampler.py b/mmdet/datasets/samplers/multi_source_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6efcde35e1375547239825a8f78a9e74f7825290 --- /dev/null +++ b/mmdet/datasets/samplers/multi_source_sampler.py @@ -0,0 +1,214 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from typing import Iterator, List, Optional, Sized, Union + +import numpy as np +import torch +from mmengine.dataset import BaseDataset +from mmengine.dist import get_dist_info, sync_random_seed +from torch.utils.data import Sampler + +from mmdet.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class MultiSourceSampler(Sampler): + r"""Multi-Source Infinite Sampler. + + According to the sampling ratio, sample data from different + datasets to form batches. + + Args: + dataset (Sized): The dataset. + batch_size (int): Size of mini-batch. + source_ratio (list[int | float]): The sampling ratio of different + source datasets in a mini-batch. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + seed (int, optional): Random seed. If None, set a random seed. + Defaults to None. + + Examples: + >>> dataset_type = 'ConcatDataset' + >>> sub_dataset_type = 'CocoDataset' + >>> data_root = 'data/coco/' + >>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json' + >>> unsup_ann = '../coco_semi_annos/' \ + >>> 'instances_train2017.1@10-unlabeled.json' + >>> dataset = dict(type=dataset_type, + >>> datasets=[ + >>> dict( + >>> type=sub_dataset_type, + >>> data_root=data_root, + >>> ann_file=sup_ann, + >>> data_prefix=dict(img='train2017/'), + >>> filter_cfg=dict(filter_empty_gt=True, min_size=32), + >>> pipeline=sup_pipeline), + >>> dict( + >>> type=sub_dataset_type, + >>> data_root=data_root, + >>> ann_file=unsup_ann, + >>> data_prefix=dict(img='train2017/'), + >>> filter_cfg=dict(filter_empty_gt=True, min_size=32), + >>> pipeline=unsup_pipeline), + >>> ]) + >>> train_dataloader = dict( + >>> batch_size=5, + >>> num_workers=5, + >>> persistent_workers=True, + >>> sampler=dict(type='MultiSourceSampler', + >>> batch_size=5, source_ratio=[1, 4]), + >>> batch_sampler=None, + >>> dataset=dataset) + """ + + def __init__(self, + dataset: Sized, + batch_size: int, + source_ratio: List[Union[int, float]], + shuffle: bool = True, + seed: Optional[int] = None) -> None: + + assert hasattr(dataset, 'cumulative_sizes'),\ + f'The dataset must be ConcatDataset, but get {dataset}' + assert isinstance(batch_size, int) and batch_size > 0, \ + 'batch_size must be a positive integer value, ' \ + f'but got batch_size={batch_size}' + assert isinstance(source_ratio, list), \ + f'source_ratio must be a list, but got source_ratio={source_ratio}' + assert len(source_ratio) == len(dataset.cumulative_sizes), \ + 'The length of source_ratio must be equal to ' \ + f'the number of datasets, but got source_ratio={source_ratio}' + + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.cumulative_sizes = [0] + dataset.cumulative_sizes + self.batch_size = batch_size + self.source_ratio = source_ratio + + self.num_per_source = [ + int(batch_size * sr / sum(source_ratio)) for sr in source_ratio + ] + self.num_per_source[0] = batch_size - sum(self.num_per_source[1:]) + + assert sum(self.num_per_source) == batch_size, \ + 'The sum of num_per_source must be equal to ' \ + f'batch_size, but get {self.num_per_source}' + + self.seed = sync_random_seed() if seed is None else seed + self.shuffle = shuffle + self.source2inds = { + source: self._indices_of_rank(len(ds)) + for source, ds in enumerate(dataset.datasets) + } + + def _infinite_indices(self, sample_size: int) -> Iterator[int]: + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(sample_size, generator=g).tolist() + else: + yield from torch.arange(sample_size).tolist() + + def _indices_of_rank(self, sample_size: int) -> Iterator[int]: + """Slice the infinite indices by rank.""" + yield from itertools.islice( + self._infinite_indices(sample_size), self.rank, None, + self.world_size) + + def __iter__(self) -> Iterator[int]: + batch_buffer = [] + while True: + for source, num in enumerate(self.num_per_source): + batch_buffer_per_source = [] + for idx in self.source2inds[source]: + idx += self.cumulative_sizes[source] + batch_buffer_per_source.append(idx) + if len(batch_buffer_per_source) == num: + batch_buffer += batch_buffer_per_source + break + yield from batch_buffer + batch_buffer = [] + + def __len__(self) -> int: + return len(self.dataset) + + def set_epoch(self, epoch: int) -> None: + """Not supported in `epoch-based runner.""" + pass + + +@DATA_SAMPLERS.register_module() +class GroupMultiSourceSampler(MultiSourceSampler): + r"""Group Multi-Source Infinite Sampler. + + According to the sampling ratio, sample data from different + datasets but the same group to form batches. + + Args: + dataset (Sized): The dataset. + batch_size (int): Size of mini-batch. + source_ratio (list[int | float]): The sampling ratio of different + source datasets in a mini-batch. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + seed (int, optional): Random seed. If None, set a random seed. + Defaults to None. + """ + + def __init__(self, + dataset: BaseDataset, + batch_size: int, + source_ratio: List[Union[int, float]], + shuffle: bool = True, + seed: Optional[int] = None) -> None: + super().__init__( + dataset=dataset, + batch_size=batch_size, + source_ratio=source_ratio, + shuffle=shuffle, + seed=seed) + + self._get_source_group_info() + self.group_source2inds = [{ + source: + self._indices_of_rank(self.group2size_per_source[source][group]) + for source in range(len(dataset.datasets)) + } for group in range(len(self.group_ratio))] + + def _get_source_group_info(self) -> None: + self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}] + self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}] + for source, dataset in enumerate(self.dataset.datasets): + for idx in range(len(dataset)): + data_info = dataset.get_data_info(idx) + width, height = data_info['width'], data_info['height'] + group = 0 if width < height else 1 + self.group2size_per_source[source][group] += 1 + self.group2inds_per_source[source][group].append(idx) + + self.group_sizes = np.zeros(2, dtype=np.int64) + for group2size in self.group2size_per_source: + for group, size in group2size.items(): + self.group_sizes[group] += size + self.group_ratio = self.group_sizes / sum(self.group_sizes) + + def __iter__(self) -> Iterator[int]: + batch_buffer = [] + while True: + group = np.random.choice( + list(range(len(self.group_ratio))), p=self.group_ratio) + for source, num in enumerate(self.num_per_source): + batch_buffer_per_source = [] + for idx in self.group_source2inds[group][source]: + idx = self.group2inds_per_source[source][group][ + idx] + self.cumulative_sizes[source] + batch_buffer_per_source.append(idx) + if len(batch_buffer_per_source) == num: + batch_buffer += batch_buffer_per_source + break + yield from batch_buffer + batch_buffer = [] diff --git a/mmdet/datasets/samplers/track_img_sampler.py b/mmdet/datasets/samplers/track_img_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d7db629f40f3f24bdf14cd852ccc4472d1d50f1b --- /dev/null +++ b/mmdet/datasets/samplers/track_img_sampler.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import random +from typing import Iterator, Optional, Sized + +import numpy as np +from mmengine.dataset import ClassBalancedDataset, ConcatDataset +from mmengine.dist import get_dist_info, sync_random_seed +from torch.utils.data import Sampler + +from mmdet.registry import DATA_SAMPLERS +from ..base_video_dataset import BaseVideoDataset + + +@DATA_SAMPLERS.register_module() +class TrackImgSampler(Sampler): + """Sampler that providing image-level sampling outputs for video datasets + in tracking tasks. It could be both used in both distributed and + non-distributed environment. + If using the default sampler in pytorch, the subsequent data receiver will + get one video, which is not desired in some cases: + (Take a non-distributed environment as an example) + 1. In test mode, we want only one image is fed into the data pipeline. This + is in consideration of memory usage since feeding the whole video commonly + requires a large amount of memory (>=20G on MOTChallenge17 dataset), which + is not available in some machines. + 2. In training mode, we may want to make sure all the images in one video + are randomly sampled once in one epoch and this can not be guaranteed in + the default sampler in pytorch. + + Args: + dataset (Sized): Dataset used for sampling. + seed (int, optional): random seed used to shuffle the sampler. This + number should be identical across all processes in the distributed + group. Defaults to None. + """ + + def __init__( + self, + dataset: Sized, + seed: Optional[int] = None, + ) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + self.epoch = 0 + if seed is None: + self.seed = sync_random_seed() + else: + self.seed = seed + + self.dataset = dataset + self.indices = [] + # Hard code here to handle different dataset wrapper + if isinstance(self.dataset, ConcatDataset): + cat_datasets = self.dataset.datasets + assert isinstance( + cat_datasets[0], BaseVideoDataset + ), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}' + self.test_mode = cat_datasets[0].test_mode + assert not self.test_mode, "'ConcatDataset' should not exist in " + 'test mode' + for dataset in cat_datasets: + num_videos = len(dataset) + for video_ind in range(num_videos): + self.indices.extend([ + (video_ind, frame_ind) for frame_ind in range( + dataset.get_len_per_video(video_ind)) + ]) + elif isinstance(self.dataset, ClassBalancedDataset): + ori_dataset = self.dataset.dataset + assert isinstance( + ori_dataset, BaseVideoDataset + ), f'expected BaseVideoDataset, but got {type(ori_dataset)}' + self.test_mode = ori_dataset.test_mode + assert not self.test_mode, "'ClassBalancedDataset' should not " + 'exist in test mode' + video_indices = self.dataset.repeat_indices + for index in video_indices: + self.indices.extend([(index, frame_ind) for frame_ind in range( + ori_dataset.get_len_per_video(index))]) + else: + assert isinstance( + self.dataset, BaseVideoDataset + ), 'TrackImgSampler is only supported in BaseVideoDataset or ' + 'dataset wrapper: ClassBalancedDataset and ConcatDataset, but ' + f'got {type(self.dataset)} ' + self.test_mode = self.dataset.test_mode + num_videos = len(self.dataset) + + if self.test_mode: + # in test mode, the images belong to the same video must be put + # on the same device. + if num_videos < self.world_size: + raise ValueError(f'only {num_videos} videos loaded,' + f'but {self.world_size} gpus were given.') + chunks = np.array_split( + list(range(num_videos)), self.world_size) + for videos_inds in chunks: + indices_chunk = [] + for video_ind in videos_inds: + indices_chunk.extend([ + (video_ind, frame_ind) for frame_ind in range( + self.dataset.get_len_per_video(video_ind)) + ]) + self.indices.append(indices_chunk) + else: + for video_ind in range(num_videos): + self.indices.extend([ + (video_ind, frame_ind) for frame_ind in range( + self.dataset.get_len_per_video(video_ind)) + ]) + + if self.test_mode: + self.num_samples = len(self.indices[self.rank]) + self.total_size = sum( + [len(index_list) for index_list in self.indices]) + else: + self.num_samples = int( + math.ceil(len(self.indices) * 1.0 / self.world_size)) + self.total_size = self.num_samples * self.world_size + + def __iter__(self) -> Iterator: + if self.test_mode: + # in test mode, the order of frames can not be shuffled. + indices = self.indices[self.rank] + else: + # deterministically shuffle based on epoch + rng = random.Random(self.epoch + self.seed) + indices = rng.sample(self.indices, len(self.indices)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f30d6c13528ba4d2f6031786c80b22eec8e6bd4 --- /dev/null +++ b/mmdet/datasets/transforms/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .augment_wrappers import AutoAugment, RandAugment +from .colorspace import (AutoContrast, Brightness, Color, ColorTransform, + Contrast, Equalize, Invert, Posterize, Sharpness, + Solarize, SolarizeAdd) +from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs, + PackTrackInputs, ToTensor, Transpose) +from .frame_sampling import BaseFrameSample, UniformRefFrameSample +from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX, + TranslateY) +from .instaboost import InstaBoost +from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations, + LoadEmptyAnnotations, LoadImageFromNDArray, + LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, + LoadProposals, LoadTrackAnnotations) +from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP +from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut, + Expand, FixScaleResize, FixShapeResize, + MinIoURandomCrop, MixUp, Mosaic, Pad, + PhotoMetricDistortion, RandomAffine, + RandomCenterCropPad, RandomCrop, RandomErasing, + RandomFlip, RandomShift, Resize, ResizeShortestEdge, + SegRescale, YOLOXHSVRandomAug) +from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder + +__all__ = [ + 'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose', + 'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations', + 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip', + 'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand', + 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad', + 'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize', + 'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift', + 'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste', + 'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform', + 'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize', + 'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing', + 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', + 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', + 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', + 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', + 'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP' +] diff --git a/mmdet/datasets/transforms/augment_wrappers.py b/mmdet/datasets/transforms/augment_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..19fae6efdf66aa4c26bb85a2f2c96a1e079320b8 --- /dev/null +++ b/mmdet/datasets/transforms/augment_wrappers.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +from mmcv.transforms import RandomChoice +from mmcv.transforms.utils import cache_randomness +from mmengine.config import ConfigDict + +from mmdet.registry import TRANSFORMS + +# AutoAugment uses reinforcement learning to search for +# some widely useful data augmentation strategies, +# here we provide AUTOAUG_POLICIES_V0. +# For AUTOAUG_POLICIES_V0, each tuple is an augmentation +# operation of the form (operation, probability, magnitude). +# Each element in policies is a policy that will be applied +# sequentially on the image. + +# RandAugment defines a data augmentation search space, RANDAUG_SPACE, +# sampling 1~3 data augmentations each time, and +# setting the magnitude of each data augmentation randomly, +# which will be applied sequentially on the image. + +_MAX_LEVEL = 10 + +AUTOAUG_POLICIES_V0 = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], +] + + +def policies_v0(): + """Autoaugment policies that was used in AutoAugment Paper.""" + policies = list() + for policy_args in AUTOAUG_POLICIES_V0: + policy = list() + for args in policy_args: + policy.append(dict(type=args[0], prob=args[1], level=args[2])) + policies.append(policy) + return policies + + +RANDAUG_SPACE = [[dict(type='AutoContrast')], [dict(type='Equalize')], + [dict(type='Invert')], [dict(type='Rotate')], + [dict(type='Posterize')], [dict(type='Solarize')], + [dict(type='SolarizeAdd')], [dict(type='Color')], + [dict(type='Contrast')], [dict(type='Brightness')], + [dict(type='Sharpness')], [dict(type='ShearX')], + [dict(type='ShearY')], [dict(type='TranslateX')], + [dict(type='TranslateY')]] + + +def level_to_mag(level: Optional[int], min_mag: float, + max_mag: float) -> float: + """Map from level to magnitude.""" + if level is None: + return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1) + else: + return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1) + + +@TRANSFORMS.register_module() +class AutoAugment(RandomChoice): + """Auto augmentation. + + This data augmentation is proposed in `AutoAugment: Learning + Augmentation Policies from Data `_ + and in `Learning Data Augmentation Strategies for Object Detection + `_. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_bboxes_labels + - gt_masks + - gt_ignore_flags + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + policies (List[List[Union[dict, ConfigDict]]]): + The policies of auto augmentation.Each policy in ``policies`` + is a specific augmentation policy, and is composed by several + augmentations. When AutoAugment is called, a random policy in + ``policies`` will be selected to augment images. + Defaults to policy_v0(). + prob (list[float], optional): The probabilities associated + with each policy. The length should be equal to the policy + number and the sum should be 1. If not given, a uniform + distribution will be assumed. Defaults to None. + + Examples: + >>> policies = [ + >>> [ + >>> dict(type='Sharpness', prob=0.0, level=8), + >>> dict(type='ShearX', prob=0.4, level=0,) + >>> ], + >>> [ + >>> dict(type='Rotate', prob=0.6, level=10), + >>> dict(type='Color', prob=1.0, level=6) + >>> ] + >>> ] + >>> augmentation = AutoAugment(policies) + >>> img = np.ones(100, 100, 3) + >>> gt_bboxes = np.ones(10, 4) + >>> results = dict(img=img, gt_bboxes=gt_bboxes) + >>> results = augmentation(results) + """ + + def __init__(self, + policies: List[List[Union[dict, ConfigDict]]] = policies_v0(), + prob: Optional[List[float]] = None) -> None: + assert isinstance(policies, list) and len(policies) > 0, \ + 'Policies must be a non-empty list.' + for policy in policies: + assert isinstance(policy, list) and len(policy) > 0, \ + 'Each policy in policies must be a non-empty list.' + for augment in policy: + assert isinstance(augment, dict) and 'type' in augment, \ + 'Each specific augmentation must be a dict with key' \ + ' "type".' + super().__init__(transforms=policies, prob=prob) + self.policies = policies + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(policies={self.policies}, ' \ + f'prob={self.prob})' + + +@TRANSFORMS.register_module() +class RandAugment(RandomChoice): + """Rand augmentation. + + This data augmentation is proposed in `RandAugment: + Practical automated data augmentation with a reduced + search space `_. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_bboxes_labels + - gt_masks + - gt_ignore_flags + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + aug_space (List[List[Union[dict, ConfigDict]]]): The augmentation space + of rand augmentation. Each augmentation transform in ``aug_space`` + is a specific transform, and is composed by several augmentations. + When RandAugment is called, a random transform in ``aug_space`` + will be selected to augment images. Defaults to aug_space. + aug_num (int): Number of augmentation to apply equentially. + Defaults to 2. + prob (list[float], optional): The probabilities associated with + each augmentation. The length should be equal to the + augmentation space and the sum should be 1. If not given, + a uniform distribution will be assumed. Defaults to None. + + Examples: + >>> aug_space = [ + >>> dict(type='Sharpness'), + >>> dict(type='ShearX'), + >>> dict(type='Color'), + >>> ], + >>> augmentation = RandAugment(aug_space) + >>> img = np.ones(100, 100, 3) + >>> gt_bboxes = np.ones(10, 4) + >>> results = dict(img=img, gt_bboxes=gt_bboxes) + >>> results = augmentation(results) + """ + + def __init__(self, + aug_space: List[Union[dict, ConfigDict]] = RANDAUG_SPACE, + aug_num: int = 2, + prob: Optional[List[float]] = None) -> None: + assert isinstance(aug_space, list) and len(aug_space) > 0, \ + 'Augmentation space must be a non-empty list.' + for aug in aug_space: + assert isinstance(aug, list) and len(aug) == 1, \ + 'Each augmentation in aug_space must be a list.' + for transform in aug: + assert isinstance(transform, dict) and 'type' in transform, \ + 'Each specific transform must be a dict with key' \ + ' "type".' + super().__init__(transforms=aug_space, prob=prob) + self.aug_space = aug_space + self.aug_num = aug_num + + @cache_randomness + def random_pipeline_index(self): + indices = np.arange(len(self.transforms)) + return np.random.choice( + indices, self.aug_num, p=self.prob, replace=False) + + def transform(self, results: dict) -> dict: + """Transform function to use RandAugment. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with RandAugment. + """ + for idx in self.random_pipeline_index(): + results = self.transforms[idx](results) + return results + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(' \ + f'aug_space={self.aug_space}, '\ + f'aug_num={self.aug_num}, ' \ + f'prob={self.prob})' diff --git a/mmdet/datasets/transforms/colorspace.py b/mmdet/datasets/transforms/colorspace.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ba2e97c7eedf65df5ab8942ee461f48a785f39 --- /dev/null +++ b/mmdet/datasets/transforms/colorspace.py @@ -0,0 +1,493 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness + +from mmdet.registry import TRANSFORMS +from .augment_wrappers import _MAX_LEVEL, level_to_mag + + +@TRANSFORMS.register_module() +class ColorTransform(BaseTransform): + """Base class for color transformations. All color transformations need to + inherit from this base class. ``ColorTransform`` unifies the class + attributes and class functions of color transformations (Color, Brightness, + Contrast, Sharpness, Solarize, SolarizeAdd, Equalize, AutoContrast, Invert, + and Posterize), and only distort color channels, without impacting the + locations of the instances. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing the geometric + transformation and should be in range [0, 1]. Defaults to 1.0. + level (int, optional): The level should be in range [0, _MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for color transformation. + Defaults to 0.1. + max_mag (float): The maximum magnitude for color transformation. + Defaults to 1.9. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.1, + max_mag: float = 1.9) -> None: + assert 0 <= prob <= 1.0, f'The probability of the transformation ' \ + f'should be in range [0,1], got {prob}.' + assert level is None or isinstance(level, int), \ + f'The level should be None or type int, got {type(level)}.' + assert level is None or 0 <= level <= _MAX_LEVEL, \ + f'The level should be in range [0,{_MAX_LEVEL}], got {level}.' + assert isinstance(min_mag, float), \ + f'min_mag should be type float, got {type(min_mag)}.' + assert isinstance(max_mag, float), \ + f'max_mag should be type float, got {type(max_mag)}.' + assert min_mag <= max_mag, \ + f'min_mag should smaller than max_mag, ' \ + f'got min_mag={min_mag} and max_mag={max_mag}' + self.prob = prob + self.level = level + self.min_mag = min_mag + self.max_mag = max_mag + + def _transform_img(self, results: dict, mag: float) -> None: + """Transform the image.""" + pass + + @cache_randomness + def _random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.prob + + @cache_randomness + def _get_mag(self): + """Get the magnitude of the transform.""" + return level_to_mag(self.level, self.min_mag, self.max_mag) + + def transform(self, results: dict) -> dict: + """Transform function for images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Transformed results. + """ + + if self._random_disable(): + return results + mag = self._get_mag() + self._transform_img(results, mag) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'level={self.level}, ' + repr_str += f'min_mag={self.min_mag}, ' + repr_str += f'max_mag={self.max_mag})' + return repr_str + + +@TRANSFORMS.register_module() +class Color(ColorTransform): + """Adjust the color balance of the image, in a manner similar to the + controls on a colour TV set. A magnitude=0 gives a black & white image, + whereas magnitude=1 gives the original image. The bboxes, masks and + segmentations are not modified. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Color transformation. + Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for Color transformation. + Defaults to 0.1. + max_mag (float): The maximum magnitude for Color transformation. + Defaults to 1.9. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.1, + max_mag: float = 1.9) -> None: + assert 0. <= min_mag <= 2.0, \ + f'min_mag for Color should be in range [0,2], got {min_mag}.' + assert 0. <= max_mag <= 2.0, \ + f'max_mag for Color should be in range [0,2], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """Apply Color transformation to image.""" + # NOTE defaultly the image should be BGR format + img = results['img'] + results['img'] = mmcv.adjust_color(img, mag).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Brightness(ColorTransform): + """Adjust the brightness of the image. A magnitude=0 gives a black image, + whereas magnitude=1 gives the original image. The bboxes, masks and + segmentations are not modified. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Brightness transformation. + Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for Brightness transformation. + Defaults to 0.1. + max_mag (float): The maximum magnitude for Brightness transformation. + Defaults to 1.9. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.1, + max_mag: float = 1.9) -> None: + assert 0. <= min_mag <= 2.0, \ + f'min_mag for Brightness should be in range [0,2], got {min_mag}.' + assert 0. <= max_mag <= 2.0, \ + f'max_mag for Brightness should be in range [0,2], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """Adjust the brightness of image.""" + img = results['img'] + results['img'] = mmcv.adjust_brightness(img, mag).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Contrast(ColorTransform): + """Control the contrast of the image. A magnitude=0 gives a gray image, + whereas magnitude=1 gives the original imageThe bboxes, masks and + segmentations are not modified. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Contrast transformation. + Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for Contrast transformation. + Defaults to 0.1. + max_mag (float): The maximum magnitude for Contrast transformation. + Defaults to 1.9. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.1, + max_mag: float = 1.9) -> None: + assert 0. <= min_mag <= 2.0, \ + f'min_mag for Contrast should be in range [0,2], got {min_mag}.' + assert 0. <= max_mag <= 2.0, \ + f'max_mag for Contrast should be in range [0,2], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """Adjust the image contrast.""" + img = results['img'] + results['img'] = mmcv.adjust_contrast(img, mag).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Sharpness(ColorTransform): + """Adjust images sharpness. A positive magnitude would enhance the + sharpness and a negative magnitude would make the image blurry. A + magnitude=0 gives the origin img. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Sharpness transformation. + Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for Sharpness transformation. + Defaults to 0.1. + max_mag (float): The maximum magnitude for Sharpness transformation. + Defaults to 1.9. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.1, + max_mag: float = 1.9) -> None: + assert 0. <= min_mag <= 2.0, \ + f'min_mag for Sharpness should be in range [0,2], got {min_mag}.' + assert 0. <= max_mag <= 2.0, \ + f'max_mag for Sharpness should be in range [0,2], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """Adjust the image sharpness.""" + img = results['img'] + results['img'] = mmcv.adjust_sharpness(img, mag).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Solarize(ColorTransform): + """Solarize images (Invert all pixels above a threshold value of + magnitude.). + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Solarize transformation. + Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for Solarize transformation. + Defaults to 0.0. + max_mag (float): The maximum magnitude for Solarize transformation. + Defaults to 256.0. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 256.0) -> None: + assert 0. <= min_mag <= 256.0, f'min_mag for Solarize should be ' \ + f'in range [0, 256], got {min_mag}.' + assert 0. <= max_mag <= 256.0, f'max_mag for Solarize should be ' \ + f'in range [0, 256], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """Invert all pixel values above magnitude.""" + img = results['img'] + results['img'] = mmcv.solarize(img, mag).astype(img.dtype) + + +@TRANSFORMS.register_module() +class SolarizeAdd(ColorTransform): + """SolarizeAdd images. For each pixel in the image that is less than 128, + add an additional amount to it decided by the magnitude. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing SolarizeAdd + transformation. Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for SolarizeAdd transformation. + Defaults to 0.0. + max_mag (float): The maximum magnitude for SolarizeAdd transformation. + Defaults to 110.0. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 110.0) -> None: + assert 0. <= min_mag <= 110.0, f'min_mag for SolarizeAdd should be ' \ + f'in range [0, 110], got {min_mag}.' + assert 0. <= max_mag <= 110.0, f'max_mag for SolarizeAdd should be ' \ + f'in range [0, 110], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """SolarizeAdd the image.""" + img = results['img'] + img_solarized = np.where(img < 128, np.minimum(img + mag, 255), img) + results['img'] = img_solarized.astype(img.dtype) + + +@TRANSFORMS.register_module() +class Posterize(ColorTransform): + """Posterize images (reduce the number of bits for each color channel). + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Posterize + transformation. Defaults to 1.0. + level (int, optional): Should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for Posterize transformation. + Defaults to 0.0. + max_mag (float): The maximum magnitude for Posterize transformation. + Defaults to 4.0. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 4.0) -> None: + assert 0. <= min_mag <= 8.0, f'min_mag for Posterize should be ' \ + f'in range [0, 8], got {min_mag}.' + assert 0. <= max_mag <= 8.0, f'max_mag for Posterize should be ' \ + f'in range [0, 8], got {max_mag}.' + super().__init__( + prob=prob, level=level, min_mag=min_mag, max_mag=max_mag) + + def _transform_img(self, results: dict, mag: float) -> None: + """Posterize the image.""" + img = results['img'] + results['img'] = mmcv.posterize(img, math.ceil(mag)).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Equalize(ColorTransform): + """Equalize the image histogram. The bboxes, masks and segmentations are + not modified. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing Equalize transformation. + Defaults to 1.0. + level (int, optional): No use for Equalize transformation. + Defaults to None. + min_mag (float): No use for Equalize transformation. Defaults to 0.1. + max_mag (float): No use for Equalize transformation. Defaults to 1.9. + """ + + def _transform_img(self, results: dict, mag: float) -> None: + """Equalizes the histogram of one image.""" + img = results['img'] + results['img'] = mmcv.imequalize(img).astype(img.dtype) + + +@TRANSFORMS.register_module() +class AutoContrast(ColorTransform): + """Auto adjust image contrast. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing AutoContrast should + be in range [0, 1]. Defaults to 1.0. + level (int, optional): No use for AutoContrast transformation. + Defaults to None. + min_mag (float): No use for AutoContrast transformation. + Defaults to 0.1. + max_mag (float): No use for AutoContrast transformation. + Defaults to 1.9. + """ + + def _transform_img(self, results: dict, mag: float) -> None: + """Auto adjust image contrast.""" + img = results['img'] + results['img'] = mmcv.auto_contrast(img).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Invert(ColorTransform): + """Invert images. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 1.0. + level (int, optional): No use for Invert transformation. + Defaults to None. + min_mag (float): No use for Invert transformation. Defaults to 0.1. + max_mag (float): No use for Invert transformation. Defaults to 1.9. + """ + + def _transform_img(self, results: dict, mag: float) -> None: + """Invert the image.""" + img = results['img'] + results['img'] = mmcv.iminvert(img).astype(img.dtype) diff --git a/mmdet/datasets/transforms/formatting.py b/mmdet/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..05263807c0eab470b0c73f435d327ad8cadb60b3 --- /dev/null +++ b/mmdet/datasets/transforms/formatting.py @@ -0,0 +1,512 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmengine.structures import InstanceData, PixelData + +from mmdet.registry import TRANSFORMS +from mmdet.structures import DetDataSample, ReIDDataSample, TrackDataSample +from mmdet.structures.bbox import BaseBoxes + + +@TRANSFORMS.register_module() +class PackDetInputs(BaseTransform): + """Pack the inputs data for the detection / semantic segmentation / + panoptic segmentation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: + + - ``img_id``: id of the image + + - ``img_path``: path to the image file + + - ``ori_shape``: original shape of the image as a tuple (h, w) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``scale_factor``: a float indicating the preprocessing scale + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + Args: + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')`` + """ + mapping_table = { + 'gt_bboxes': 'bboxes', + 'gt_bboxes_labels': 'labels', + 'gt_masks': 'masks' + } + + def __init__(self, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`DetDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + # To improve the computational speed by by 3-5 times, apply: + # If image is not contiguous, use + # `numpy.transpose()` followed by `numpy.ascontiguousarray()` + # If image is already contiguous, use + # `torch.permute()` followed by `torch.contiguous()` + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if not img.flags.c_contiguous: + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + img = to_tensor(img) + else: + img = to_tensor(img).permute(2, 0, 1).contiguous() + + packed_results['inputs'] = img + + if 'gt_ignore_flags' in results: + valid_idx = np.where(results['gt_ignore_flags'] == 0)[0] + ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0] + + data_sample = DetDataSample() + instance_data = InstanceData() + ignore_instance_data = InstanceData() + + for key in self.mapping_table.keys(): + if key not in results: + continue + if key == 'gt_masks' or isinstance(results[key], BaseBoxes): + if 'gt_ignore_flags' in results: + instance_data[ + self.mapping_table[key]] = results[key][valid_idx] + ignore_instance_data[ + self.mapping_table[key]] = results[key][ignore_idx] + else: + instance_data[self.mapping_table[key]] = results[key] + else: + if 'gt_ignore_flags' in results: + instance_data[self.mapping_table[key]] = to_tensor( + results[key][valid_idx]) + ignore_instance_data[self.mapping_table[key]] = to_tensor( + results[key][ignore_idx]) + else: + instance_data[self.mapping_table[key]] = to_tensor( + results[key]) + data_sample.gt_instances = instance_data + data_sample.ignored_instances = ignore_instance_data + + if 'proposals' in results: + proposals = InstanceData( + bboxes=to_tensor(results['proposals']), + scores=to_tensor(results['proposals_scores'])) + data_sample.proposals = proposals + + if 'gt_seg_map' in results: + gt_sem_seg_data = dict( + sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy())) + gt_sem_seg_data = PixelData(**gt_sem_seg_data) + if 'ignore_index' in results: + metainfo = dict(ignore_index=results['ignore_index']) + gt_sem_seg_data.set_metainfo(metainfo) + data_sample.gt_sem_seg = gt_sem_seg_data + + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class ToTensor: + """Convert some results to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert data in results to :obj:`torch.Tensor`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@TRANSFORMS.register_module() +class ImageToTensor: + """Convert image to :obj:`torch.Tensor` by given keys. + + The dimension order of input image is (H, W, C). The pipeline will convert + it to (C, H, W). If only 2 dimension (H, W) is given, the output would be + (1, H, W). + + Args: + keys (Sequence[str]): Key of images to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and permuted to (C, H, W) order. + """ + for key in self.keys: + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = to_tensor(img).permute(2, 0, 1).contiguous() + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@TRANSFORMS.register_module() +class Transpose: + """Transpose some results by given keys. + + Args: + keys (Sequence[str]): Keys of results to be transposed. + order (Sequence[int]): Order of transpose. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + """Call function to transpose the channel order of data in results. + + Args: + results (dict): Result dict contains the data to transpose. + + Returns: + dict: The result dict contains the data transposed to \ + ``self.order``. + """ + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@TRANSFORMS.register_module() +class WrapFieldsToLists: + """Wrap fields of the data dictionary into lists for evaluation. + + This class can be used as a last step of a test or validation + pipeline for single image evaluation or inference. + + Example: + >>> test_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + >>> dict(type='Pad', size_divisor=32), + >>> dict(type='ImageToTensor', keys=['img']), + >>> dict(type='Collect', keys=['img']), + >>> dict(type='WrapFieldsToLists') + >>> ] + """ + + def __call__(self, results): + """Call function to wrap fields into lists. + + Args: + results (dict): Result dict contains the data to wrap. + + Returns: + dict: The result dict where value of ``self.keys`` are wrapped \ + into list. + """ + + # Wrap dict fields into lists + for key, val in results.items(): + results[key] = [val] + return results + + def __repr__(self): + return f'{self.__class__.__name__}()' + + +@TRANSFORMS.register_module() +class PackTrackInputs(BaseTransform): + """Pack the inputs data for the multi object tracking and video instance + segmentation. All the information of images are packed to ``inputs``. All + the information except images are packed to ``data_samples``. In order to + get the original annotaiton and meta info, we add `instances` key into meta + keys. + + Args: + meta_keys (Sequence[str]): Meta keys to be collected in + ``data_sample.metainfo``. Defaults to None. + default_meta_keys (tuple): Default meta keys. Defaults to ('img_id', + 'img_path', 'ori_shape', 'img_shape', 'scale_factor', + 'flip', 'flip_direction', 'frame_id', 'is_video_data', + 'video_id', 'video_length', 'instances'). + """ + mapping_table = { + 'gt_bboxes': 'bboxes', + 'gt_bboxes_labels': 'labels', + 'gt_masks': 'masks', + 'gt_instances_ids': 'instances_ids' + } + + def __init__(self, + meta_keys: Optional[dict] = None, + default_meta_keys: tuple = ('img_id', 'img_path', 'ori_shape', + 'img_shape', 'scale_factor', + 'flip', 'flip_direction', + 'frame_id', 'video_id', + 'video_length', + 'ori_video_length', 'instances')): + self.meta_keys = default_meta_keys + if meta_keys is not None: + if isinstance(meta_keys, str): + meta_keys = (meta_keys, ) + else: + assert isinstance(meta_keys, tuple), \ + 'meta_keys must be str or tuple' + self.meta_keys += meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + Args: + results (dict): Result dict from the data pipeline. + Returns: + dict: + - 'inputs' (dict[Tensor]): The forward data of models. + - 'data_samples' (obj:`TrackDataSample`): The annotation info of + the samples. + """ + packed_results = dict() + packed_results['inputs'] = dict() + + # 1. Pack images + if 'img' in results: + imgs = results['img'] + imgs = np.stack(imgs, axis=0) + imgs = imgs.transpose(0, 3, 1, 2) + packed_results['inputs'] = to_tensor(imgs) + + # 2. Pack InstanceData + if 'gt_ignore_flags' in results: + gt_ignore_flags_list = results['gt_ignore_flags'] + valid_idx_list, ignore_idx_list = [], [] + for gt_ignore_flags in gt_ignore_flags_list: + valid_idx = np.where(gt_ignore_flags == 0)[0] + ignore_idx = np.where(gt_ignore_flags == 1)[0] + valid_idx_list.append(valid_idx) + ignore_idx_list.append(ignore_idx) + + assert 'img_id' in results, "'img_id' must contained in the results " + 'for counting the number of images' + + num_imgs = len(results['img_id']) + instance_data_list = [InstanceData() for _ in range(num_imgs)] + ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)] + + for key in self.mapping_table.keys(): + if key not in results: + continue + if key == 'gt_masks': + mapped_key = self.mapping_table[key] + gt_masks_list = results[key] + if 'gt_ignore_flags' in results: + for i, gt_mask in enumerate(gt_masks_list): + valid_idx, ignore_idx = valid_idx_list[ + i], ignore_idx_list[i] + instance_data_list[i][mapped_key] = gt_mask[valid_idx] + ignore_instance_data_list[i][mapped_key] = gt_mask[ + ignore_idx] + + else: + for i, gt_mask in enumerate(gt_masks_list): + instance_data_list[i][mapped_key] = gt_mask + + else: + anns_list = results[key] + if 'gt_ignore_flags' in results: + for i, ann in enumerate(anns_list): + valid_idx, ignore_idx = valid_idx_list[ + i], ignore_idx_list[i] + instance_data_list[i][ + self.mapping_table[key]] = to_tensor( + ann[valid_idx]) + ignore_instance_data_list[i][ + self.mapping_table[key]] = to_tensor( + ann[ignore_idx]) + else: + for i, ann in enumerate(anns_list): + instance_data_list[i][ + self.mapping_table[key]] = to_tensor(ann) + + det_data_samples_list = [] + for i in range(num_imgs): + det_data_sample = DetDataSample() + det_data_sample.gt_instances = instance_data_list[i] + det_data_sample.ignored_instances = ignore_instance_data_list[i] + det_data_samples_list.append(det_data_sample) + + # 3. Pack metainfo + for key in self.meta_keys: + if key not in results: + continue + img_metas_list = results[key] + for i, img_meta in enumerate(img_metas_list): + det_data_samples_list[i].set_metainfo({f'{key}': img_meta}) + + track_data_sample = TrackDataSample() + track_data_sample.video_data_samples = det_data_samples_list + if 'key_frame_flags' in results: + key_frame_flags = np.asarray(results['key_frame_flags']) + key_frames_inds = np.where(key_frame_flags)[0].tolist() + ref_frames_inds = np.where(~key_frame_flags)[0].tolist() + track_data_sample.set_metainfo( + dict(key_frames_inds=key_frames_inds)) + track_data_sample.set_metainfo( + dict(ref_frames_inds=ref_frames_inds)) + + packed_results['data_samples'] = track_data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'meta_keys={self.meta_keys}, ' + repr_str += f'default_meta_keys={self.default_meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class PackReIDInputs(BaseTransform): + """Pack the inputs data for the ReID. The ``meta_info`` item is always + populated. The contents of the ``meta_info`` dictionary depends on + ``meta_keys``. By default this includes: + + - ``img_path``: path to the image file. + - ``ori_shape``: original shape of the image as a tuple (H, W). + - ``img_shape``: shape of the image input to the network as a tuple + (H, W). Note that images may be zero padded on the bottom/right + if the batch tensor is larger than this shape. + - ``scale``: scale of the image as a tuple (W, H). + - ``scale_factor``: a float indicating the pre-processing scale. + - ``flip``: a boolean indicating if image flip transform was used. + - ``flip_direction``: the flipping direction. + Args: + meta_keys (Sequence[str], optional): The meta keys to saved in the + ``metainfo`` of the packed ``data_sample``. + """ + default_meta_keys = ('img_path', 'ori_shape', 'img_shape', 'scale', + 'scale_factor') + + def __init__(self, meta_keys: Sequence[str] = ()) -> None: + self.meta_keys = self.default_meta_keys + if meta_keys is not None: + if isinstance(meta_keys, str): + meta_keys = (meta_keys, ) + else: + assert isinstance(meta_keys, tuple), \ + 'meta_keys must be str or tuple.' + self.meta_keys += meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + Args: + results (dict): Result dict from the data pipeline. + Returns: + dict: + - 'inputs' (dict[Tensor]): The forward data of models. + - 'data_samples' (obj:`ReIDDataSample`): The meta info of the + sample. + """ + packed_results = dict(inputs=dict(), data_samples=None) + assert 'img' in results, 'Missing the key ``img``.' + _type = type(results['img']) + label = results['gt_label'] + + if _type == list: + img = results['img'] + label = np.stack(label, axis=0) # (N,) + assert all([type(v) == _type for v in results.values()]), \ + 'All items in the results must have the same type.' + else: + img = [results['img']] + + img = np.stack(img, axis=3) # (H, W, C, N) + img = img.transpose(3, 2, 0, 1) # (N, C, H, W) + img = np.ascontiguousarray(img) + + packed_results['inputs'] = to_tensor(img) + + data_sample = ReIDDataSample() + data_sample.set_gt_label(label) + + meta_info = dict() + for key in self.meta_keys: + meta_info[key] = results[key] + data_sample.set_metainfo(meta_info) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/mmdet/datasets/transforms/frame_sampling.py b/mmdet/datasets/transforms/frame_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a91f1e7880f8f061f183dc30a01758d97b7d03da --- /dev/null +++ b/mmdet/datasets/transforms/frame_sampling.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from collections import defaultdict +from typing import Dict, List, Optional, Union + +from mmcv.transforms import BaseTransform + +from mmdet.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class BaseFrameSample(BaseTransform): + """Directly get the key frame, no reference frames. + + Args: + collect_video_keys (list[str]): The keys of video info to be + collected. + """ + + def __init__(self, + collect_video_keys: List[str] = ['video_id', 'video_length']): + self.collect_video_keys = collect_video_keys + + def prepare_data(self, video_infos: dict, + sampled_inds: List[int]) -> Dict[str, List]: + """Prepare data for the subsequent pipeline. + + Args: + video_infos (dict): The whole video information. + sampled_inds (list[int]): The sampled frame indices. + + Returns: + dict: The processed data information. + """ + frames_anns = video_infos['images'] + final_data_info = defaultdict(list) + # for data in frames_anns: + for index in sampled_inds: + data = frames_anns[index] + # copy the info in video-level into img-level + for key in self.collect_video_keys: + if key == 'video_length': + data['ori_video_length'] = video_infos[key] + data['video_length'] = len(sampled_inds) + else: + data[key] = video_infos[key] + # Collate data_list (list of dict to dict of list) + for key, value in data.items(): + final_data_info[key].append(value) + + return final_data_info + + def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: + """Transform the video information. + + Args: + video_infos (dict): The whole video information. + + Returns: + dict: The data information of the key frames. + """ + if 'key_frame_id' in video_infos: + key_frame_id = video_infos['key_frame_id'] + assert isinstance(video_infos['key_frame_id'], int) + else: + key_frame_id = random.sample( + list(range(video_infos['video_length'])), 1)[0] + results = self.prepare_data(video_infos, [key_frame_id]) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(collect_video_keys={self.collect_video_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class UniformRefFrameSample(BaseFrameSample): + """Uniformly sample reference frames. + + Args: + num_ref_imgs (int): Number of reference frames to be sampled. + frame_range (int | list[int]): Range of frames to be sampled around + key frame. If int, the range is [-frame_range, frame_range]. + Defaults to 10. + filter_key_img (bool): Whether to filter the key frame when + sampling reference frames. Defaults to True. + collect_video_keys (list[str]): The keys of video info to be + collected. + """ + + def __init__(self, + num_ref_imgs: int = 1, + frame_range: Union[int, List[int]] = 10, + filter_key_img: bool = True, + collect_video_keys: List[str] = ['video_id', 'video_length']): + self.num_ref_imgs = num_ref_imgs + self.filter_key_img = filter_key_img + if isinstance(frame_range, int): + assert frame_range >= 0, 'frame_range can not be a negative value.' + frame_range = [-frame_range, frame_range] + elif isinstance(frame_range, list): + assert len(frame_range) == 2, 'The length must be 2.' + assert frame_range[0] <= 0 and frame_range[1] >= 0 + for i in frame_range: + assert isinstance(i, int), 'Each element must be int.' + else: + raise TypeError('The type of frame_range must be int or list.') + self.frame_range = frame_range + super().__init__(collect_video_keys=collect_video_keys) + + def sampling_frames(self, video_length: int, key_frame_id: int): + """Sampling frames. + + Args: + video_length (int): The length of the video. + key_frame_id (int): The key frame id. + + Returns: + list[int]: The sampled frame indices. + """ + if video_length > 1: + left = max(0, key_frame_id + self.frame_range[0]) + right = min(key_frame_id + self.frame_range[1], video_length - 1) + frame_ids = list(range(0, video_length)) + + valid_ids = frame_ids[left:right + 1] + if self.filter_key_img and key_frame_id in valid_ids: + valid_ids.remove(key_frame_id) + assert len( + valid_ids + ) > 0, 'After filtering key frame, there are no valid frames' + if len(valid_ids) < self.num_ref_imgs: + valid_ids = valid_ids * self.num_ref_imgs + ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs) + else: + ref_frame_ids = [key_frame_id] * self.num_ref_imgs + + sampled_frames_ids = [key_frame_id] + ref_frame_ids + sampled_frames_ids = sorted(sampled_frames_ids) + + key_frames_ind = sampled_frames_ids.index(key_frame_id) + key_frame_flags = [False] * len(sampled_frames_ids) + key_frame_flags[key_frames_ind] = True + return sampled_frames_ids, key_frame_flags + + def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: + """Transform the video information. + + Args: + video_infos (dict): The whole video information. + + Returns: + dict: The data information of the sampled frames. + """ + if 'key_frame_id' in video_infos: + key_frame_id = video_infos['key_frame_id'] + assert isinstance(video_infos['key_frame_id'], int) + else: + key_frame_id = random.sample( + list(range(video_infos['video_length'])), 1)[0] + + (sampled_frames_ids, key_frame_flags) = self.sampling_frames( + video_infos['video_length'], key_frame_id=key_frame_id) + results = self.prepare_data(video_infos, sampled_frames_ids) + results['key_frame_flags'] = key_frame_flags + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(num_ref_imgs={self.num_ref_imgs}, ' + repr_str += f'frame_range={self.frame_range}, ' + repr_str += f'filter_key_img={self.filter_key_img}, ' + repr_str += f'collect_video_keys={self.collect_video_keys})' + return repr_str diff --git a/mmdet/datasets/transforms/geometric.py b/mmdet/datasets/transforms/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..d2cd6be258f73a69aa2c2b36fef64c6c4e46a2a4 --- /dev/null +++ b/mmdet/datasets/transforms/geometric.py @@ -0,0 +1,754 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional, Union + +import cv2 +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import autocast_box_type +from .augment_wrappers import _MAX_LEVEL, level_to_mag + + +@TRANSFORMS.register_module() +class GeomTransform(BaseTransform): + """Base class for geometric transformations. All geometric transformations + need to inherit from this base class. ``GeomTransform`` unifies the class + attributes and class functions of geometric transformations (ShearX, + ShearY, Rotate, TranslateX, and TranslateY), and records the homography + matrix. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + prob (float): The probability for performing the geometric + transformation and should be in range [0, 1]. Defaults to 1.0. + level (int, optional): The level should be in range [0, _MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum magnitude for geometric transformation. + Defaults to 0.0. + max_mag (float): The maximum magnitude for geometric transformation. + Defaults to 1.0. + reversal_prob (float): The probability that reverses the geometric + transformation magnitude. Should be in range [0,1]. + Defaults to 0.5. + img_border_value (int | float | tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 1.0, + reversal_prob: float = 0.5, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + interpolation: str = 'bilinear') -> None: + assert 0 <= prob <= 1.0, f'The probability of the transformation ' \ + f'should be in range [0,1], got {prob}.' + assert level is None or isinstance(level, int), \ + f'The level should be None or type int, got {type(level)}.' + assert level is None or 0 <= level <= _MAX_LEVEL, \ + f'The level should be in range [0,{_MAX_LEVEL}], got {level}.' + assert isinstance(min_mag, float), \ + f'min_mag should be type float, got {type(min_mag)}.' + assert isinstance(max_mag, float), \ + f'max_mag should be type float, got {type(max_mag)}.' + assert min_mag <= max_mag, \ + f'min_mag should smaller than max_mag, ' \ + f'got min_mag={min_mag} and max_mag={max_mag}' + assert isinstance(reversal_prob, float), \ + f'reversal_prob should be type float, got {type(max_mag)}.' + assert 0 <= reversal_prob <= 1.0, \ + f'The reversal probability of the transformation magnitude ' \ + f'should be type float, got {type(reversal_prob)}.' + if isinstance(img_border_value, (float, int)): + img_border_value = tuple([float(img_border_value)] * 3) + elif isinstance(img_border_value, tuple): + assert len(img_border_value) == 3, \ + f'img_border_value as tuple must have 3 elements, ' \ + f'got {len(img_border_value)}.' + img_border_value = tuple([float(val) for val in img_border_value]) + else: + raise ValueError( + 'img_border_value must be float or tuple with 3 elements.') + assert np.all([0 <= val <= 255 for val in img_border_value]), 'all ' \ + 'elements of img_border_value should between range [0,255].' \ + f'got {img_border_value}.' + self.prob = prob + self.level = level + self.min_mag = min_mag + self.max_mag = max_mag + self.reversal_prob = reversal_prob + self.img_border_value = img_border_value + self.mask_border_value = mask_border_value + self.seg_ignore_label = seg_ignore_label + self.interpolation = interpolation + + def _transform_img(self, results: dict, mag: float) -> None: + """Transform the image.""" + pass + + def _transform_masks(self, results: dict, mag: float) -> None: + """Transform the masks.""" + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + """Transform the segmentation map.""" + pass + + def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray: + """Get the homography matrix for the geometric transformation.""" + return np.eye(3, dtype=np.float32) + + def _transform_bboxes(self, results: dict, mag: float) -> None: + """Transform the bboxes.""" + results['gt_bboxes'].project_(self.homography_matrix) + results['gt_bboxes'].clip_(results['img_shape']) + + def _record_homography_matrix(self, results: dict) -> None: + """Record the homography matrix for the geometric transformation.""" + if results.get('homography_matrix', None) is None: + results['homography_matrix'] = self.homography_matrix + else: + results['homography_matrix'] = self.homography_matrix @ results[ + 'homography_matrix'] + + @cache_randomness + def _random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.prob + + @cache_randomness + def _get_mag(self): + """Get the magnitude of the transform.""" + mag = level_to_mag(self.level, self.min_mag, self.max_mag) + return -mag if np.random.rand() > self.reversal_prob else mag + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function for images, bounding boxes, masks and semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Transformed results. + """ + + if self._random_disable(): + return results + mag = self._get_mag() + self.homography_matrix = self._get_homography_matrix(results, mag) + self._record_homography_matrix(results) + self._transform_img(results, mag) + if results.get('gt_bboxes', None) is not None: + self._transform_bboxes(results, mag) + if results.get('gt_masks', None) is not None: + self._transform_masks(results, mag) + if results.get('gt_seg_map', None) is not None: + self._transform_seg(results, mag) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'level={self.level}, ' + repr_str += f'min_mag={self.min_mag}, ' + repr_str += f'max_mag={self.max_mag}, ' + repr_str += f'reversal_prob={self.reversal_prob}, ' + repr_str += f'img_border_value={self.img_border_value}, ' + repr_str += f'mask_border_value={self.mask_border_value}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class ShearX(GeomTransform): + """Shear the images, bboxes, masks and segmentation map horizontally. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + prob (float): The probability for performing Shear and should be in + range [0, 1]. Defaults to 1.0. + level (int, optional): The level should be in range [0, _MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum angle for the horizontal shear. + Defaults to 0.0. + max_mag (float): The maximum angle for the horizontal shear. + Defaults to 30.0. + reversal_prob (float): The probability that reverses the horizontal + shear magnitude. Should be in range [0,1]. Defaults to 0.5. + img_border_value (int | float | tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 30.0, + reversal_prob: float = 0.5, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + interpolation: str = 'bilinear') -> None: + assert 0. <= min_mag <= 90., \ + f'min_mag angle for ShearX should be ' \ + f'in range [0, 90], got {min_mag}.' + assert 0. <= max_mag <= 90., \ + f'max_mag angle for ShearX should be ' \ + f'in range [0, 90], got {max_mag}.' + super().__init__( + prob=prob, + level=level, + min_mag=min_mag, + max_mag=max_mag, + reversal_prob=reversal_prob, + img_border_value=img_border_value, + mask_border_value=mask_border_value, + seg_ignore_label=seg_ignore_label, + interpolation=interpolation) + + @cache_randomness + def _get_mag(self): + """Get the magnitude of the transform.""" + mag = level_to_mag(self.level, self.min_mag, self.max_mag) + mag = np.tan(mag * np.pi / 180) + return -mag if np.random.rand() > self.reversal_prob else mag + + def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray: + """Get the homography matrix for ShearX.""" + return np.array([[1, mag, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + + def _transform_img(self, results: dict, mag: float) -> None: + """Shear the image horizontally.""" + results['img'] = mmcv.imshear( + results['img'], + mag, + direction='horizontal', + border_value=self.img_border_value, + interpolation=self.interpolation) + + def _transform_masks(self, results: dict, mag: float) -> None: + """Shear the masks horizontally.""" + results['gt_masks'] = results['gt_masks'].shear( + results['img_shape'], + mag, + direction='horizontal', + border_value=self.mask_border_value, + interpolation=self.interpolation) + + def _transform_seg(self, results: dict, mag: float) -> None: + """Shear the segmentation map horizontally.""" + results['gt_seg_map'] = mmcv.imshear( + results['gt_seg_map'], + mag, + direction='horizontal', + border_value=self.seg_ignore_label, + interpolation='nearest') + + +@TRANSFORMS.register_module() +class ShearY(GeomTransform): + """Shear the images, bboxes, masks and segmentation map vertically. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + prob (float): The probability for performing ShearY and should be in + range [0, 1]. Defaults to 1.0. + level (int, optional): The level should be in range [0,_MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum angle for the vertical shear. + Defaults to 0.0. + max_mag (float): The maximum angle for the vertical shear. + Defaults to 30.0. + reversal_prob (float): The probability that reverses the vertical + shear magnitude. Should be in range [0,1]. Defaults to 0.5. + img_border_value (int | float | tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 30., + reversal_prob: float = 0.5, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + interpolation: str = 'bilinear') -> None: + assert 0. <= min_mag <= 90., \ + f'min_mag angle for ShearY should be ' \ + f'in range [0, 90], got {min_mag}.' + assert 0. <= max_mag <= 90., \ + f'max_mag angle for ShearY should be ' \ + f'in range [0, 90], got {max_mag}.' + super().__init__( + prob=prob, + level=level, + min_mag=min_mag, + max_mag=max_mag, + reversal_prob=reversal_prob, + img_border_value=img_border_value, + mask_border_value=mask_border_value, + seg_ignore_label=seg_ignore_label, + interpolation=interpolation) + + @cache_randomness + def _get_mag(self): + """Get the magnitude of the transform.""" + mag = level_to_mag(self.level, self.min_mag, self.max_mag) + mag = np.tan(mag * np.pi / 180) + return -mag if np.random.rand() > self.reversal_prob else mag + + def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray: + """Get the homography matrix for ShearY.""" + return np.array([[1, 0, 0], [mag, 1, 0], [0, 0, 1]], dtype=np.float32) + + def _transform_img(self, results: dict, mag: float) -> None: + """Shear the image vertically.""" + results['img'] = mmcv.imshear( + results['img'], + mag, + direction='vertical', + border_value=self.img_border_value, + interpolation=self.interpolation) + + def _transform_masks(self, results: dict, mag: float) -> None: + """Shear the masks vertically.""" + results['gt_masks'] = results['gt_masks'].shear( + results['img_shape'], + mag, + direction='vertical', + border_value=self.mask_border_value, + interpolation=self.interpolation) + + def _transform_seg(self, results: dict, mag: float) -> None: + """Shear the segmentation map vertically.""" + results['gt_seg_map'] = mmcv.imshear( + results['gt_seg_map'], + mag, + direction='vertical', + border_value=self.seg_ignore_label, + interpolation='nearest') + + +@TRANSFORMS.register_module() +class Rotate(GeomTransform): + """Rotate the images, bboxes, masks and segmentation map. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + prob (float): The probability for perform transformation and + should be in range 0 to 1. Defaults to 1.0. + level (int, optional): The level should be in range [0, _MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The maximum angle for rotation. + Defaults to 0.0. + max_mag (float): The maximum angle for rotation. + Defaults to 30.0. + reversal_prob (float): The probability that reverses the rotation + magnitude. Should be in range [0,1]. Defaults to 0.5. + img_border_value (int | float | tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 30.0, + reversal_prob: float = 0.5, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + interpolation: str = 'bilinear') -> None: + assert 0. <= min_mag <= 180., \ + f'min_mag for Rotate should be in range [0,180], got {min_mag}.' + assert 0. <= max_mag <= 180., \ + f'max_mag for Rotate should be in range [0,180], got {max_mag}.' + super().__init__( + prob=prob, + level=level, + min_mag=min_mag, + max_mag=max_mag, + reversal_prob=reversal_prob, + img_border_value=img_border_value, + mask_border_value=mask_border_value, + seg_ignore_label=seg_ignore_label, + interpolation=interpolation) + + def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray: + """Get the homography matrix for Rotate.""" + img_shape = results['img_shape'] + center = ((img_shape[1] - 1) * 0.5, (img_shape[0] - 1) * 0.5) + cv2_rotation_matrix = cv2.getRotationMatrix2D(center, -mag, 1.0) + return np.concatenate( + [cv2_rotation_matrix, + np.array([0, 0, 1]).reshape((1, 3))]).astype(np.float32) + + def _transform_img(self, results: dict, mag: float) -> None: + """Rotate the image.""" + results['img'] = mmcv.imrotate( + results['img'], + mag, + border_value=self.img_border_value, + interpolation=self.interpolation) + + def _transform_masks(self, results: dict, mag: float) -> None: + """Rotate the masks.""" + results['gt_masks'] = results['gt_masks'].rotate( + results['img_shape'], + mag, + border_value=self.mask_border_value, + interpolation=self.interpolation) + + def _transform_seg(self, results: dict, mag: float) -> None: + """Rotate the segmentation map.""" + results['gt_seg_map'] = mmcv.imrotate( + results['gt_seg_map'], + mag, + border_value=self.seg_ignore_label, + interpolation='nearest') + + +@TRANSFORMS.register_module() +class TranslateX(GeomTransform): + """Translate the images, bboxes, masks and segmentation map horizontally. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + prob (float): The probability for perform transformation and + should be in range 0 to 1. Defaults to 1.0. + level (int, optional): The level should be in range [0, _MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum pixel's offset ratio for horizontal + translation. Defaults to 0.0. + max_mag (float): The maximum pixel's offset ratio for horizontal + translation. Defaults to 0.1. + reversal_prob (float): The probability that reverses the horizontal + translation magnitude. Should be in range [0,1]. Defaults to 0.5. + img_border_value (int | float | tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 0.1, + reversal_prob: float = 0.5, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + interpolation: str = 'bilinear') -> None: + assert 0. <= min_mag <= 1., \ + f'min_mag ratio for TranslateX should be ' \ + f'in range [0, 1], got {min_mag}.' + assert 0. <= max_mag <= 1., \ + f'max_mag ratio for TranslateX should be ' \ + f'in range [0, 1], got {max_mag}.' + super().__init__( + prob=prob, + level=level, + min_mag=min_mag, + max_mag=max_mag, + reversal_prob=reversal_prob, + img_border_value=img_border_value, + mask_border_value=mask_border_value, + seg_ignore_label=seg_ignore_label, + interpolation=interpolation) + + def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray: + """Get the homography matrix for TranslateX.""" + mag = int(results['img_shape'][1] * mag) + return np.array([[1, 0, mag], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + + def _transform_img(self, results: dict, mag: float) -> None: + """Translate the image horizontally.""" + mag = int(results['img_shape'][1] * mag) + results['img'] = mmcv.imtranslate( + results['img'], + mag, + direction='horizontal', + border_value=self.img_border_value, + interpolation=self.interpolation) + + def _transform_masks(self, results: dict, mag: float) -> None: + """Translate the masks horizontally.""" + mag = int(results['img_shape'][1] * mag) + results['gt_masks'] = results['gt_masks'].translate( + results['img_shape'], + mag, + direction='horizontal', + border_value=self.mask_border_value, + interpolation=self.interpolation) + + def _transform_seg(self, results: dict, mag: float) -> None: + """Translate the segmentation map horizontally.""" + mag = int(results['img_shape'][1] * mag) + results['gt_seg_map'] = mmcv.imtranslate( + results['gt_seg_map'], + mag, + direction='horizontal', + border_value=self.seg_ignore_label, + interpolation='nearest') + + +@TRANSFORMS.register_module() +class TranslateY(GeomTransform): + """Translate the images, bboxes, masks and segmentation map vertically. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - homography_matrix + + Args: + prob (float): The probability for perform transformation and + should be in range 0 to 1. Defaults to 1.0. + level (int, optional): The level should be in range [0, _MAX_LEVEL]. + If level is None, it will generate from [0, _MAX_LEVEL] randomly. + Defaults to None. + min_mag (float): The minimum pixel's offset ratio for vertical + translation. Defaults to 0.0. + max_mag (float): The maximum pixel's offset ratio for vertical + translation. Defaults to 0.1. + reversal_prob (float): The probability that reverses the vertical + translation magnitude. Should be in range [0,1]. Defaults to 0.5. + img_border_value (int | float | tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + prob: float = 1.0, + level: Optional[int] = None, + min_mag: float = 0.0, + max_mag: float = 0.1, + reversal_prob: float = 0.5, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + interpolation: str = 'bilinear') -> None: + assert 0. <= min_mag <= 1., \ + f'min_mag ratio for TranslateY should be ' \ + f'in range [0,1], got {min_mag}.' + assert 0. <= max_mag <= 1., \ + f'max_mag ratio for TranslateY should be ' \ + f'in range [0,1], got {max_mag}.' + super().__init__( + prob=prob, + level=level, + min_mag=min_mag, + max_mag=max_mag, + reversal_prob=reversal_prob, + img_border_value=img_border_value, + mask_border_value=mask_border_value, + seg_ignore_label=seg_ignore_label, + interpolation=interpolation) + + def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray: + """Get the homography matrix for TranslateY.""" + mag = int(results['img_shape'][0] * mag) + return np.array([[1, 0, 0], [0, 1, mag], [0, 0, 1]], dtype=np.float32) + + def _transform_img(self, results: dict, mag: float) -> None: + """Translate the image vertically.""" + mag = int(results['img_shape'][0] * mag) + results['img'] = mmcv.imtranslate( + results['img'], + mag, + direction='vertical', + border_value=self.img_border_value, + interpolation=self.interpolation) + + def _transform_masks(self, results: dict, mag: float) -> None: + """Translate masks vertically.""" + mag = int(results['img_shape'][0] * mag) + results['gt_masks'] = results['gt_masks'].translate( + results['img_shape'], + mag, + direction='vertical', + border_value=self.mask_border_value, + interpolation=self.interpolation) + + def _transform_seg(self, results: dict, mag: float) -> None: + """Translate segmentation map vertically.""" + mag = int(results['img_shape'][0] * mag) + results['gt_seg_map'] = mmcv.imtranslate( + results['gt_seg_map'], + mag, + direction='vertical', + border_value=self.seg_ignore_label, + interpolation='nearest') diff --git a/mmdet/datasets/transforms/instaboost.py b/mmdet/datasets/transforms/instaboost.py new file mode 100644 index 0000000000000000000000000000000000000000..30dc1603643ec8d398bfade95f5ec1c9b8f89c8d --- /dev/null +++ b/mmdet/datasets/transforms/instaboost.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +from mmcv.transforms import BaseTransform + +from mmdet.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class InstaBoost(BaseTransform): + r"""Data augmentation method in `InstaBoost: Boosting Instance + Segmentation Via Probability Map Guided Copy-Pasting + `_. + + Refer to https://github.com/GothicAi/Instaboost for implementation details. + + + Required Keys: + + - img (np.uint8) + - instances + + Modified Keys: + + - img (np.uint8) + - instances + + Args: + action_candidate (tuple): Action candidates. "normal", "horizontal", \ + "vertical", "skip" are supported. Defaults to ('normal', \ + 'horizontal', 'skip'). + action_prob (tuple): Corresponding action probabilities. Should be \ + the same length as action_candidate. Defaults to (1, 0, 0). + scale (tuple): (min scale, max scale). Defaults to (0.8, 1.2). + dx (int): The maximum x-axis shift will be (instance width) / dx. + Defaults to 15. + dy (int): The maximum y-axis shift will be (instance height) / dy. + Defaults to 15. + theta (tuple): (min rotation degree, max rotation degree). \ + Defaults to (-1, 1). + color_prob (float): Probability of images for color augmentation. + Defaults to 0.5. + hflag (bool): Whether to use heatmap guided. Defaults to False. + aug_ratio (float): Probability of applying this transformation. \ + Defaults to 0.5. + """ + + def __init__(self, + action_candidate: tuple = ('normal', 'horizontal', 'skip'), + action_prob: tuple = (1, 0, 0), + scale: tuple = (0.8, 1.2), + dx: int = 15, + dy: int = 15, + theta: tuple = (-1, 1), + color_prob: float = 0.5, + hflag: bool = False, + aug_ratio: float = 0.5) -> None: + + import matplotlib + import matplotlib.pyplot as plt + default_backend = plt.get_backend() + + try: + import instaboostfast as instaboost + except ImportError: + raise ImportError( + 'Please run "pip install instaboostfast" ' + 'to install instaboostfast first for instaboost augmentation.') + + # instaboost will modify the default backend + # and cause visualization to fail. + matplotlib.use(default_backend) + + self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob, + scale, dx, dy, theta, + color_prob, hflag) + self.aug_ratio = aug_ratio + + def _load_anns(self, results: dict) -> Tuple[list, list]: + """Convert raw anns to instaboost expected input format.""" + anns = [] + ignore_anns = [] + for instance in results['instances']: + label = instance['bbox_label'] + bbox = instance['bbox'] + mask = instance['mask'] + x1, y1, x2, y2 = bbox + # assert (x2 - x1) >= 1 and (y2 - y1) >= 1 + bbox = [x1, y1, x2 - x1, y2 - y1] + + if instance['ignore_flag'] == 0: + anns.append({ + 'category_id': label, + 'segmentation': mask, + 'bbox': bbox + }) + else: + # Ignore instances without data augmentation + ignore_anns.append(instance) + return anns, ignore_anns + + def _parse_anns(self, results: dict, anns: list, ignore_anns: list, + img: np.ndarray) -> dict: + """Restore the result of instaboost processing to the original anns + format.""" + instances = [] + for ann in anns: + x1, y1, w, h = ann['bbox'] + # TODO: more essential bug need to be fixed in instaboost + if w <= 0 or h <= 0: + continue + bbox = [x1, y1, x1 + w, y1 + h] + instances.append( + dict( + bbox=bbox, + bbox_label=ann['category_id'], + mask=ann['segmentation'], + ignore_flag=0)) + + instances.extend(ignore_anns) + results['img'] = img + results['instances'] = instances + return results + + def transform(self, results) -> dict: + """The transform function.""" + img = results['img'] + ori_type = img.dtype + if 'instances' not in results or len(results['instances']) == 0: + return results + + anns, ignore_anns = self._load_anns(results) + if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]): + try: + import instaboostfast as instaboost + except ImportError: + raise ImportError('Please run "pip install instaboostfast" ' + 'to install instaboostfast first.') + anns, img = instaboost.get_new_data( + anns, img.astype(np.uint8), self.cfg, background=None) + + results = self._parse_anns(results, anns, ignore_anns, + img.astype(ori_type)) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(aug_ratio={self.aug_ratio})' + return repr_str diff --git a/mmdet/datasets/transforms/loading.py b/mmdet/datasets/transforms/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..722d4b0e7c830dfde2412746db1258b880167a2f --- /dev/null +++ b/mmdet/datasets/transforms/loading.py @@ -0,0 +1,1074 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import mmcv +import numpy as np +import pycocotools.mask as maskUtils +import torch +from mmcv.transforms import BaseTransform +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile +from mmengine.fileio import get +from mmengine.structures import BaseDataElement + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import get_box_type +from mmdet.structures.bbox.box_type import autocast_box_type +from mmdet.structures.mask import BitmapMasks, PolygonMasks + + +@TRANSFORMS.register_module() +class LoadImageFromNDArray(LoadImageFromFile): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results['img'] + if self.to_float32: + img = img.astype(np.float32) + + results['img_path'] = None + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + +@TRANSFORMS.register_module() +class LoadMultiChannelImageFromFiles(BaseTransform): + """Load multi-channel images from a list of separate channel files. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:``mmcv.imfrombytes``. + Defaults to 'unchanged'. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :func:``mmcv.imfrombytes`` for details. + Defaults to 'cv2'. + file_client_args (dict): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet >= 3.0.0rc7. Defaults to None. + """ + + def __init__( + self, + to_float32: bool = False, + color_type: str = 'unchanged', + imdecode_backend: str = 'cv2', + file_client_args: dict = None, + backend_args: dict = None, + ) -> None: + self.to_float32 = to_float32 + self.color_type = color_type + self.imdecode_backend = imdecode_backend + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + def transform(self, results: dict) -> dict: + """Transform functions to load multiple images and get images meta + information. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded images and meta information. + """ + + assert isinstance(results['img_path'], list) + img = [] + for name in results['img_path']: + img_bytes = get(name, backend_args=self.backend_args) + img.append( + mmcv.imfrombytes( + img_bytes, + flag=self.color_type, + backend=self.imdecode_backend)) + img = np.stack(img, axis=-1) + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32}, ' + f"color_type='{self.color_type}', " + f"imdecode_backend='{self.imdecode_backend}', " + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadAnnotations(MMCV_LoadAnnotations): + """Load and process the ``instances`` and ``seg_map`` annotation provided + by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + 'instances': + [ + { + # List of 4 numbers representing the bounding box of the + # instance, in (x1, y1, x2, y2) order. + 'bbox': [x1, y1, x2, y2], + + # Label of image classification. + 'bbox_label': 1, + + # Used in instance/panoptic segmentation. The segmentation mask + # of the instance or the information of segments. + # 1. If list[list[float]], it represents a list of polygons, + # one for each connected component of the object. Each + # list[float] is one simple polygon in the format of + # [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute + # coordinates in unit of pixels. + # 2. If dict, it represents the per-pixel segmentation mask in + # COCO's compressed RLE format. The dict should have keys + # “size” and “counts”. Can be loaded by pycocotools + 'mask': list[list[float]] or dict, + + } + ] + # Filename of semantic or panoptic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # In (x1, y1, x2, y2) order, float type. N is the number of bboxes + # in an image + 'gt_bboxes': BaseBoxes(N, 4) + # In int type. + 'gt_bboxes_labels': np.ndarray(N, ) + # In built-in class + 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W) + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + # in (x, y, v) order, float type. + } + + Required Keys: + + - height + - width + - instances + + - bbox (optional) + - bbox_label + - mask (optional) + - ignore_flag + + - seg_map_path (optional) + + Added Keys: + + - gt_bboxes (BaseBoxes[torch.float32]) + - gt_bboxes_labels (np.int64) + - gt_masks (BitmapMasks | PolygonMasks) + - gt_seg_map (np.uint8) + - gt_ignore_flags (bool) + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Defaults to True. + with_label (bool): Whether to parse and load the label annotation. + Defaults to True. + with_mask (bool): Whether to parse and load the mask annotation. + Default: False. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Defaults to False. + poly2mask (bool): Whether to convert mask to bitmap. Default: True. + box_type (str): The box type used to wrap the bboxes. If ``box_type`` + is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'. + reduce_zero_label (bool): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to False. + ignore_index (int): The label index to be ignored. + Valid only if reduce_zero_label is true. Defaults is 255. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'cv2'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + + def __init__( + self, + with_mask: bool = False, + poly2mask: bool = True, + box_type: str = 'hbox', + # use for semseg + reduce_zero_label: bool = False, + ignore_index: int = 255, + **kwargs) -> None: + super(LoadAnnotations, self).__init__(**kwargs) + self.with_mask = with_mask + self.poly2mask = poly2mask + self.box_type = box_type + self.reduce_zero_label = reduce_zero_label + self.ignore_index = ignore_index + + def _load_bboxes(self, results: dict) -> None: + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + Returns: + dict: The dict contains loaded bounding box annotations. + """ + gt_bboxes = [] + gt_ignore_flags = [] + for instance in results.get('instances', []): + gt_bboxes.append(instance['bbox']) + gt_ignore_flags.append(instance['ignore_flag']) + if self.box_type is None: + results['gt_bboxes'] = np.array( + gt_bboxes, dtype=np.float32).reshape((-1, 4)) + else: + _, box_type_cls = get_box_type(self.box_type) + results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32) + results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool) + + def _load_labels(self, results: dict) -> None: + """Private function to load label annotations. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + + Returns: + dict: The dict contains loaded label annotations. + """ + gt_bboxes_labels = [] + for instance in results.get('instances', []): + gt_bboxes_labels.append(instance['bbox_label']) + # TODO: Inconsistent with mmcv, consider how to deal with it later. + results['gt_bboxes_labels'] = np.array( + gt_bboxes_labels, dtype=np.int64) + + def _poly2mask(self, mask_ann: Union[list, dict], img_h: int, + img_w: int) -> np.ndarray: + """Private function to convert masks represented with polygon to + bitmaps. + + Args: + mask_ann (list | dict): Polygon mask annotation input. + img_h (int): The height of output mask. + img_w (int): The width of output mask. + + Returns: + np.ndarray: The decode bitmap mask of shape (img_h, img_w). + """ + + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) + rle = maskUtils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = maskUtils.decode(rle) + return mask + + def _process_masks(self, results: dict) -> list: + """Process gt_masks and filter invalid polygons. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + + Returns: + list: Processed gt_masks. + """ + gt_masks = [] + gt_ignore_flags = [] + for instance in results.get('instances', []): + gt_mask = instance['mask'] + # If the annotation of segmentation mask is invalid, + # ignore the whole instance. + if isinstance(gt_mask, list): + gt_mask = [ + np.array(polygon) for polygon in gt_mask + if len(polygon) % 2 == 0 and len(polygon) >= 6 + ] + if len(gt_mask) == 0: + # ignore this instance and set gt_mask to a fake mask + instance['ignore_flag'] = 1 + gt_mask = [np.zeros(6)] + elif not self.poly2mask: + # `PolygonMasks` requires a ploygon of format List[np.array], + # other formats are invalid. + instance['ignore_flag'] = 1 + gt_mask = [np.zeros(6)] + elif isinstance(gt_mask, dict) and \ + not (gt_mask.get('counts') is not None and + gt_mask.get('size') is not None and + isinstance(gt_mask['counts'], (list, str))): + # if gt_mask is a dict, it should include `counts` and `size`, + # so that `BitmapMasks` can uncompressed RLE + instance['ignore_flag'] = 1 + gt_mask = [np.zeros(6)] + gt_masks.append(gt_mask) + # re-process gt_ignore_flags + gt_ignore_flags.append(instance['ignore_flag']) + results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool) + return gt_masks + + def _load_masks(self, results: dict) -> None: + """Private function to load mask annotations. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + """ + h, w = results['ori_shape'] + gt_masks = self._process_masks(results) + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + # fake polygon masks will be ignored in `PackDetInputs` + gt_masks = PolygonMasks([mask for mask in gt_masks], h, w) + results['gt_masks'] = gt_masks + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + if results.get('seg_map_path', None) is None: + return + + img_bytes = get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze() + + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = self.ignore_index + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == self.ignore_index - + 1] = self.ignore_index + + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['ignore_index'] = self.ignore_index + + def transform(self, results: dict) -> dict: + """Function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + + Returns: + dict: The dict contains loaded bounding box, label and + semantic segmentation. + """ + + if self.with_bbox: + self._load_bboxes(results) + if self.with_label: + self._load_labels(results) + if self.with_mask: + self._load_masks(results) + if self.with_seg: + self._load_seg_map(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg}, ' + repr_str += f'poly2mask={self.poly2mask}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class LoadPanopticAnnotations(LoadAnnotations): + """Load multiple types of panoptic annotations. + + The annotation format is as the following: + + .. code-block:: python + + { + 'instances': + [ + { + # List of 4 numbers representing the bounding box of the + # instance, in (x1, y1, x2, y2) order. + 'bbox': [x1, y1, x2, y2], + + # Label of image classification. + 'bbox_label': 1, + }, + ... + ] + 'segments_info': + [ + { + # id = cls_id + instance_id * INSTANCE_OFFSET + 'id': int, + + # Contiguous category id defined in dataset. + 'category': int + + # Thing flag. + 'is_thing': bool + }, + ... + ] + + # Filename of semantic or panoptic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # In (x1, y1, x2, y2) order, float type. N is the number of bboxes + # in an image + 'gt_bboxes': BaseBoxes(N, 4) + # In int type. + 'gt_bboxes_labels': np.ndarray(N, ) + # In built-in class + 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W) + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + # in (x, y, v) order, float type. + } + + Required Keys: + + - height + - width + - instances + - bbox + - bbox_label + - ignore_flag + - segments_info + - id + - category + - is_thing + - seg_map_path + + Added Keys: + + - gt_bboxes (BaseBoxes[torch.float32]) + - gt_bboxes_labels (np.int64) + - gt_masks (BitmapMasks | PolygonMasks) + - gt_seg_map (np.uint8) + - gt_ignore_flags (bool) + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Defaults to True. + with_label (bool): Whether to parse and load the label annotation. + Defaults to True. + with_mask (bool): Whether to parse and load the mask annotation. + Defaults to True. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Defaults to False. + box_type (str): The box mode used to wrap the bboxes. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'cv2'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet >= 3.0.0rc7. Defaults to None. + """ + + def __init__(self, + with_bbox: bool = True, + with_label: bool = True, + with_mask: bool = True, + with_seg: bool = True, + box_type: str = 'hbox', + imdecode_backend: str = 'cv2', + backend_args: dict = None) -> None: + try: + from panopticapi import utils + except ImportError: + raise ImportError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + self.rgb2id = utils.rgb2id + + super(LoadPanopticAnnotations, self).__init__( + with_bbox=with_bbox, + with_label=with_label, + with_mask=with_mask, + with_seg=with_seg, + with_keypoints=False, + box_type=box_type, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + + def _load_masks_and_semantic_segs(self, results: dict) -> None: + """Private function to load mask and semantic segmentation annotations. + + In gt_semantic_seg, the foreground label is from ``0`` to + ``num_things - 1``, the background label is from ``num_things`` to + ``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``). + + Args: + results (dict): Result dict from :obj:``mmdet.CustomDataset``. + """ + # seg_map_path is None, when inference on the dataset without gts. + if results.get('seg_map_path', None) is None: + return + + img_bytes = get( + results['seg_map_path'], backend_args=self.backend_args) + pan_png = mmcv.imfrombytes( + img_bytes, flag='color', channel_order='rgb').squeeze() + pan_png = self.rgb2id(pan_png) + + gt_masks = [] + gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore + + for segment_info in results['segments_info']: + mask = (pan_png == segment_info['id']) + gt_seg = np.where(mask, segment_info['category'], gt_seg) + + # The legal thing masks + if segment_info.get('is_thing'): + gt_masks.append(mask.astype(np.uint8)) + + if self.with_mask: + h, w = results['ori_shape'] + gt_masks = BitmapMasks(gt_masks, h, w) + results['gt_masks'] = gt_masks + + if self.with_seg: + results['gt_seg_map'] = gt_seg + + def transform(self, results: dict) -> dict: + """Function to load multiple types panoptic annotations. + + Args: + results (dict): Result dict from :obj:``mmdet.CustomDataset``. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + if self.with_bbox: + self._load_bboxes(results) + if self.with_label: + self._load_labels(results) + if self.with_mask or self.with_seg: + # The tasks completed by '_load_masks' and '_load_semantic_segs' + # in LoadAnnotations are merged to one function. + self._load_masks_and_semantic_segs(results) + + return results + + +@TRANSFORMS.register_module() +class LoadProposals(BaseTransform): + """Load proposal pipeline. + + Required Keys: + + - proposals + + Modified Keys: + + - proposals + + Args: + num_max_proposals (int, optional): Maximum number of proposals to load. + If not specified, all proposals will be loaded. + """ + + def __init__(self, num_max_proposals: Optional[int] = None) -> None: + self.num_max_proposals = num_max_proposals + + def transform(self, results: dict) -> dict: + """Transform function to load proposals from file. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded proposal annotations. + """ + + proposals = results['proposals'] + # the type of proposals should be `dict` or `InstanceData` + assert isinstance(proposals, dict) \ + or isinstance(proposals, BaseDataElement) + bboxes = proposals['bboxes'].astype(np.float32) + assert bboxes.shape[1] == 4, \ + f'Proposals should have shapes (n, 4), but found {bboxes.shape}' + + if 'scores' in proposals: + scores = proposals['scores'].astype(np.float32) + assert bboxes.shape[0] == scores.shape[0] + else: + scores = np.zeros(bboxes.shape[0], dtype=np.float32) + + if self.num_max_proposals is not None: + # proposals should sort by scores during dumping the proposals + bboxes = bboxes[:self.num_max_proposals] + scores = scores[:self.num_max_proposals] + + if len(bboxes) == 0: + bboxes = np.zeros((0, 4), dtype=np.float32) + scores = np.zeros(0, dtype=np.float32) + + results['proposals'] = bboxes + results['proposals_scores'] = scores + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(num_max_proposals={self.num_max_proposals})' + + +@TRANSFORMS.register_module() +class FilterAnnotations(BaseTransform): + """Filter invalid annotations. + + Required Keys: + + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_masks (optional) + - gt_ignore_flags (optional) + + Args: + min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth + boxes. Default: (1., 1.) + min_gt_mask_area (int): Minimum foreground area of ground truth masks. + Default: 1 + by_box (bool): Filter instances with bounding boxes not meeting the + min_gt_bbox_wh threshold. Default: True + by_mask (bool): Filter instances with masks not meeting + min_gt_mask_area threshold. Default: False + keep_empty (bool): Whether to return None when it + becomes an empty bbox after filtering. Defaults to True. + """ + + def __init__(self, + min_gt_bbox_wh: Tuple[int, int] = (1, 1), + min_gt_mask_area: int = 1, + by_box: bool = True, + by_mask: bool = False, + keep_empty: bool = True) -> None: + # TODO: add more filter options + assert by_box or by_mask + self.min_gt_bbox_wh = min_gt_bbox_wh + self.min_gt_mask_area = min_gt_mask_area + self.by_box = by_box + self.by_mask = by_mask + self.keep_empty = keep_empty + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to filter annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + assert 'gt_bboxes' in results + gt_bboxes = results['gt_bboxes'] + if gt_bboxes.shape[0] == 0: + return results + + tests = [] + if self.by_box: + tests.append( + ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) & + (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy()) + if self.by_mask: + assert 'gt_masks' in results + gt_masks = results['gt_masks'] + tests.append(gt_masks.areas >= self.min_gt_mask_area) + + keep = tests[0] + for t in tests[1:]: + keep = keep & t + + if not keep.any(): + if self.keep_empty: + return None + + keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags') + for key in keys: + if key in results: + results[key] = results[key][keep] + + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(min_gt_bbox_wh={self.min_gt_bbox_wh}, ' \ + f'keep_empty={self.keep_empty})' + + +@TRANSFORMS.register_module() +class LoadEmptyAnnotations(BaseTransform): + """Load Empty Annotations for unlabeled images. + + Added Keys: + - gt_bboxes (np.float32) + - gt_bboxes_labels (np.int64) + - gt_masks (BitmapMasks | PolygonMasks) + - gt_seg_map (np.uint8) + - gt_ignore_flags (bool) + + Args: + with_bbox (bool): Whether to load the pseudo bbox annotation. + Defaults to True. + with_label (bool): Whether to load the pseudo label annotation. + Defaults to True. + with_mask (bool): Whether to load the pseudo mask annotation. + Default: False. + with_seg (bool): Whether to load the pseudo semantic segmentation + annotation. Defaults to False. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + """ + + def __init__(self, + with_bbox: bool = True, + with_label: bool = True, + with_mask: bool = False, + with_seg: bool = False, + seg_ignore_label: int = 255) -> None: + self.with_bbox = with_bbox + self.with_label = with_label + self.with_mask = with_mask + self.with_seg = with_seg + self.seg_ignore_label = seg_ignore_label + + def transform(self, results: dict) -> dict: + """Transform function to load empty annotations. + + Args: + results (dict): Result dict. + Returns: + dict: Updated result dict. + """ + if self.with_bbox: + results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32) + results['gt_ignore_flags'] = np.zeros((0, ), dtype=bool) + if self.with_label: + results['gt_bboxes_labels'] = np.zeros((0, ), dtype=np.int64) + if self.with_mask: + # TODO: support PolygonMasks + h, w = results['img_shape'] + gt_masks = np.zeros((0, h, w), dtype=np.uint8) + results['gt_masks'] = BitmapMasks(gt_masks, h, w) + if self.with_seg: + h, w = results['img_shape'] + results['gt_seg_map'] = self.seg_ignore_label * np.ones( + (h, w), dtype=np.uint8) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label})' + return repr_str + + +@TRANSFORMS.register_module() +class InferencerLoader(BaseTransform): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='LoadImageFromFile', **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='mmdet.LoadImageFromNDArray', **kwargs)) + + def transform(self, results: Union[str, np.ndarray, dict]) -> dict: + """Transform function to add image meta information. + + Args: + results (str, np.ndarray or dict): The result. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if isinstance(results, str): + inputs = dict(img_path=results) + elif isinstance(results, np.ndarray): + inputs = dict(img=results) + elif isinstance(results, dict): + inputs = results + else: + raise NotImplementedError + + if 'img' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) + + +@TRANSFORMS.register_module() +class LoadTrackAnnotations(LoadAnnotations): + """Load and process the ``instances`` and ``seg_map`` annotation provided + by dataset. It must load ``instances_ids`` which is only used in the + tracking tasks. The annotation format is as the following: + + .. code-block:: python + { + 'instances': + [ + { + # List of 4 numbers representing the bounding box of the + # instance, in (x1, y1, x2, y2) order. + 'bbox': [x1, y1, x2, y2], + # Label of image classification. + 'bbox_label': 1, + # Used in tracking. + # Id of instances. + 'instance_id': 100, + # Used in instance/panoptic segmentation. The segmentation mask + # of the instance or the information of segments. + # 1. If list[list[float]], it represents a list of polygons, + # one for each connected component of the object. Each + # list[float] is one simple polygon in the format of + # [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute + # coordinates in unit of pixels. + # 2. If dict, it represents the per-pixel segmentation mask in + # COCO's compressed RLE format. The dict should have keys + # “size” and “counts”. Can be loaded by pycocotools + 'mask': list[list[float]] or dict, + } + ] + # Filename of semantic or panoptic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + .. code-block:: python + { + # In (x1, y1, x2, y2) order, float type. N is the number of bboxes + # in an image + 'gt_bboxes': np.ndarray(N, 4) + # In int type. + 'gt_bboxes_labels': np.ndarray(N, ) + # In built-in class + 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W) + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + # in (x, y, v) order, float type. + } + + Required Keys: + + - height (optional) + - width (optional) + - instances + - bbox (optional) + - bbox_label + - instance_id (optional) + - mask (optional) + - ignore_flag (optional) + - seg_map_path (optional) + + Added Keys: + + - gt_bboxes (np.float32) + - gt_bboxes_labels (np.int32) + - gt_instances_ids (np.int32) + - gt_masks (BitmapMasks | PolygonMasks) + - gt_seg_map (np.uint8) + - gt_ignore_flags (np.bool) + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def _load_bboxes(self, results: dict) -> None: + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + gt_bboxes = [] + gt_ignore_flags = [] + # TODO: use bbox_type + for instance in results['instances']: + # The datasets which are only format in evaluation don't have + # groundtruth boxes. + if 'bbox' in instance: + gt_bboxes.append(instance['bbox']) + if 'ignore_flag' in instance: + gt_ignore_flags.append(instance['ignore_flag']) + + # TODO: check this case + if len(gt_bboxes) != len(gt_ignore_flags): + # There may be no ``gt_ignore_flags`` in some cases, we treat them + # as all False in order to keep the length of ``gt_bboxes`` and + # ``gt_ignore_flags`` the same + gt_ignore_flags = [False] * len(gt_bboxes) + + results['gt_bboxes'] = np.array( + gt_bboxes, dtype=np.float32).reshape(-1, 4) + results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool) + + def _load_instances_ids(self, results: dict) -> None: + """Private function to load instances id annotations. + + Args: + results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict containing instances id annotations. + """ + gt_instances_ids = [] + for instance in results['instances']: + gt_instances_ids.append(instance['instance_id']) + results['gt_instances_ids'] = np.array( + gt_instances_ids, dtype=np.int32) + + def transform(self, results: dict) -> dict: + """Function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded bounding box, label, instances id + and semantic segmentation and keypoints annotations. + """ + results = super().transform(results) + self._load_instances_ids(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg}, ' + repr_str += f'poly2mask={self.poly2mask}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'file_client_args={self.file_client_args})' + return repr_str diff --git a/mmdet/datasets/transforms/transformers_glip.py b/mmdet/datasets/transforms/transformers_glip.py new file mode 100644 index 0000000000000000000000000000000000000000..60c4f87d1b86c13f886da27584114b6420b8b8cb --- /dev/null +++ b/mmdet/datasets/transforms/transformers_glip.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import HorizontalBoxes, autocast_box_type +from .transforms import RandomFlip + + +@TRANSFORMS.register_module() +class GTBoxSubOne_GLIP(BaseTransform): + """Subtract 1 from the x2 and y2 coordinates of the gt_bboxes.""" + + def transform(self, results: dict) -> dict: + if 'gt_bboxes' in results: + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, np.ndarray): + gt_bboxes[:, 2:] -= 1 + results['gt_bboxes'] = gt_bboxes + elif isinstance(gt_bboxes, HorizontalBoxes): + gt_bboxes = results['gt_bboxes'].tensor + gt_bboxes[:, 2:] -= 1 + results['gt_bboxes'] = HorizontalBoxes(gt_bboxes) + else: + raise NotImplementedError + return results + + +@TRANSFORMS.register_module() +class RandomFlip_GLIP(RandomFlip): + """Flip the image & bboxes & masks & segs horizontally or vertically. + + When using horizontal flipping, the corresponding bbox x-coordinate needs + to be additionally subtracted by one. + """ + + @autocast_box_type() + def _flip(self, results: dict) -> None: + """Flip images, bounding boxes, and semantic segmentation map.""" + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + img_shape = results['img'].shape[:2] + + # flip bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'].flip_(img_shape, results['flip_direction']) + # Only change this line + if results['flip_direction'] == 'horizontal': + results['gt_bboxes'].translate_([-1, 0]) + + # TODO: check it + # flip masks + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'].flip( + results['flip_direction']) + + # flip segs + if results.get('gt_seg_map', None) is not None: + results['gt_seg_map'] = mmcv.imflip( + results['gt_seg_map'], direction=results['flip_direction']) + + # record homography matrix for flip + self._record_homography_matrix(results) diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac2bf75b5435bc220b12f369e69acf172492df4 --- /dev/null +++ b/mmdet/datasets/transforms/transforms.py @@ -0,0 +1,3854 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +import math +import warnings +from typing import List, Optional, Sequence, Tuple, Union + +import cv2 +import mmcv +import numpy as np +from mmcv.image import imresize +from mmcv.image.geometric import _scale_size +from mmcv.transforms import BaseTransform +from mmcv.transforms import Pad as MMCV_Pad +from mmcv.transforms import RandomFlip as MMCV_RandomFlip +from mmcv.transforms import Resize as MMCV_Resize +from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness +from mmengine.dataset import BaseDataset +from mmengine.utils import is_str +from numpy import random + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import HorizontalBoxes, autocast_box_type +from mmdet.structures.mask import BitmapMasks, PolygonMasks +from mmdet.utils import log_img_scale + +try: + from imagecorruptions import corrupt +except ImportError: + corrupt = None + +try: + import albumentations + from albumentations import Compose +except ImportError: + albumentations = None + Compose = None + +Number = Union[int, float] + + +def _fixed_scale_size( + size: Tuple[int, int], + scale: Union[float, int, tuple], +) -> Tuple[int, int]: + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + # don't need o.5 offset + return int(w * float(scale[0])), int(h * float(scale[1])) + + +def rescale_size(old_size: tuple, + scale: Union[float, int, tuple], + return_scale: bool = False) -> tuple: + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f'Invalid scale {scale}, must be positive.') + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError( + f'Scale must be a number or tuple of int, but got {type(scale)}') + # only change this + new_size = _fixed_scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale( + img: np.ndarray, + scale: Union[float, Tuple[int, int]], + return_scale: bool = False, + interpolation: str = 'bilinear', + backend: Optional[str] = None +) -> Union[np.ndarray, Tuple[np.ndarray, float]]: + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize( + img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +@TRANSFORMS.register_module() +class Resize(MMCV_Resize): + """Resize images & bbox & seg. + + This transform resizes the input image according to ``scale`` or + ``scale_factor``. Bboxes, masks, and seg map are then resized + with the same scale factor. + if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to + resize. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_masks + - gt_seg_map + + + Added Keys: + + - scale + - scale_factor + - keep_ratio + - homography_matrix + + Args: + scale (int or tuple): Images scales for resizing. Defaults to None + scale_factor (float or tuple[float]): Scale factors for resizing. + Defaults to None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Defaults to False. + clip_object_border (bool): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, the gt + bboxes are allowed to cross the border of images. Therefore, we + don't need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def _resize_masks(self, results: dict) -> None: + """Resize masks with ``results['scale']``""" + if results.get('gt_masks', None) is not None: + if self.keep_ratio: + results['gt_masks'] = results['gt_masks'].rescale( + results['scale']) + else: + results['gt_masks'] = results['gt_masks'].resize( + results['img_shape']) + + def _resize_bboxes(self, results: dict) -> None: + """Resize bounding boxes with ``results['scale_factor']``.""" + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'].rescale_(results['scale_factor']) + if self.clip_object_border: + results['gt_bboxes'].clip_(results['img_shape']) + + def _record_homography_matrix(self, results: dict) -> None: + """Record the homography matrix for the Resize.""" + w_scale, h_scale = results['scale_factor'] + homography_matrix = np.array( + [[w_scale, 0, 0], [0, h_scale, 0], [0, 0, 1]], dtype=np.float32) + if results.get('homography_matrix', None) is None: + results['homography_matrix'] = homography_matrix + else: + results['homography_matrix'] = homography_matrix @ results[ + 'homography_matrix'] + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to resize images, bounding boxes and semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map', + 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys + are updated in result dict. + """ + if self.scale: + results['scale'] = self.scale + else: + img_shape = results['img'].shape[:2] + results['scale'] = _scale_size(img_shape[::-1], self.scale_factor) + self._resize_img(results) + self._resize_bboxes(results) + self._resize_masks(results) + self._resize_seg(results) + self._record_homography_matrix(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'scale_factor={self.scale_factor}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'clip_object_border={self.clip_object_border}), ' + repr_str += f'backend={self.backend}), ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class FixScaleResize(Resize): + """Compared to Resize, FixScaleResize fixes the scaling issue when + `keep_ratio=true`.""" + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + if results.get('img', None) is not None: + if self.keep_ratio: + img, scale_factor = imrescale( + results['img'], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results['img'], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape[:2] + results['scale_factor'] = (w_scale, h_scale) + results['keep_ratio'] = self.keep_ratio + + +@TRANSFORMS.register_module() +class ResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + - img + - gt_seg_map (optional) + Modified Keys: + - img + - img_shape + - gt_seg_map (optional)) + Added Keys: + - scale + - scale_factor + - keep_ratio + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, + scale: Union[int, Tuple[int, int]], + max_size: Optional[int] = None, + resize_type: str = 'Resize', + **resize_kwargs) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + self.resize_cfg = dict(type=resize_type, **resize_kwargs) + self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg}) + + def _get_output_shape( + self, img: np.ndarray, + short_edge_length: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if self.max_size and max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return new_w, new_h + + def transform(self, results: dict) -> dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + +@TRANSFORMS.register_module() +class FixShapeResize(Resize): + """Resize images & bbox & seg to the specified size. + + This transform resizes the input image according to ``width`` and + ``height``. Bboxes, masks, and seg map are then resized + with the same parameters. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_masks + - gt_seg_map + + + Added Keys: + + - scale + - scale_factor + - keep_ratio + - homography_matrix + + Args: + width (int): width for resizing. + height (int): height for resizing. + Defaults to None. + pad_val (Number | dict[str, Number], optional): Padding value for if + the pad_mode is "constant". If it is a single number, the value + to pad the image is the number and to pad the semantic + segmentation map is 255. If it is a dict, it should have the + following keys: + + - img: The value to pad the image. + - seg: The value to pad the semantic segmentation map. + Defaults to dict(img=0, seg=255). + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Defaults to False. + clip_object_border (bool): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, the gt + bboxes are allowed to cross the border of images. Therefore, we + don't need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + width: int, + height: int, + pad_val: Union[Number, dict] = dict(img=0, seg=255), + keep_ratio: bool = False, + clip_object_border: bool = True, + backend: str = 'cv2', + interpolation: str = 'bilinear') -> None: + assert width is not None and height is not None, ( + '`width` and' + '`height` can not be `None`') + + self.width = width + self.height = height + self.scale = (width, height) + + self.backend = backend + self.interpolation = interpolation + self.keep_ratio = keep_ratio + self.clip_object_border = clip_object_border + + if keep_ratio is True: + # padding to the fixed size when keep_ratio=True + self.pad_transform = Pad(size=self.scale, pad_val=pad_val) + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to resize images, bounding boxes and semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map', + 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys + are updated in result dict. + """ + img = results['img'] + h, w = img.shape[:2] + if self.keep_ratio: + scale_factor = min(self.width / w, self.height / h) + results['scale_factor'] = (scale_factor, scale_factor) + real_w, real_h = int(w * float(scale_factor) + + 0.5), int(h * float(scale_factor) + 0.5) + img, scale_factor = mmcv.imrescale( + results['img'], (real_w, real_h), + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + results['img'] = img + results['img_shape'] = img.shape[:2] + results['keep_ratio'] = self.keep_ratio + results['scale'] = (real_w, real_h) + else: + results['scale'] = (self.width, self.height) + results['scale_factor'] = (self.width / w, self.height / h) + super()._resize_img(results) + + self._resize_bboxes(results) + self._resize_masks(results) + self._resize_seg(results) + self._record_homography_matrix(results) + if self.keep_ratio: + self.pad_transform(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(width={self.width}, height={self.height}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'clip_object_border={self.clip_object_border}), ' + repr_str += f'backend={self.backend}), ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomFlip(MMCV_RandomFlip): + """Flip the image & bbox & mask & segmentation map. Added or Updated keys: + flip, flip_direction, img, gt_bboxes, and gt_seg_map. There are 3 flip + modes: + + - ``prob`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``prob`` . + E.g., ``prob=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + - ``prob`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``prob/len(direction)``. + E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + - ``prob`` is list of float, ``direction`` is list of string: + given ``len(prob) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``prob[i]``. + E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with + probability of 0.3, vertically with probability of 0.5. + + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_masks + - gt_seg_map + + Added Keys: + + - flip + - flip_direction + - homography_matrix + + + Args: + prob (float | list[float], optional): The flipping probability. + Defaults to None. + direction(str | list[str]): The flipping direction. Options + If input is a list, the length must equal ``prob``. Each + element in ``prob`` indicates the flip probability of + corresponding direction. Defaults to 'horizontal'. + """ + + def _record_homography_matrix(self, results: dict) -> None: + """Record the homography matrix for the RandomFlip.""" + cur_dir = results['flip_direction'] + h, w = results['img'].shape[:2] + + if cur_dir == 'horizontal': + homography_matrix = np.array([[-1, 0, w], [0, 1, 0], [0, 0, 1]], + dtype=np.float32) + elif cur_dir == 'vertical': + homography_matrix = np.array([[1, 0, 0], [0, -1, h], [0, 0, 1]], + dtype=np.float32) + elif cur_dir == 'diagonal': + homography_matrix = np.array([[-1, 0, w], [0, -1, h], [0, 0, 1]], + dtype=np.float32) + else: + homography_matrix = np.eye(3, dtype=np.float32) + + if results.get('homography_matrix', None) is None: + results['homography_matrix'] = homography_matrix + else: + results['homography_matrix'] = homography_matrix @ results[ + 'homography_matrix'] + + @autocast_box_type() + def _flip(self, results: dict) -> None: + """Flip images, bounding boxes, and semantic segmentation map.""" + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + img_shape = results['img'].shape[:2] + + # flip bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'].flip_(img_shape, results['flip_direction']) + + # flip masks + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'].flip( + results['flip_direction']) + + # flip segs + if results.get('gt_seg_map', None) is not None: + results['gt_seg_map'] = mmcv.imflip( + results['gt_seg_map'], direction=results['flip_direction']) + + # record homography matrix for flip + self._record_homography_matrix(results) + + +@TRANSFORMS.register_module() +class RandomShift(BaseTransform): + """Shift the image and box given shift pixels and probability. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) + - gt_bboxes_labels (np.int64) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - img + - gt_bboxes + - gt_bboxes_labels + - gt_ignore_flags (bool) (optional) + + Args: + prob (float): Probability of shifts. Defaults to 0.5. + max_shift_px (int): The max pixels for shifting. Defaults to 32. + filter_thr_px (int): The width and height threshold for filtering. + The bbox and the rest of the targets below the width and + height threshold will be filtered. Defaults to 1. + """ + + def __init__(self, + prob: float = 0.5, + max_shift_px: int = 32, + filter_thr_px: int = 1) -> None: + assert 0 <= prob <= 1 + assert max_shift_px >= 0 + self.prob = prob + self.max_shift_px = max_shift_px + self.filter_thr_px = int(filter_thr_px) + + @cache_randomness + def _random_prob(self) -> float: + return random.uniform(0, 1) + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to random shift images, bounding boxes. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Shift results. + """ + if self._random_prob() < self.prob: + img_shape = results['img'].shape[:2] + + random_shift_x = random.randint(-self.max_shift_px, + self.max_shift_px) + random_shift_y = random.randint(-self.max_shift_px, + self.max_shift_px) + new_x = max(0, random_shift_x) + ori_x = max(0, -random_shift_x) + new_y = max(0, random_shift_y) + ori_y = max(0, -random_shift_y) + + # TODO: support mask and semantic segmentation maps. + bboxes = results['gt_bboxes'].clone() + bboxes.translate_([random_shift_x, random_shift_y]) + + # clip border + bboxes.clip_(img_shape) + + # remove invalid bboxes + valid_inds = (bboxes.widths > self.filter_thr_px).numpy() & ( + bboxes.heights > self.filter_thr_px).numpy() + # If the shift does not contain any gt-bbox area, skip this + # image. + if not valid_inds.any(): + return results + bboxes = bboxes[valid_inds] + results['gt_bboxes'] = bboxes + results['gt_bboxes_labels'] = results['gt_bboxes_labels'][ + valid_inds] + + if results.get('gt_ignore_flags', None) is not None: + results['gt_ignore_flags'] = \ + results['gt_ignore_flags'][valid_inds] + + # shift img + img = results['img'] + new_img = np.zeros_like(img) + img_h, img_w = img.shape[:2] + new_h = img_h - np.abs(random_shift_y) + new_w = img_w - np.abs(random_shift_x) + new_img[new_y:new_y + new_h, new_x:new_x + new_w] \ + = img[ori_y:ori_y + new_h, ori_x:ori_x + new_w] + results['img'] = new_img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'max_shift_px={self.max_shift_px}, ' + repr_str += f'filter_thr_px={self.filter_thr_px})' + return repr_str + + +@TRANSFORMS.register_module() +class Pad(MMCV_Pad): + """Pad the image & segmentation map. + + There are three padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. and (3)pad to square. Also, + pad to square and pad to the minimum size can be used as the same time. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_masks + - gt_seg_map + + Added Keys: + + - pad_shape + - pad_fixed_size + - pad_size_divisor + + Args: + size (tuple, optional): Fixed padding size. + Expected padding shape (width, height). Defaults to None. + size_divisor (int, optional): The divisor of padded size. Defaults to + None. + pad_to_square (bool): Whether to pad the image into a square. + Currently only used for YOLOX. Defaults to False. + pad_val (Number | dict[str, Number], optional) - Padding value for if + the pad_mode is "constant". If it is a single number, the value + to pad the image is the number and to pad the semantic + segmentation map is 255. If it is a dict, it should have the + following keys: + + - img: The value to pad the image. + - seg: The value to pad the semantic segmentation map. + Defaults to dict(img=0, seg=255). + padding_mode (str): Type of padding. Should be: constant, edge, + reflect or symmetric. Defaults to 'constant'. + + - constant: pads with a constant value, this value is specified + with pad_val. + - edge: pads with the last value at the edge of the image. + - reflect: pads with reflection of image without repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with 2 + elements on both sides in reflect mode will result in + [3, 2, 1, 2, 3, 4, 3, 2]. + - symmetric: pads with reflection of image repeating the last value + on the edge. For example, padding [1, 2, 3, 4] with 2 elements on + both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def _pad_masks(self, results: dict) -> None: + """Pad masks according to ``results['pad_shape']``.""" + if results.get('gt_masks', None) is not None: + pad_val = self.pad_val.get('masks', 0) + pad_shape = results['pad_shape'][:2] + results['gt_masks'] = results['gt_masks'].pad( + pad_shape, pad_val=pad_val) + + def transform(self, results: dict) -> dict: + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_seg(results) + self._pad_masks(results) + return results + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Random crop the image & bboxes & masks. + + The absolute ``crop_size`` is sampled based on ``crop_type`` and + ``image_size``, then the cropped results are generated. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_masks (optional) + - gt_ignore_flags (optional) + - gt_seg_map (optional) + - gt_instances_ids (options, only used in MOT/VIS) + + Added Keys: + + - homography_matrix + + Args: + crop_size (tuple): The relative ratio or absolute pixels of + (width, height). + crop_type (str, optional): One of "relative_range", "relative", + "absolute", "absolute_range". "relative" randomly crops + (h * crop_size[0], w * crop_size[1]) part from an input of size + (h, w). "relative_range" uniformly samples relative crop size from + range [crop_size[0], 1] and [crop_size[1], 1] for height and width + respectively. "absolute" crops from an input with absolute size + (crop_size[0], crop_size[1]). "absolute_range" uniformly samples + crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w + in range [crop_size[0], min(w, crop_size[1])]. + Defaults to "absolute". + allow_negative_crop (bool, optional): Whether to allow a crop that does + not contain any bbox area. Defaults to False. + recompute_bbox (bool, optional): Whether to re-compute the boxes based + on cropped instance masks. Defaults to False. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + + Note: + - If the image is smaller than the absolute crop size, return the + original image. + - The keys for bboxes, labels and masks must be aligned. That is, + ``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and + ``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and + ``gt_masks_ignore``. + - If the crop does not contain any gt-bbox region and + ``allow_negative_crop`` is set to False, skip this image. + """ + + def __init__(self, + crop_size: tuple, + crop_type: str = 'absolute', + allow_negative_crop: bool = False, + recompute_bbox: bool = False, + bbox_clip_border: bool = True) -> None: + if crop_type not in [ + 'relative_range', 'relative', 'absolute', 'absolute_range' + ]: + raise ValueError(f'Invalid crop_type {crop_type}.') + if crop_type in ['absolute', 'absolute_range']: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance( + crop_size[1], int) + if crop_type == 'absolute_range': + assert crop_size[0] <= crop_size[1] + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.allow_negative_crop = allow_negative_crop + self.bbox_clip_border = bbox_clip_border + self.recompute_bbox = recompute_bbox + + def _crop_data(self, results: dict, crop_size: Tuple[int, int], + allow_negative_crop: bool) -> Union[dict, None]: + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (Tuple[int, int]): Expected absolute size after + cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. + + Returns: + results (Union[dict, None]): Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. None will + be returned when there is no valid bbox after cropping. + """ + assert crop_size[0] > 0 and crop_size[1] > 0 + img = results['img'] + margin_h = max(img.shape[0] - crop_size[0], 0) + margin_w = max(img.shape[1] - crop_size[1], 0) + offset_h, offset_w = self._rand_offset((margin_h, margin_w)) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + + # Record the homography matrix for the RandomCrop + homography_matrix = np.array( + [[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]], + dtype=np.float32) + if results.get('homography_matrix', None) is None: + results['homography_matrix'] = homography_matrix + else: + results['homography_matrix'] = homography_matrix @ results[ + 'homography_matrix'] + + # crop the image + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape[:2] + + # crop bboxes accordingly and clip to the image boundary + if results.get('gt_bboxes', None) is not None: + bboxes = results['gt_bboxes'] + bboxes.translate_([-offset_w, -offset_h]) + if self.bbox_clip_border: + bboxes.clip_(img_shape[:2]) + valid_inds = bboxes.is_inside(img_shape[:2]).numpy() + # If the crop does not contain any gt-bbox area and + # allow_negative_crop is False, skip this image. + if (not valid_inds.any() and not allow_negative_crop): + return None + + results['gt_bboxes'] = bboxes[valid_inds] + + if results.get('gt_ignore_flags', None) is not None: + results['gt_ignore_flags'] = \ + results['gt_ignore_flags'][valid_inds] + + if results.get('gt_bboxes_labels', None) is not None: + results['gt_bboxes_labels'] = \ + results['gt_bboxes_labels'][valid_inds] + + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'][ + valid_inds.nonzero()[0]].crop( + np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) + if self.recompute_bbox: + results['gt_bboxes'] = results['gt_masks'].get_bboxes( + type(results['gt_bboxes'])) + + # We should remove the instance ids corresponding to invalid boxes. + if results.get('gt_instances_ids', None) is not None: + results['gt_instances_ids'] = \ + results['gt_instances_ids'][valid_inds] + + # crop semantic seg + if results.get('gt_seg_map', None) is not None: + results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2, + crop_x1:crop_x2] + + return results + + @cache_randomness + def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]: + """Randomly generate crop offset. + + Args: + margin (Tuple[int, int]): The upper bound for the offset generated + randomly. + + Returns: + Tuple[int, int]: The random offset for the crop. + """ + margin_h, margin_w = margin + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + + return offset_h, offset_w + + @cache_randomness + def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]: + """Randomly generates the absolute crop size based on `crop_type` and + `image_size`. + + Args: + image_size (Tuple[int, int]): (h, w). + + Returns: + crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels. + """ + h, w = image_size + if self.crop_type == 'absolute': + return min(self.crop_size[1], h), min(self.crop_size[0], w) + elif self.crop_type == 'absolute_range': + crop_h = np.random.randint( + min(h, self.crop_size[0]), + min(h, self.crop_size[1]) + 1) + crop_w = np.random.randint( + min(w, self.crop_size[0]), + min(w, self.crop_size[1]) + 1) + return crop_h, crop_w + elif self.crop_type == 'relative': + crop_w, crop_h = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + else: + # 'relative_range' + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + results (Union[dict, None]): Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. None will + be returned when there is no valid bbox after cropping. + """ + image_size = results['img'].shape[:2] + crop_size = self._get_crop_size(image_size) + results = self._crop_data(results, crop_size, self.allow_negative_crop) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(crop_size={self.crop_size}, ' + repr_str += f'crop_type={self.crop_type}, ' + repr_str += f'allow_negative_crop={self.allow_negative_crop}, ' + repr_str += f'recompute_bbox={self.recompute_bbox}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@TRANSFORMS.register_module() +class SegRescale(BaseTransform): + """Rescale semantic segmentation maps. + + This transform rescale the ``gt_seg_map`` according to ``scale_factor``. + + Required Keys: + + - gt_seg_map + + Modified Keys: + + - gt_seg_map + + Args: + scale_factor (float): The scale factor of the final output. Defaults + to 1. + backend (str): Image rescale backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + """ + + def __init__(self, scale_factor: float = 1, backend: str = 'cv2') -> None: + self.scale_factor = scale_factor + self.backend = backend + + def transform(self, results: dict) -> dict: + """Transform function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + if self.scale_factor != 1: + results['gt_seg_map'] = mmcv.imrescale( + results['gt_seg_map'], + self.scale_factor, + interpolation='nearest', + backend=self.backend) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(scale_factor={self.scale_factor}, ' + repr_str += f'backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class PhotoMetricDistortion(BaseTransform): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + + Required Keys: + + - img (np.uint8) + + Modified Keys: + + - img (np.float32) + + Args: + brightness_delta (int): delta of brightness. + contrast_range (sequence): range of contrast. + saturation_range (sequence): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta: int = 32, + contrast_range: Sequence[Number] = (0.5, 1.5), + saturation_range: Sequence[Number] = (0.5, 1.5), + hue_delta: int = 18) -> None: + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + @cache_randomness + def _random_flags(self) -> Sequence[Number]: + mode = random.randint(2) + brightness_flag = random.randint(2) + contrast_flag = random.randint(2) + saturation_flag = random.randint(2) + hue_flag = random.randint(2) + swap_flag = random.randint(2) + delta_value = random.uniform(-self.brightness_delta, + self.brightness_delta) + alpha_value = random.uniform(self.contrast_lower, self.contrast_upper) + saturation_value = random.uniform(self.saturation_lower, + self.saturation_upper) + hue_value = random.uniform(-self.hue_delta, self.hue_delta) + swap_value = random.permutation(3) + + return (mode, brightness_flag, contrast_flag, saturation_flag, + hue_flag, swap_flag, delta_value, alpha_value, + saturation_value, hue_value, swap_value) + + def transform(self, results: dict) -> dict: + """Transform function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + assert 'img' in results, '`img` is not found in results' + img = results['img'] + img = img.astype(np.float32) + + (mode, brightness_flag, contrast_flag, saturation_flag, hue_flag, + swap_flag, delta_value, alpha_value, saturation_value, hue_value, + swap_value) = self._random_flags() + + # random brightness + if brightness_flag: + img += delta_value + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + if mode == 1: + if contrast_flag: + img *= alpha_value + + # convert color from BGR to HSV + img = mmcv.bgr2hsv(img) + + # random saturation + if saturation_flag: + img[..., 1] *= saturation_value + # For image(type=float32), after convert bgr to hsv by opencv, + # valid saturation value range is [0, 1] + if saturation_value > 1: + img[..., 1] = img[..., 1].clip(0, 1) + + # random hue + if hue_flag: + img[..., 0] += hue_value + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + + # convert color from HSV to BGR + img = mmcv.hsv2bgr(img) + + # random contrast + if mode == 0: + if contrast_flag: + img *= alpha_value + + # randomly swap channels + if swap_flag: + img = img[..., swap_value] + + results['img'] = img + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(brightness_delta={self.brightness_delta}, ' + repr_str += 'contrast_range=' + repr_str += f'{(self.contrast_lower, self.contrast_upper)}, ' + repr_str += 'saturation_range=' + repr_str += f'{(self.saturation_lower, self.saturation_upper)}, ' + repr_str += f'hue_delta={self.hue_delta})' + return repr_str + + +@TRANSFORMS.register_module() +class Expand(BaseTransform): + """Random expand the image & bboxes & masks & segmentation map. + + Randomly place the original image on a canvas of ``ratio`` x original image + size filled with mean values. The ratio is in the range of ratio_range. + + Required Keys: + + - img + - img_shape + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_masks + - gt_seg_map + + + Args: + mean (sequence): mean value of dataset. + to_rgb (bool): if need to convert the order of mean to align with RGB. + ratio_range (sequence)): range of expand ratio. + seg_ignore_label (int): label of ignore segmentation map. + prob (float): probability of applying this transformation + """ + + def __init__(self, + mean: Sequence[Number] = (0, 0, 0), + to_rgb: bool = True, + ratio_range: Sequence[Number] = (1, 4), + seg_ignore_label: int = None, + prob: float = 0.5) -> None: + self.to_rgb = to_rgb + self.ratio_range = ratio_range + if to_rgb: + self.mean = mean[::-1] + else: + self.mean = mean + self.min_ratio, self.max_ratio = ratio_range + self.seg_ignore_label = seg_ignore_label + self.prob = prob + + @cache_randomness + def _random_prob(self) -> float: + return random.uniform(0, 1) + + @cache_randomness + def _random_ratio(self) -> float: + return random.uniform(self.min_ratio, self.max_ratio) + + @cache_randomness + def _random_left_top(self, ratio: float, h: int, + w: int) -> Tuple[int, int]: + left = int(random.uniform(0, w * ratio - w)) + top = int(random.uniform(0, h * ratio - h)) + return left, top + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to expand images, bounding boxes, masks, + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images, bounding boxes, masks, segmentation + map expanded. + """ + if self._random_prob() > self.prob: + return results + assert 'img' in results, '`img` is not found in results' + img = results['img'] + h, w, c = img.shape + ratio = self._random_ratio() + # speedup expand when meets large image + if np.all(self.mean == self.mean[0]): + expand_img = np.empty((int(h * ratio), int(w * ratio), c), + img.dtype) + expand_img.fill(self.mean[0]) + else: + expand_img = np.full((int(h * ratio), int(w * ratio), c), + self.mean, + dtype=img.dtype) + left, top = self._random_left_top(ratio, h, w) + expand_img[top:top + h, left:left + w] = img + results['img'] = expand_img + results['img_shape'] = expand_img.shape[:2] + + # expand bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'].translate_([left, top]) + + # expand masks + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'].expand( + int(h * ratio), int(w * ratio), top, left) + + # expand segmentation map + if results.get('gt_seg_map', None) is not None: + gt_seg = results['gt_seg_map'] + expand_gt_seg = np.full((int(h * ratio), int(w * ratio)), + self.seg_ignore_label, + dtype=gt_seg.dtype) + expand_gt_seg[top:top + h, left:left + w] = gt_seg + results['gt_seg_map'] = expand_gt_seg + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' + repr_str += f'prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class MinIoURandomCrop(BaseTransform): + """Random crop the image & bboxes & masks & segmentation map, the cropped + patches have minimum IoU requirement with original image & bboxes & masks. + + & segmentation map, the IoU threshold is randomly selected from min_ious. + + + Required Keys: + + - img + - img_shape + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_bboxes_labels + - gt_masks + - gt_ignore_flags + - gt_seg_map + + + Args: + min_ious (Sequence[float]): minimum IoU threshold for all intersections + with bounding boxes. + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size). + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + """ + + def __init__(self, + min_ious: Sequence[float] = (0.1, 0.3, 0.5, 0.7, 0.9), + min_crop_size: float = 0.3, + bbox_clip_border: bool = True) -> None: + + self.min_ious = min_ious + self.sample_mode = (1, *min_ious, 0) + self.min_crop_size = min_crop_size + self.bbox_clip_border = bbox_clip_border + + @cache_randomness + def _random_mode(self) -> Number: + return random.choice(self.sample_mode) + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to crop images and bounding boxes with minimum + IoU constraint. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images and bounding boxes cropped, \ + 'img_shape' key is updated. + """ + assert 'img' in results, '`img` is not found in results' + assert 'gt_bboxes' in results, '`gt_bboxes` is not found in results' + img = results['img'] + boxes = results['gt_bboxes'] + h, w, c = img.shape + while True: + mode = self._random_mode() + self.mode = mode + if mode == 1: + return results + + min_iou = self.mode + for i in range(50): + new_w = random.uniform(self.min_crop_size * w, w) + new_h = random.uniform(self.min_crop_size * h, h) + + # h / w in [0.5, 2] + if new_h / new_w < 0.5 or new_h / new_w > 2: + continue + + left = random.uniform(w - new_w) + top = random.uniform(h - new_h) + + patch = np.array( + (int(left), int(top), int(left + new_w), int(top + new_h))) + # Line or point crop is not allowed + if patch[2] == patch[0] or patch[3] == patch[1]: + continue + overlaps = boxes.overlaps( + HorizontalBoxes(patch.reshape(-1, 4).astype(np.float32)), + boxes).numpy().reshape(-1) + if len(overlaps) > 0 and overlaps.min() < min_iou: + continue + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + if len(overlaps) > 0: + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + centers = boxes.centers.numpy() + mask = ((centers[:, 0] > patch[0]) * + (centers[:, 1] > patch[1]) * + (centers[:, 0] < patch[2]) * + (centers[:, 1] < patch[3])) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + if results.get('gt_bboxes', None) is not None: + boxes = results['gt_bboxes'] + mask = is_center_of_bboxes_in_patch(boxes, patch) + boxes = boxes[mask] + boxes.translate_([-patch[0], -patch[1]]) + if self.bbox_clip_border: + boxes.clip_( + [patch[3] - patch[1], patch[2] - patch[0]]) + results['gt_bboxes'] = boxes + + # ignore_flags + if results.get('gt_ignore_flags', None) is not None: + results['gt_ignore_flags'] = \ + results['gt_ignore_flags'][mask] + + # labels + if results.get('gt_bboxes_labels', None) is not None: + results['gt_bboxes_labels'] = results[ + 'gt_bboxes_labels'][mask] + + # mask fields + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'][ + mask.nonzero()[0]].crop(patch) + # adjust the img no matter whether the gt is empty before crop + img = img[patch[1]:patch[3], patch[0]:patch[2]] + results['img'] = img + results['img_shape'] = img.shape[:2] + + # seg fields + if results.get('gt_seg_map', None) is not None: + results['gt_seg_map'] = results['gt_seg_map'][ + patch[1]:patch[3], patch[0]:patch[2]] + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(min_ious={self.min_ious}, ' + repr_str += f'min_crop_size={self.min_crop_size}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@TRANSFORMS.register_module() +class Corrupt(BaseTransform): + """Corruption augmentation. + + Corruption transforms implemented based on + `imagecorruptions `_. + + Required Keys: + + - img (np.uint8) + + + Modified Keys: + + - img (np.uint8) + + + Args: + corruption (str): Corruption name. + severity (int): The severity of corruption. Defaults to 1. + """ + + def __init__(self, corruption: str, severity: int = 1) -> None: + self.corruption = corruption + self.severity = severity + + def transform(self, results: dict) -> dict: + """Call function to corrupt image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images corrupted. + """ + + if corrupt is None: + raise RuntimeError('imagecorruptions is not installed') + results['img'] = corrupt( + results['img'].astype(np.uint8), + corruption_name=self.corruption, + severity=self.severity) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(corruption={self.corruption}, ' + repr_str += f'severity={self.severity})' + return repr_str + + +@TRANSFORMS.register_module() +@avoid_cache_randomness +class Albu(BaseTransform): + """Albumentation augmentation. + + Adds custom transformations from Albumentations library. + Please, visit `https://albumentations.readthedocs.io` + to get more information. + + Required Keys: + + - img (np.uint8) + - gt_bboxes (HorizontalBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + + Modified Keys: + + - img (np.uint8) + - gt_bboxes (HorizontalBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - img_shape (tuple) + + An example of ``transforms`` is as followed: + + .. code-block:: + + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + + Args: + transforms (list[dict]): A list of albu transformations + bbox_params (dict, optional): Bbox_params for albumentation `Compose` + keymap (dict, optional): Contains + {'input key':'albumentation-style key'} + skip_img_without_anno (bool): Whether to skip the image if no ann left + after aug. Defaults to False. + """ + + def __init__(self, + transforms: List[dict], + bbox_params: Optional[dict] = None, + keymap: Optional[dict] = None, + skip_img_without_anno: bool = False) -> None: + if Compose is None: + raise RuntimeError('albumentations is not installed') + + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + if bbox_params is not None: + bbox_params = copy.deepcopy(bbox_params) + if keymap is not None: + keymap = copy.deepcopy(keymap) + self.transforms = transforms + self.filter_lost_elements = False + self.skip_img_without_anno = skip_img_without_anno + + # A simple workaround to remove masks without boxes + if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params + and 'filter_lost_elements' in bbox_params): + self.filter_lost_elements = True + self.origin_label_fields = bbox_params['label_fields'] + bbox_params['label_fields'] = ['idx_mapper'] + del bbox_params['filter_lost_elements'] + + self.bbox_params = ( + self.albu_builder(bbox_params) if bbox_params else None) + self.aug = Compose([self.albu_builder(t) for t in self.transforms], + bbox_params=self.bbox_params) + + if not keymap: + self.keymap_to_albu = { + 'img': 'image', + 'gt_masks': 'masks', + 'gt_bboxes': 'bboxes' + } + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: dict) -> albumentations: + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + obj_type = args.pop('type') + if is_str(obj_type): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d: dict, keymap: dict) -> dict: + """Dictionary mapper. Renames keys according to keymap provided. + + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function of Albu.""" + # TODO: gt_seg_map is not currently supported + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + results, ori_masks = self._preprocess_results(results) + results = self.aug(**results) + results = self._postprocess_results(results, ori_masks) + if results is None: + return None + # back to the original format + results = self.mapper(results, self.keymap_back) + results['img_shape'] = results['img'].shape[:2] + return results + + def _preprocess_results(self, results: dict) -> tuple: + """Pre-processing results to facilitate the use of Albu.""" + if 'bboxes' in results: + # to list of boxes + if not isinstance(results['bboxes'], HorizontalBoxes): + raise NotImplementedError( + 'Albu only supports horizontal boxes now') + bboxes = results['bboxes'].numpy() + results['bboxes'] = [x for x in bboxes] + # add pseudo-field for filtration + if self.filter_lost_elements: + results['idx_mapper'] = np.arange(len(results['bboxes'])) + + # TODO: Support mask structure in albu + ori_masks = None + if 'masks' in results: + if isinstance(results['masks'], PolygonMasks): + raise NotImplementedError( + 'Albu only supports BitMap masks now') + ori_masks = results['masks'] + if albumentations.__version__ < '0.5': + results['masks'] = results['masks'].masks + else: + results['masks'] = [mask for mask in results['masks'].masks] + + return results, ori_masks + + def _postprocess_results( + self, + results: dict, + ori_masks: Optional[Union[BitmapMasks, + PolygonMasks]] = None) -> dict: + """Post-processing Albu output.""" + # albumentations may return np.array or list on different versions + if 'gt_bboxes_labels' in results and isinstance( + results['gt_bboxes_labels'], list): + results['gt_bboxes_labels'] = np.array( + results['gt_bboxes_labels'], dtype=np.int64) + if 'gt_ignore_flags' in results and isinstance( + results['gt_ignore_flags'], list): + results['gt_ignore_flags'] = np.array( + results['gt_ignore_flags'], dtype=bool) + + if 'bboxes' in results: + if isinstance(results['bboxes'], list): + results['bboxes'] = np.array( + results['bboxes'], dtype=np.float32) + results['bboxes'] = results['bboxes'].reshape(-1, 4) + results['bboxes'] = HorizontalBoxes(results['bboxes']) + + # filter label_fields + if self.filter_lost_elements: + + for label in self.origin_label_fields: + results[label] = np.array( + [results[label][i] for i in results['idx_mapper']]) + if 'masks' in results: + assert ori_masks is not None + results['masks'] = np.array( + [results['masks'][i] for i in results['idx_mapper']]) + results['masks'] = ori_masks.__class__( + results['masks'], ori_masks.height, ori_masks.width) + + if (not len(results['idx_mapper']) + and self.skip_img_without_anno): + return None + elif 'masks' in results: + results['masks'] = ori_masks.__class__(results['masks'], + ori_masks.height, + ori_masks.width) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str + + +@TRANSFORMS.register_module() +@avoid_cache_randomness +class RandomCenterCropPad(BaseTransform): + """Random center crop and random around padding for CornerNet. + + This operation generates randomly cropped image from the original image and + pads it simultaneously. Different from :class:`RandomCrop`, the output + shape may not equal to ``crop_size`` strictly. We choose a random value + from ``ratios`` and the output shape could be larger or smaller than + ``crop_size``. The padding operation is also different from :class:`Pad`, + here we use around padding instead of right-bottom padding. + + The relation between output image (padding image) and original image: + + .. code:: text + + output image + + +----------------------------+ + | padded area | + +------|----------------------------|----------+ + | | cropped area | | + | | +---------------+ | | + | | | . center | | | original image + | | | range | | | + | | +---------------+ | | + +------|----------------------------|----------+ + | padded area | + +----------------------------+ + + There are 5 main areas in the figure: + + - output image: output image of this operation, also called padding + image in following instruction. + - original image: input image of this operation. + - padded area: non-intersect area of output image and original image. + - cropped area: the overlap of output image and original image. + - center range: a smaller area where random center chosen from. + center range is computed by ``border`` and original image's shape + to avoid our random center is too close to original image's border. + + Also this operation act differently in train and test mode, the summary + pipeline is listed below. + + Train pipeline: + + 1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image + will be ``random_ratio * crop_size``. + 2. Choose a ``random_center`` in center range. + 3. Generate padding image with center matches the ``random_center``. + 4. Initialize the padding image with pixel value equals to ``mean``. + 5. Copy the cropped area to padding image. + 6. Refine annotations. + + Test pipeline: + + 1. Compute output shape according to ``test_pad_mode``. + 2. Generate padding image with center matches the original image + center. + 3. Initialize the padding image with pixel value equals to ``mean``. + 4. Copy the ``cropped area`` to padding image. + + Required Keys: + + - img (np.float32) + - img_shape (tuple) + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - img (np.float32) + - img_shape (tuple) + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + + Args: + crop_size (tuple, optional): expected size after crop, final size will + computed according to ratio. Requires (width, height) + in train mode, and None in test mode. + ratios (tuple, optional): random select a ratio from tuple and crop + image to (crop_size[0] * ratio) * (crop_size[1] * ratio). + Only available in train mode. Defaults to (0.9, 1.0, 1.1). + border (int, optional): max distance from center select area to image + border. Only available in train mode. Defaults to 128. + mean (sequence, optional): Mean values of 3 channels. + std (sequence, optional): Std values of 3 channels. + to_rgb (bool, optional): Whether to convert the image from BGR to RGB. + test_mode (bool): whether involve random variables in transform. + In train mode, crop_size is fixed, center coords and ratio is + random selected from predefined lists. In test mode, crop_size + is image's original shape, center coords and ratio is fixed. + Defaults to False. + test_pad_mode (tuple, optional): padding method and padding shape + value, only available in test mode. Default is using + 'logical_or' with 127 as padding shape value. + + - 'logical_or': final_shape = input_shape | padding_shape_value + - 'size_divisor': final_shape = int( + ceil(input_shape / padding_shape_value) * padding_shape_value) + + Defaults to ('logical_or', 127). + test_pad_add_pix (int): Extra padding pixel in test mode. + Defaults to 0. + bbox_clip_border (bool): Whether clip the objects outside + the border of the image. Defaults to True. + """ + + def __init__(self, + crop_size: Optional[tuple] = None, + ratios: Optional[tuple] = (0.9, 1.0, 1.1), + border: Optional[int] = 128, + mean: Optional[Sequence] = None, + std: Optional[Sequence] = None, + to_rgb: Optional[bool] = None, + test_mode: bool = False, + test_pad_mode: Optional[tuple] = ('logical_or', 127), + test_pad_add_pix: int = 0, + bbox_clip_border: bool = True) -> None: + if test_mode: + assert crop_size is None, 'crop_size must be None in test mode' + assert ratios is None, 'ratios must be None in test mode' + assert border is None, 'border must be None in test mode' + assert isinstance(test_pad_mode, (list, tuple)) + assert test_pad_mode[0] in ['logical_or', 'size_divisor'] + else: + assert isinstance(crop_size, (list, tuple)) + assert crop_size[0] > 0 and crop_size[1] > 0, ( + 'crop_size must > 0 in train mode') + assert isinstance(ratios, (list, tuple)) + assert test_pad_mode is None, ( + 'test_pad_mode must be None in train mode') + + self.crop_size = crop_size + self.ratios = ratios + self.border = border + # We do not set default value to mean, std and to_rgb because these + # hyper-parameters are easy to forget but could affect the performance. + # Please use the same setting as Normalize for performance assurance. + assert mean is not None and std is not None and to_rgb is not None + self.to_rgb = to_rgb + self.input_mean = mean + self.input_std = std + if to_rgb: + self.mean = mean[::-1] + self.std = std[::-1] + else: + self.mean = mean + self.std = std + self.test_mode = test_mode + self.test_pad_mode = test_pad_mode + self.test_pad_add_pix = test_pad_add_pix + self.bbox_clip_border = bbox_clip_border + + def _get_border(self, border, size): + """Get final border for the target size. + + This function generates a ``final_border`` according to image's shape. + The area between ``final_border`` and ``size - final_border`` is the + ``center range``. We randomly choose center from the ``center range`` + to avoid our random center is too close to original image's border. + Also ``center range`` should be larger than 0. + + Args: + border (int): The initial border, default is 128. + size (int): The width or height of original image. + Returns: + int: The final border. + """ + k = 2 * border / size + i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k))) + return border // i + + def _filter_boxes(self, patch, boxes): + """Check whether the center of each box is in the patch. + + Args: + patch (list[int]): The cropped area, [left, top, right, bottom]. + boxes (numpy array, (N x 4)): Ground truth boxes. + + Returns: + mask (numpy array, (N,)): Each box is inside or outside the patch. + """ + center = boxes.centers.numpy() + mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * ( + center[:, 0] < patch[2]) * ( + center[:, 1] < patch[3]) + return mask + + def _crop_image_and_paste(self, image, center, size): + """Crop image with a given center and size, then paste the cropped + image to a blank image with two centers align. + + This function is equivalent to generating a blank image with ``size`` + as its shape. Then cover it on the original image with two centers ( + the center of blank image and the random center of original image) + aligned. The overlap area is paste from the original image and the + outside area is filled with ``mean pixel``. + + Args: + image (np array, H x W x C): Original image. + center (list[int]): Target crop center coord. + size (list[int]): Target crop size. [target_h, target_w] + + Returns: + cropped_img (np array, target_h x target_w x C): Cropped image. + border (np array, 4): The distance of four border of + ``cropped_img`` to the original image area, [top, bottom, + left, right] + patch (list[int]): The cropped area, [left, top, right, bottom]. + """ + center_y, center_x = center + target_h, target_w = size + img_h, img_w, img_c = image.shape + + x0 = max(0, center_x - target_w // 2) + x1 = min(center_x + target_w // 2, img_w) + y0 = max(0, center_y - target_h // 2) + y1 = min(center_y + target_h // 2, img_h) + patch = np.array((int(x0), int(y0), int(x1), int(y1))) + + left, right = center_x - x0, x1 - center_x + top, bottom = center_y - y0, y1 - center_y + + cropped_center_y, cropped_center_x = target_h // 2, target_w // 2 + cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype) + for i in range(img_c): + cropped_img[:, :, i] += self.mean[i] + y_slice = slice(cropped_center_y - top, cropped_center_y + bottom) + x_slice = slice(cropped_center_x - left, cropped_center_x + right) + cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :] + + border = np.array([ + cropped_center_y - top, cropped_center_y + bottom, + cropped_center_x - left, cropped_center_x + right + ], + dtype=np.float32) + + return cropped_img, border, patch + + def _train_aug(self, results): + """Random crop and around padding the original image. + + Args: + results (dict): Image infomations in the augment pipeline. + + Returns: + results (dict): The updated dict. + """ + img = results['img'] + h, w, c = img.shape + gt_bboxes = results['gt_bboxes'] + while True: + scale = random.choice(self.ratios) + new_h = int(self.crop_size[1] * scale) + new_w = int(self.crop_size[0] * scale) + h_border = self._get_border(self.border, h) + w_border = self._get_border(self.border, w) + + for i in range(50): + center_x = random.randint(low=w_border, high=w - w_border) + center_y = random.randint(low=h_border, high=h - h_border) + + cropped_img, border, patch = self._crop_image_and_paste( + img, [center_y, center_x], [new_h, new_w]) + + if len(gt_bboxes) == 0: + results['img'] = cropped_img + results['img_shape'] = cropped_img.shape[:2] + return results + + # if image do not have valid bbox, any crop patch is valid. + mask = self._filter_boxes(patch, gt_bboxes) + if not mask.any(): + continue + + results['img'] = cropped_img + results['img_shape'] = cropped_img.shape[:2] + + x0, y0, x1, y1 = patch + + left_w, top_h = center_x - x0, center_y - y0 + cropped_center_x, cropped_center_y = new_w // 2, new_h // 2 + + # crop bboxes accordingly and clip to the image boundary + gt_bboxes = gt_bboxes[mask] + gt_bboxes.translate_([ + cropped_center_x - left_w - x0, + cropped_center_y - top_h - y0 + ]) + if self.bbox_clip_border: + gt_bboxes.clip_([new_h, new_w]) + keep = gt_bboxes.is_inside([new_h, new_w]).numpy() + gt_bboxes = gt_bboxes[keep] + + results['gt_bboxes'] = gt_bboxes + + # ignore_flags + if results.get('gt_ignore_flags', None) is not None: + gt_ignore_flags = results['gt_ignore_flags'][mask] + results['gt_ignore_flags'] = \ + gt_ignore_flags[keep] + + # labels + if results.get('gt_bboxes_labels', None) is not None: + gt_labels = results['gt_bboxes_labels'][mask] + results['gt_bboxes_labels'] = gt_labels[keep] + + if 'gt_masks' in results or 'gt_seg_map' in results: + raise NotImplementedError( + 'RandomCenterCropPad only supports bbox.') + + return results + + def _test_aug(self, results): + """Around padding the original image without cropping. + + The padding mode and value are from ``test_pad_mode``. + + Args: + results (dict): Image infomations in the augment pipeline. + + Returns: + results (dict): The updated dict. + """ + img = results['img'] + h, w, c = img.shape + if self.test_pad_mode[0] in ['logical_or']: + # self.test_pad_add_pix is only used for centernet + target_h = (h | self.test_pad_mode[1]) + self.test_pad_add_pix + target_w = (w | self.test_pad_mode[1]) + self.test_pad_add_pix + elif self.test_pad_mode[0] in ['size_divisor']: + divisor = self.test_pad_mode[1] + target_h = int(np.ceil(h / divisor)) * divisor + target_w = int(np.ceil(w / divisor)) * divisor + else: + raise NotImplementedError( + 'RandomCenterCropPad only support two testing pad mode:' + 'logical-or and size_divisor.') + + cropped_img, border, _ = self._crop_image_and_paste( + img, [h // 2, w // 2], [target_h, target_w]) + results['img'] = cropped_img + results['img_shape'] = cropped_img.shape[:2] + results['border'] = border + return results + + @autocast_box_type() + def transform(self, results: dict) -> dict: + img = results['img'] + assert img.dtype == np.float32, ( + 'RandomCenterCropPad needs the input image of dtype np.float32,' + ' please set "to_float32=True" in "LoadImageFromFile" pipeline') + h, w, c = img.shape + assert c == len(self.mean) + if self.test_mode: + return self._test_aug(results) + else: + return self._train_aug(results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(crop_size={self.crop_size}, ' + repr_str += f'ratios={self.ratios}, ' + repr_str += f'border={self.border}, ' + repr_str += f'mean={self.input_mean}, ' + repr_str += f'std={self.input_std}, ' + repr_str += f'to_rgb={self.to_rgb}, ' + repr_str += f'test_mode={self.test_mode}, ' + repr_str += f'test_pad_mode={self.test_pad_mode}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@TRANSFORMS.register_module() +class CutOut(BaseTransform): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + n_holes (int or tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [``n_holes[0]``, ``n_holes[1]``]. + cutout_shape (tuple[int, int] or list[tuple[int, int]], optional): + The candidate shape of dropped regions. It can be + ``tuple[int, int]`` to use a fixed cutout shape, or + ``list[tuple[int, int]]`` to randomly choose shape + from the list. Defaults to None. + cutout_ratio (tuple[float, float] or list[tuple[float, float]], + optional): The candidate ratio of dropped regions. It can be + ``tuple[float, float]`` to use a fixed ratio or + ``list[tuple[float, float]]`` to randomly choose ratio + from the list. Please note that ``cutout_shape`` and + ``cutout_ratio`` cannot be both given at the same time. + Defaults to None. + fill_in (tuple[float, float, float] or tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Defaults to (0, 0, 0). + """ + + def __init__( + self, + n_holes: Union[int, Tuple[int, int]], + cutout_shape: Optional[Union[Tuple[int, int], + List[Tuple[int, int]]]] = None, + cutout_ratio: Optional[Union[Tuple[float, float], + List[Tuple[float, float]]]] = None, + fill_in: Union[Tuple[float, float, float], Tuple[int, int, + int]] = (0, 0, 0) + ) -> None: + + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + self.n_holes = n_holes + self.fill_in = fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Call function to drop some regions of image.""" + h, w, c = results['img'].shape + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + for _ in range(n_holes): + x1 = np.random.randint(0, w) + y1 = np.random.randint(0, h) + index = np.random.randint(0, len(self.candidates)) + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in})' + return repr_str + + +@TRANSFORMS.register_module() +class Mosaic(BaseTransform): + """Mosaic augmentation. + + Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - mix_results (List[dict]) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + Args: + img_scale (Sequence[int]): Image size before mosaic pipeline of single + image. The shape order should be (width, height). + Defaults to (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Defaults to (0.5, 1.5). + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + pad_val (int): Pad value. Defaults to 114. + prob (float): Probability of applying this transformation. + Defaults to 1.0. + """ + + def __init__(self, + img_scale: Tuple[int, int] = (640, 640), + center_ratio_range: Tuple[float, float] = (0.5, 1.5), + bbox_clip_border: bool = True, + pad_val: float = 114.0, + prob: float = 1.0) -> None: + assert isinstance(img_scale, tuple) + assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \ + f'got {prob}.' + + log_img_scale(img_scale, skip_square=True, shape_order='wh') + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.bbox_clip_border = bbox_clip_border + self.pad_val = pad_val + self.prob = prob + + @cache_randomness + def get_indexes(self, dataset: BaseDataset) -> int: + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + indexes = [random.randint(0, len(dataset)) for _ in range(3)] + return indexes + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + if random.uniform(0, 1) > self.prob: + return results + + assert 'mix_results' in results + mosaic_bboxes = [] + mosaic_bboxes_labels = [] + mosaic_ignore_flags = [] + if len(results['img'].shape) == 3: + mosaic_img = np.full( + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_position = (center_x, center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + results_patch = copy.deepcopy(results) + else: + results_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = results_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[1] / h_i, + self.img_scale[0] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + # adjust coordinate + gt_bboxes_i = results_patch['gt_bboxes'] + gt_bboxes_labels_i = results_patch['gt_bboxes_labels'] + gt_ignore_flags_i = results_patch['gt_ignore_flags'] + + padw = x1_p - x1_c + padh = y1_p - y1_c + gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i]) + gt_bboxes_i.translate_([padw, padh]) + mosaic_bboxes.append(gt_bboxes_i) + mosaic_bboxes_labels.append(gt_bboxes_labels_i) + mosaic_ignore_flags.append(gt_ignore_flags_i) + + mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0) + mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0) + mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0) + + if self.bbox_clip_border: + mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]]) + # remove outside bboxes + inside_inds = mosaic_bboxes.is_inside( + [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy() + mosaic_bboxes = mosaic_bboxes[inside_inds] + mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds] + mosaic_ignore_flags = mosaic_ignore_flags[inside_inds] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape[:2] + results['gt_bboxes'] = mosaic_bboxes + results['gt_bboxes_labels'] = mosaic_bboxes_labels + results['gt_ignore_flags'] = mosaic_ignore_flags + return results + + def _mosaic_combine( + self, loc: str, center_position_xy: Sequence[float], + img_shape_wh: Sequence[int]) -> Tuple[Tuple[int], Tuple[int]]: + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[0] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[1] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[0] * 2), \ + min(self.img_scale[1] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class MixUp(BaseTransform): + """MixUp data augmentation. + + .. code:: text + + mixup transform + +------------------------------+ + | mixup image | | + | +--------|--------+ | + | | | | | + |---------------+ | | + | | | | + | | image | | + | | | | + | | | | + | |-----------------+ | + | pad | + +------------------------------+ + + The mixup transform steps are as follows: + + 1. Another random image is picked by dataset and embedded in + the top left patch(after padding and resizing) + 2. The target of mixup transform is the weighted average of mixup + image and origin image. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - mix_results (List[dict]) + + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + + Args: + img_scale (Sequence[int]): Image output size after mixup pipeline. + The shape order should be (width, height). Defaults to (640, 640). + ratio_range (Sequence[float]): Scale ratio of mixup image. + Defaults to (0.5, 1.5). + flip_ratio (float): Horizontal flip ratio of mixup image. + Defaults to 0.5. + pad_val (int): Pad value. Defaults to 114. + max_iters (int): The maximum number of iterations. If the number of + iterations is greater than `max_iters`, but gt_bbox is still + empty, then the iteration is terminated. Defaults to 15. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + """ + + def __init__(self, + img_scale: Tuple[int, int] = (640, 640), + ratio_range: Tuple[float, float] = (0.5, 1.5), + flip_ratio: float = 0.5, + pad_val: float = 114.0, + max_iters: int = 15, + bbox_clip_border: bool = True) -> None: + assert isinstance(img_scale, tuple) + log_img_scale(img_scale, skip_square=True, shape_order='wh') + self.dynamic_scale = img_scale + self.ratio_range = ratio_range + self.flip_ratio = flip_ratio + self.pad_val = pad_val + self.max_iters = max_iters + self.bbox_clip_border = bbox_clip_border + + @cache_randomness + def get_indexes(self, dataset: BaseDataset) -> int: + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + for i in range(self.max_iters): + index = random.randint(0, len(dataset)) + gt_bboxes_i = dataset[index]['gt_bboxes'] + if len(gt_bboxes_i) != 0: + break + + return index + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """MixUp transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + assert len( + results['mix_results']) == 1, 'MixUp only support 2 images now !' + + if results['mix_results'][0]['gt_bboxes'].shape[0] == 0: + # empty bbox + return results + + retrieve_results = results['mix_results'][0] + retrieve_img = retrieve_results['img'] + + jit_factor = random.uniform(*self.ratio_range) + is_flip = random.uniform(0, 1) > self.flip_ratio + + if len(retrieve_img.shape) == 3: + out_img = np.ones( + (self.dynamic_scale[1], self.dynamic_scale[0], 3), + dtype=retrieve_img.dtype) * self.pad_val + else: + out_img = np.ones( + self.dynamic_scale[::-1], + dtype=retrieve_img.dtype) * self.pad_val + + # 1. keep_ratio resize + scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0], + self.dynamic_scale[0] / retrieve_img.shape[1]) + retrieve_img = mmcv.imresize( + retrieve_img, (int(retrieve_img.shape[1] * scale_ratio), + int(retrieve_img.shape[0] * scale_ratio))) + + # 2. paste + out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img + + # 3. scale jit + scale_ratio *= jit_factor + out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor), + int(out_img.shape[0] * jit_factor))) + + # 4. flip + if is_flip: + out_img = out_img[:, ::-1, :] + + # 5. random crop + ori_img = results['img'] + origin_h, origin_w = out_img.shape[:2] + target_h, target_w = ori_img.shape[:2] + padded_img = np.ones((max(origin_h, target_h), max( + origin_w, target_w), 3)) * self.pad_val + padded_img = padded_img.astype(np.uint8) + padded_img[:origin_h, :origin_w] = out_img + + x_offset, y_offset = 0, 0 + if padded_img.shape[0] > target_h: + y_offset = random.randint(0, padded_img.shape[0] - target_h) + if padded_img.shape[1] > target_w: + x_offset = random.randint(0, padded_img.shape[1] - target_w) + padded_cropped_img = padded_img[y_offset:y_offset + target_h, + x_offset:x_offset + target_w] + + # 6. adjust bbox + retrieve_gt_bboxes = retrieve_results['gt_bboxes'] + retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio]) + if self.bbox_clip_border: + retrieve_gt_bboxes.clip_([origin_h, origin_w]) + + if is_flip: + retrieve_gt_bboxes.flip_([origin_h, origin_w], + direction='horizontal') + + # 7. filter + cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone() + cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset]) + if self.bbox_clip_border: + cp_retrieve_gt_bboxes.clip_([target_h, target_w]) + + # 8. mix up + ori_img = ori_img.astype(np.float32) + mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32) + + retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels'] + retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags'] + + mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat( + (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0) + mixup_gt_bboxes_labels = np.concatenate( + (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0) + mixup_gt_ignore_flags = np.concatenate( + (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0) + + # remove outside bbox + inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy() + mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] + mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds] + mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds] + + results['img'] = mixup_img.astype(np.uint8) + results['img_shape'] = mixup_img.shape[:2] + results['gt_bboxes'] = mixup_gt_bboxes + results['gt_bboxes_labels'] = mixup_gt_bboxes_labels + results['gt_ignore_flags'] = mixup_gt_ignore_flags + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(dynamic_scale={self.dynamic_scale}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'flip_ratio={self.flip_ratio}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'max_iters={self.max_iters}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomAffine(BaseTransform): + """Random affine transform data augmentation. + + This operation randomly generates affine transform matrix which including + rotation, translation, shear and scaling transforms. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + Args: + max_rotate_degree (float): Maximum degrees of rotation transform. + Defaults to 10. + max_translate_ratio (float): Maximum ratio of translation. + Defaults to 0.1. + scaling_ratio_range (tuple[float]): Min and max ratio of + scaling transform. Defaults to (0.5, 1.5). + max_shear_degree (float): Maximum degrees of shear + transform. Defaults to 2. + border (tuple[int]): Distance from width and height sides of input + image to adjust output shape. Only used in mosaic dataset. + Defaults to (0, 0). + border_val (tuple[int]): Border padding values of 3 channels. + Defaults to (114, 114, 114). + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + """ + + def __init__(self, + max_rotate_degree: float = 10.0, + max_translate_ratio: float = 0.1, + scaling_ratio_range: Tuple[float, float] = (0.5, 1.5), + max_shear_degree: float = 2.0, + border: Tuple[int, int] = (0, 0), + border_val: Tuple[int, int, int] = (114, 114, 114), + bbox_clip_border: bool = True) -> None: + assert 0 <= max_translate_ratio <= 1 + assert scaling_ratio_range[0] <= scaling_ratio_range[1] + assert scaling_ratio_range[0] > 0 + self.max_rotate_degree = max_rotate_degree + self.max_translate_ratio = max_translate_ratio + self.scaling_ratio_range = scaling_ratio_range + self.max_shear_degree = max_shear_degree + self.border = border + self.border_val = border_val + self.bbox_clip_border = bbox_clip_border + + @cache_randomness + def _get_random_homography_matrix(self, height, width): + # Rotation + rotation_degree = random.uniform(-self.max_rotate_degree, + self.max_rotate_degree) + rotation_matrix = self._get_rotation_matrix(rotation_degree) + + # Scaling + scaling_ratio = random.uniform(self.scaling_ratio_range[0], + self.scaling_ratio_range[1]) + scaling_matrix = self._get_scaling_matrix(scaling_ratio) + + # Shear + x_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + y_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + shear_matrix = self._get_shear_matrix(x_degree, y_degree) + + # Translation + trans_x = random.uniform(-self.max_translate_ratio, + self.max_translate_ratio) * width + trans_y = random.uniform(-self.max_translate_ratio, + self.max_translate_ratio) * height + translate_matrix = self._get_translation_matrix(trans_x, trans_y) + + warp_matrix = ( + translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix) + return warp_matrix + + @autocast_box_type() + def transform(self, results: dict) -> dict: + img = results['img'] + height = img.shape[0] + self.border[1] * 2 + width = img.shape[1] + self.border[0] * 2 + + warp_matrix = self._get_random_homography_matrix(height, width) + + img = cv2.warpPerspective( + img, + warp_matrix, + dsize=(width, height), + borderValue=self.border_val) + results['img'] = img + results['img_shape'] = img.shape[:2] + + bboxes = results['gt_bboxes'] + num_bboxes = len(bboxes) + if num_bboxes: + bboxes.project_(warp_matrix) + if self.bbox_clip_border: + bboxes.clip_([height, width]) + # remove outside bbox + valid_index = bboxes.is_inside([height, width]).numpy() + results['gt_bboxes'] = bboxes[valid_index] + results['gt_bboxes_labels'] = results['gt_bboxes_labels'][ + valid_index] + results['gt_ignore_flags'] = results['gt_ignore_flags'][ + valid_index] + + if 'gt_masks' in results: + raise NotImplementedError('RandomAffine only supports bbox.') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(max_rotate_degree={self.max_rotate_degree}, ' + repr_str += f'max_translate_ratio={self.max_translate_ratio}, ' + repr_str += f'scaling_ratio_range={self.scaling_ratio_range}, ' + repr_str += f'max_shear_degree={self.max_shear_degree}, ' + repr_str += f'border={self.border}, ' + repr_str += f'border_val={self.border_val}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + @staticmethod + def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray: + radian = math.radians(rotate_degrees) + rotation_matrix = np.array( + [[np.cos(radian), -np.sin(radian), 0.], + [np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]], + dtype=np.float32) + return rotation_matrix + + @staticmethod + def _get_scaling_matrix(scale_ratio: float) -> np.ndarray: + scaling_matrix = np.array( + [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]], + dtype=np.float32) + return scaling_matrix + + @staticmethod + def _get_shear_matrix(x_shear_degrees: float, + y_shear_degrees: float) -> np.ndarray: + x_radian = math.radians(x_shear_degrees) + y_radian = math.radians(y_shear_degrees) + shear_matrix = np.array([[1, np.tan(x_radian), 0.], + [np.tan(y_radian), 1, 0.], [0., 0., 1.]], + dtype=np.float32) + return shear_matrix + + @staticmethod + def _get_translation_matrix(x: float, y: float) -> np.ndarray: + translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]], + dtype=np.float32) + return translation_matrix + + +@TRANSFORMS.register_module() +class YOLOXHSVRandomAug(BaseTransform): + """Apply HSV augmentation to image sequentially. It is referenced from + https://github.com/Megvii- + BaseDetection/YOLOX/blob/main/yolox/data/data_augment.py#L21. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + hue_delta (int): delta of hue. Defaults to 5. + saturation_delta (int): delta of saturation. Defaults to 30. + value_delta (int): delat of value. Defaults to 30. + """ + + def __init__(self, + hue_delta: int = 5, + saturation_delta: int = 30, + value_delta: int = 30) -> None: + self.hue_delta = hue_delta + self.saturation_delta = saturation_delta + self.value_delta = value_delta + + @cache_randomness + def _get_hsv_gains(self): + hsv_gains = np.random.uniform(-1, 1, 3) * [ + self.hue_delta, self.saturation_delta, self.value_delta + ] + # random selection of h, s, v + hsv_gains *= np.random.randint(0, 2, 3) + # prevent overflow + hsv_gains = hsv_gains.astype(np.int16) + return hsv_gains + + def transform(self, results: dict) -> dict: + img = results['img'] + hsv_gains = self._get_hsv_gains() + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16) + + img_hsv[..., 0] = (img_hsv[..., 0] + hsv_gains[0]) % 180 + img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_gains[1], 0, 255) + img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_gains[2], 0, 255) + cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(hue_delta={self.hue_delta}, ' + repr_str += f'saturation_delta={self.saturation_delta}, ' + repr_str += f'value_delta={self.value_delta})' + return repr_str + + +@TRANSFORMS.register_module() +class CopyPaste(BaseTransform): + """Simple Copy-Paste is a Strong Data Augmentation Method for Instance + Segmentation The simple copy-paste transform steps are as follows: + + 1. The destination image is already resized with aspect ratio kept, + cropped and padded. + 2. Randomly select a source image, which is also already resized + with aspect ratio kept, cropped and padded in a similar way + as the destination image. + 3. Randomly select some objects from the source image. + 4. Paste these source objects to the destination image directly, + due to the source and destination image have the same size. + 5. Update object masks of the destination image, for some origin objects + may be occluded. + 6. Generate bboxes from the updated destination masks and + filter some objects which are totally occluded, and adjust bboxes + which are partly occluded. + 7. Append selected source bboxes, masks, and labels. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - gt_masks (BitmapMasks) (optional) + + Modified Keys: + + - img + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + - gt_masks (optional) + + Args: + max_num_pasted (int): The maximum number of pasted objects. + Defaults to 100. + bbox_occluded_thr (int): The threshold of occluded bbox. + Defaults to 10. + mask_occluded_thr (int): The threshold of occluded mask. + Defaults to 300. + selected (bool): Whether select objects or not. If select is False, + all objects of the source image will be pasted to the + destination image. + Defaults to True. + paste_by_box (bool): Whether use boxes as masks when masks are not + available. + Defaults to False. + """ + + def __init__( + self, + max_num_pasted: int = 100, + bbox_occluded_thr: int = 10, + mask_occluded_thr: int = 300, + selected: bool = True, + paste_by_box: bool = False, + ) -> None: + self.max_num_pasted = max_num_pasted + self.bbox_occluded_thr = bbox_occluded_thr + self.mask_occluded_thr = mask_occluded_thr + self.selected = selected + self.paste_by_box = paste_by_box + + @cache_randomness + def get_indexes(self, dataset: BaseDataset) -> int: + """Call function to collect indexes.s. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + Returns: + list: Indexes. + """ + return random.randint(0, len(dataset)) + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to make a copy-paste of image. + + Args: + results (dict): Result dict. + Returns: + dict: Result dict with copy-paste transformed. + """ + + assert 'mix_results' in results + num_images = len(results['mix_results']) + assert num_images == 1, \ + f'CopyPaste only supports processing 2 images, got {num_images}' + if self.selected: + selected_results = self._select_object(results['mix_results'][0]) + else: + selected_results = results['mix_results'][0] + return self._copy_paste(results, selected_results) + + @cache_randomness + def _get_selected_inds(self, num_bboxes: int) -> np.ndarray: + max_num_pasted = min(num_bboxes + 1, self.max_num_pasted) + num_pasted = np.random.randint(0, max_num_pasted) + return np.random.choice(num_bboxes, size=num_pasted, replace=False) + + def get_gt_masks(self, results: dict) -> BitmapMasks: + """Get gt_masks originally or generated based on bboxes. + + If gt_masks is not contained in results, + it will be generated based on gt_bboxes. + Args: + results (dict): Result dict. + Returns: + BitmapMasks: gt_masks, originally or generated based on bboxes. + """ + if results.get('gt_masks', None) is not None: + if self.paste_by_box: + warnings.warn('gt_masks is already contained in results, ' + 'so paste_by_box is disabled.') + return results['gt_masks'] + else: + if not self.paste_by_box: + raise RuntimeError('results does not contain masks.') + return results['gt_bboxes'].create_masks(results['img'].shape[:2]) + + def _select_object(self, results: dict) -> dict: + """Select some objects from the source results.""" + bboxes = results['gt_bboxes'] + labels = results['gt_bboxes_labels'] + masks = self.get_gt_masks(results) + ignore_flags = results['gt_ignore_flags'] + + selected_inds = self._get_selected_inds(bboxes.shape[0]) + + selected_bboxes = bboxes[selected_inds] + selected_labels = labels[selected_inds] + selected_masks = masks[selected_inds] + selected_ignore_flags = ignore_flags[selected_inds] + + results['gt_bboxes'] = selected_bboxes + results['gt_bboxes_labels'] = selected_labels + results['gt_masks'] = selected_masks + results['gt_ignore_flags'] = selected_ignore_flags + return results + + def _copy_paste(self, dst_results: dict, src_results: dict) -> dict: + """CopyPaste transform function. + + Args: + dst_results (dict): Result dict of the destination image. + src_results (dict): Result dict of the source image. + Returns: + dict: Updated result dict. + """ + dst_img = dst_results['img'] + dst_bboxes = dst_results['gt_bboxes'] + dst_labels = dst_results['gt_bboxes_labels'] + dst_masks = self.get_gt_masks(dst_results) + dst_ignore_flags = dst_results['gt_ignore_flags'] + + src_img = src_results['img'] + src_bboxes = src_results['gt_bboxes'] + src_labels = src_results['gt_bboxes_labels'] + src_masks = src_results['gt_masks'] + src_ignore_flags = src_results['gt_ignore_flags'] + + if len(src_bboxes) == 0: + return dst_results + + # update masks and generate bboxes from updated masks + composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0) + updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask) + updated_dst_bboxes = updated_dst_masks.get_bboxes(type(dst_bboxes)) + assert len(updated_dst_bboxes) == len(updated_dst_masks) + + # filter totally occluded objects + l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs() + bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all( + dim=-1).numpy() + masks_inds = updated_dst_masks.masks.sum( + axis=(1, 2)) > self.mask_occluded_thr + valid_inds = bboxes_inds | masks_inds + + # Paste source objects to destination image directly + img = dst_img * (1 - composed_mask[..., np.newaxis] + ) + src_img * composed_mask[..., np.newaxis] + bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes]) + labels = np.concatenate([dst_labels[valid_inds], src_labels]) + masks = np.concatenate( + [updated_dst_masks.masks[valid_inds], src_masks.masks]) + ignore_flags = np.concatenate( + [dst_ignore_flags[valid_inds], src_ignore_flags]) + + dst_results['img'] = img + dst_results['gt_bboxes'] = bboxes + dst_results['gt_bboxes_labels'] = labels + dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1], + masks.shape[2]) + dst_results['gt_ignore_flags'] = ignore_flags + + return dst_results + + def _get_updated_masks(self, masks: BitmapMasks, + composed_mask: np.ndarray) -> BitmapMasks: + """Update masks with composed mask.""" + assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \ + 'Cannot compare two arrays of different size' + masks.masks = np.where(composed_mask, 0, masks.masks) + return masks + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(max_num_pasted={self.max_num_pasted}, ' + repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, ' + repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, ' + repr_str += f'selected={self.selected}), ' + repr_str += f'paste_by_box={self.paste_by_box})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomErasing(BaseTransform): + """RandomErasing operation. + + Random Erasing randomly selects a rectangle region + in an image and erases its pixels with random values. + `RandomErasing `_. + + Required Keys: + + - img + - gt_bboxes (HorizontalBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - gt_masks (BitmapMasks) (optional) + + Modified Keys: + - img + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + - gt_masks (optional) + + Args: + n_patches (int or tuple[int, int]): Number of regions to be dropped. + If it is given as a tuple, number of patches will be randomly + selected from the closed interval [``n_patches[0]``, + ``n_patches[1]``]. + ratio (float or tuple[float, float]): The ratio of erased regions. + It can be ``float`` to use a fixed ratio or ``tuple[float, float]`` + to randomly choose ratio from the interval. + squared (bool): Whether to erase square region. Defaults to True. + bbox_erased_thr (float): The threshold for the maximum area proportion + of the bbox to be erased. When the proportion of the area where the + bbox is erased is greater than the threshold, the bbox will be + removed. Defaults to 0.9. + img_border_value (int or float or tuple): The filled values for + image border. If float, the same fill value will be used for + all the three channels of image. If tuple, it should be 3 elements. + Defaults to 128. + mask_border_value (int): The fill value used for masks. Defaults to 0. + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Defaults to 255. + """ + + def __init__( + self, + n_patches: Union[int, Tuple[int, int]], + ratio: Union[float, Tuple[float, float]], + squared: bool = True, + bbox_erased_thr: float = 0.9, + img_border_value: Union[int, float, tuple] = 128, + mask_border_value: int = 0, + seg_ignore_label: int = 255, + ) -> None: + if isinstance(n_patches, tuple): + assert len(n_patches) == 2 and 0 <= n_patches[0] < n_patches[1] + else: + n_patches = (n_patches, n_patches) + if isinstance(ratio, tuple): + assert len(ratio) == 2 and 0 <= ratio[0] < ratio[1] <= 1 + else: + ratio = (ratio, ratio) + + self.n_patches = n_patches + self.ratio = ratio + self.squared = squared + self.bbox_erased_thr = bbox_erased_thr + self.img_border_value = img_border_value + self.mask_border_value = mask_border_value + self.seg_ignore_label = seg_ignore_label + + @cache_randomness + def _get_patches(self, img_shape: Tuple[int, int]) -> List[list]: + """Get patches for random erasing.""" + patches = [] + n_patches = np.random.randint(self.n_patches[0], self.n_patches[1] + 1) + for _ in range(n_patches): + if self.squared: + ratio = np.random.random() * (self.ratio[1] - + self.ratio[0]) + self.ratio[0] + ratio = (ratio, ratio) + else: + ratio = (np.random.random() * (self.ratio[1] - self.ratio[0]) + + self.ratio[0], np.random.random() * + (self.ratio[1] - self.ratio[0]) + self.ratio[0]) + ph, pw = int(img_shape[0] * ratio[0]), int(img_shape[1] * ratio[1]) + px1, py1 = np.random.randint(0, + img_shape[1] - pw), np.random.randint( + 0, img_shape[0] - ph) + px2, py2 = px1 + pw, py1 + ph + patches.append([px1, py1, px2, py2]) + return np.array(patches) + + def _transform_img(self, results: dict, patches: List[list]) -> None: + """Random erasing the image.""" + for patch in patches: + px1, py1, px2, py2 = patch + results['img'][py1:py2, px1:px2, :] = self.img_border_value + + def _transform_bboxes(self, results: dict, patches: List[list]) -> None: + """Random erasing the bboxes.""" + bboxes = results['gt_bboxes'] + # TODO: unify the logic by using operators in BaseBoxes. + assert isinstance(bboxes, HorizontalBoxes) + bboxes = bboxes.numpy() + left_top = np.maximum(bboxes[:, None, :2], patches[:, :2]) + right_bottom = np.minimum(bboxes[:, None, 2:], patches[:, 2:]) + wh = np.maximum(right_bottom - left_top, 0) + inter_areas = wh[:, :, 0] * wh[:, :, 1] + bbox_areas = (bboxes[:, 2] - bboxes[:, 0]) * ( + bboxes[:, 3] - bboxes[:, 1]) + bboxes_erased_ratio = inter_areas.sum(-1) / (bbox_areas + 1e-7) + valid_inds = bboxes_erased_ratio < self.bbox_erased_thr + results['gt_bboxes'] = HorizontalBoxes(bboxes[valid_inds]) + results['gt_bboxes_labels'] = results['gt_bboxes_labels'][valid_inds] + results['gt_ignore_flags'] = results['gt_ignore_flags'][valid_inds] + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'][valid_inds] + + def _transform_masks(self, results: dict, patches: List[list]) -> None: + """Random erasing the masks.""" + for patch in patches: + px1, py1, px2, py2 = patch + results['gt_masks'].masks[:, py1:py2, + px1:px2] = self.mask_border_value + + def _transform_seg(self, results: dict, patches: List[list]) -> None: + """Random erasing the segmentation map.""" + for patch in patches: + px1, py1, px2, py2 = patch + results['gt_seg_map'][py1:py2, px1:px2] = self.seg_ignore_label + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to erase some regions of image.""" + patches = self._get_patches(results['img_shape']) + self._transform_img(results, patches) + if results.get('gt_bboxes', None) is not None: + self._transform_bboxes(results, patches) + if results.get('gt_masks', None) is not None: + self._transform_masks(results, patches) + if results.get('gt_seg_map', None) is not None: + self._transform_seg(results, patches) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(n_patches={self.n_patches}, ' + repr_str += f'ratio={self.ratio}, ' + repr_str += f'squared={self.squared}, ' + repr_str += f'bbox_erased_thr={self.bbox_erased_thr}, ' + repr_str += f'img_border_value={self.img_border_value}, ' + repr_str += f'mask_border_value={self.mask_border_value}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label})' + return repr_str + + +@TRANSFORMS.register_module() +class CachedMosaic(Mosaic): + """Cached mosaic augmentation. + + Cached mosaic transform will random select images from the cache + and combine them into one output image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The cached mosaic transform steps are as follows: + + 1. Append the results from the last transform into the cache. + 2. Choose the mosaic center as the intersections of 4 images + 3. Get the left top image according to the index, and randomly + sample another 3 images from the result cache. + 4. Sub image will be cropped if image is larger than mosaic patch + + Required Keys: + + - img + - gt_bboxes (np.float32) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + Args: + img_scale (Sequence[int]): Image size before mosaic pipeline of single + image. The shape order should be (width, height). + Defaults to (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Defaults to (0.5, 1.5). + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + pad_val (int): Pad value. Defaults to 114. + prob (float): Probability of applying this transformation. + Defaults to 1.0. + max_cached_images (int): The maximum length of the cache. The larger + the cache, the stronger the randomness of this transform. As a + rule of thumb, providing 10 caches for each image suffices for + randomness. Defaults to 40. + random_pop (bool): Whether to randomly pop a result from the cache + when the cache is full. If set to False, use FIFO popping method. + Defaults to True. + """ + + def __init__(self, + *args, + max_cached_images: int = 40, + random_pop: bool = True, + **kwargs) -> None: + super().__init__(*args, **kwargs) + self.results_cache = [] + self.random_pop = random_pop + assert max_cached_images >= 4, 'The length of cache must >= 4, ' \ + f'but got {max_cached_images}.' + self.max_cached_images = max_cached_images + + @cache_randomness + def get_indexes(self, cache: list) -> list: + """Call function to collect indexes. + + Args: + cache (list): The results cache. + + Returns: + list: indexes. + """ + + indexes = [random.randint(0, len(cache) - 1) for _ in range(3)] + return indexes + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + # cache and pop images + self.results_cache.append(copy.deepcopy(results)) + if len(self.results_cache) > self.max_cached_images: + if self.random_pop: + index = random.randint(0, len(self.results_cache) - 1) + else: + index = 0 + self.results_cache.pop(index) + + if len(self.results_cache) <= 4: + return results + + if random.uniform(0, 1) > self.prob: + return results + indices = self.get_indexes(self.results_cache) + mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices] + + # TODO: refactor mosaic to reuse these code. + mosaic_bboxes = [] + mosaic_bboxes_labels = [] + mosaic_ignore_flags = [] + mosaic_masks = [] + with_mask = True if 'gt_masks' in results else False + + if len(results['img'].shape) == 3: + mosaic_img = np.full( + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_position = (center_x, center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + results_patch = copy.deepcopy(results) + else: + results_patch = copy.deepcopy(mix_results[i - 1]) + + img_i = results_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[1] / h_i, + self.img_scale[0] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + # adjust coordinate + gt_bboxes_i = results_patch['gt_bboxes'] + gt_bboxes_labels_i = results_patch['gt_bboxes_labels'] + gt_ignore_flags_i = results_patch['gt_ignore_flags'] + + padw = x1_p - x1_c + padh = y1_p - y1_c + gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i]) + gt_bboxes_i.translate_([padw, padh]) + mosaic_bboxes.append(gt_bboxes_i) + mosaic_bboxes_labels.append(gt_bboxes_labels_i) + mosaic_ignore_flags.append(gt_ignore_flags_i) + if with_mask and results_patch.get('gt_masks', None) is not None: + gt_masks_i = results_patch['gt_masks'] + gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i)) + gt_masks_i = gt_masks_i.translate( + out_shape=(int(self.img_scale[0] * 2), + int(self.img_scale[1] * 2)), + offset=padw, + direction='horizontal') + gt_masks_i = gt_masks_i.translate( + out_shape=(int(self.img_scale[0] * 2), + int(self.img_scale[1] * 2)), + offset=padh, + direction='vertical') + mosaic_masks.append(gt_masks_i) + + mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0) + mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0) + mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0) + + if self.bbox_clip_border: + mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]]) + # remove outside bboxes + inside_inds = mosaic_bboxes.is_inside( + [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy() + mosaic_bboxes = mosaic_bboxes[inside_inds] + mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds] + mosaic_ignore_flags = mosaic_ignore_flags[inside_inds] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape[:2] + results['gt_bboxes'] = mosaic_bboxes + results['gt_bboxes_labels'] = mosaic_bboxes_labels + results['gt_ignore_flags'] = mosaic_ignore_flags + + if with_mask: + mosaic_masks = mosaic_masks[0].cat(mosaic_masks) + results['gt_masks'] = mosaic_masks[inside_inds] + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'max_cached_images={self.max_cached_images}, ' + repr_str += f'random_pop={self.random_pop})' + return repr_str + + +@TRANSFORMS.register_module() +class CachedMixUp(BaseTransform): + """Cached mixup data augmentation. + + .. code:: text + + mixup transform + +------------------------------+ + | mixup image | | + | +--------|--------+ | + | | | | | + |---------------+ | | + | | | | + | | image | | + | | | | + | | | | + | |-----------------+ | + | pad | + +------------------------------+ + + The cached mixup transform steps are as follows: + + 1. Append the results from the last transform into the cache. + 2. Another random image is picked from the cache and embedded in + the top left patch(after padding and resizing) + 3. The target of mixup transform is the weighted average of mixup + image and origin image. + + Required Keys: + + - img + - gt_bboxes (np.float32) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (bool) (optional) + - mix_results (List[dict]) + + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + + Args: + img_scale (Sequence[int]): Image output size after mixup pipeline. + The shape order should be (width, height). Defaults to (640, 640). + ratio_range (Sequence[float]): Scale ratio of mixup image. + Defaults to (0.5, 1.5). + flip_ratio (float): Horizontal flip ratio of mixup image. + Defaults to 0.5. + pad_val (int): Pad value. Defaults to 114. + max_iters (int): The maximum number of iterations. If the number of + iterations is greater than `max_iters`, but gt_bbox is still + empty, then the iteration is terminated. Defaults to 15. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + max_cached_images (int): The maximum length of the cache. The larger + the cache, the stronger the randomness of this transform. As a + rule of thumb, providing 10 caches for each image suffices for + randomness. Defaults to 20. + random_pop (bool): Whether to randomly pop a result from the cache + when the cache is full. If set to False, use FIFO popping method. + Defaults to True. + prob (float): Probability of applying this transformation. + Defaults to 1.0. + """ + + def __init__(self, + img_scale: Tuple[int, int] = (640, 640), + ratio_range: Tuple[float, float] = (0.5, 1.5), + flip_ratio: float = 0.5, + pad_val: float = 114.0, + max_iters: int = 15, + bbox_clip_border: bool = True, + max_cached_images: int = 20, + random_pop: bool = True, + prob: float = 1.0) -> None: + assert isinstance(img_scale, tuple) + assert max_cached_images >= 2, 'The length of cache must >= 2, ' \ + f'but got {max_cached_images}.' + assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \ + f'got {prob}.' + self.dynamic_scale = img_scale + self.ratio_range = ratio_range + self.flip_ratio = flip_ratio + self.pad_val = pad_val + self.max_iters = max_iters + self.bbox_clip_border = bbox_clip_border + self.results_cache = [] + + self.max_cached_images = max_cached_images + self.random_pop = random_pop + self.prob = prob + + @cache_randomness + def get_indexes(self, cache: list) -> int: + """Call function to collect indexes. + + Args: + cache (list): The result cache. + + Returns: + int: index. + """ + + for i in range(self.max_iters): + index = random.randint(0, len(cache) - 1) + gt_bboxes_i = cache[index]['gt_bboxes'] + if len(gt_bboxes_i) != 0: + break + return index + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """MixUp transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + # cache and pop images + self.results_cache.append(copy.deepcopy(results)) + if len(self.results_cache) > self.max_cached_images: + if self.random_pop: + index = random.randint(0, len(self.results_cache) - 1) + else: + index = 0 + self.results_cache.pop(index) + + if len(self.results_cache) <= 1: + return results + + if random.uniform(0, 1) > self.prob: + return results + + index = self.get_indexes(self.results_cache) + retrieve_results = copy.deepcopy(self.results_cache[index]) + + # TODO: refactor mixup to reuse these code. + if retrieve_results['gt_bboxes'].shape[0] == 0: + # empty bbox + return results + + retrieve_img = retrieve_results['img'] + with_mask = True if 'gt_masks' in results else False + + jit_factor = random.uniform(*self.ratio_range) + is_flip = random.uniform(0, 1) > self.flip_ratio + + if len(retrieve_img.shape) == 3: + out_img = np.ones( + (self.dynamic_scale[1], self.dynamic_scale[0], 3), + dtype=retrieve_img.dtype) * self.pad_val + else: + out_img = np.ones( + self.dynamic_scale[::-1], + dtype=retrieve_img.dtype) * self.pad_val + + # 1. keep_ratio resize + scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0], + self.dynamic_scale[0] / retrieve_img.shape[1]) + retrieve_img = mmcv.imresize( + retrieve_img, (int(retrieve_img.shape[1] * scale_ratio), + int(retrieve_img.shape[0] * scale_ratio))) + + # 2. paste + out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img + + # 3. scale jit + scale_ratio *= jit_factor + out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor), + int(out_img.shape[0] * jit_factor))) + + # 4. flip + if is_flip: + out_img = out_img[:, ::-1, :] + + # 5. random crop + ori_img = results['img'] + origin_h, origin_w = out_img.shape[:2] + target_h, target_w = ori_img.shape[:2] + padded_img = np.ones((max(origin_h, target_h), max( + origin_w, target_w), 3)) * self.pad_val + padded_img = padded_img.astype(np.uint8) + padded_img[:origin_h, :origin_w] = out_img + + x_offset, y_offset = 0, 0 + if padded_img.shape[0] > target_h: + y_offset = random.randint(0, padded_img.shape[0] - target_h) + if padded_img.shape[1] > target_w: + x_offset = random.randint(0, padded_img.shape[1] - target_w) + padded_cropped_img = padded_img[y_offset:y_offset + target_h, + x_offset:x_offset + target_w] + + # 6. adjust bbox + retrieve_gt_bboxes = retrieve_results['gt_bboxes'] + retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio]) + if with_mask: + retrieve_gt_masks = retrieve_results['gt_masks'].rescale( + scale_ratio) + + if self.bbox_clip_border: + retrieve_gt_bboxes.clip_([origin_h, origin_w]) + + if is_flip: + retrieve_gt_bboxes.flip_([origin_h, origin_w], + direction='horizontal') + if with_mask: + retrieve_gt_masks = retrieve_gt_masks.flip() + + # 7. filter + cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone() + cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset]) + if with_mask: + retrieve_gt_masks = retrieve_gt_masks.translate( + out_shape=(target_h, target_w), + offset=-x_offset, + direction='horizontal') + retrieve_gt_masks = retrieve_gt_masks.translate( + out_shape=(target_h, target_w), + offset=-y_offset, + direction='vertical') + + if self.bbox_clip_border: + cp_retrieve_gt_bboxes.clip_([target_h, target_w]) + + # 8. mix up + ori_img = ori_img.astype(np.float32) + mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32) + + retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels'] + retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags'] + + mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat( + (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0) + mixup_gt_bboxes_labels = np.concatenate( + (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0) + mixup_gt_ignore_flags = np.concatenate( + (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0) + if with_mask: + mixup_gt_masks = retrieve_gt_masks.cat( + [results['gt_masks'], retrieve_gt_masks]) + + # remove outside bbox + inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy() + mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] + mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds] + mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds] + if with_mask: + mixup_gt_masks = mixup_gt_masks[inside_inds] + + results['img'] = mixup_img.astype(np.uint8) + results['img_shape'] = mixup_img.shape[:2] + results['gt_bboxes'] = mixup_gt_bboxes + results['gt_bboxes_labels'] = mixup_gt_bboxes_labels + results['gt_ignore_flags'] = mixup_gt_ignore_flags + if with_mask: + results['gt_masks'] = mixup_gt_masks + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(dynamic_scale={self.dynamic_scale}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'flip_ratio={self.flip_ratio}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'max_iters={self.max_iters}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border}, ' + repr_str += f'max_cached_images={self.max_cached_images}, ' + repr_str += f'random_pop={self.random_pop}, ' + repr_str += f'prob={self.prob})' + return repr_str diff --git a/mmdet/datasets/transforms/wrappers.py b/mmdet/datasets/transforms/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..3a17711c06bfbd4dc0038dce9ea7796d1476c37e --- /dev/null +++ b/mmdet/datasets/transforms/wrappers.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +from mmcv.transforms import BaseTransform, Compose +from mmcv.transforms.utils import cache_random_params, cache_randomness + +from mmdet.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class MultiBranch(BaseTransform): + r"""Multiple branch pipeline wrapper. + + Generate multiple data-augmented versions of the same image. + `MultiBranch` needs to specify the branch names of all + pipelines of the dataset, perform corresponding data augmentation + for the current branch, and return None for other branches, + which ensures the consistency of return format across + different samples. + + Args: + branch_field (list): List of branch names. + branch_pipelines (dict): Dict of different pipeline configs + to be composed. + + Examples: + >>> branch_field = ['sup', 'unsup_teacher', 'unsup_student'] + >>> sup_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='LoadAnnotations', with_bbox=True), + >>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), + >>> dict(type='RandomFlip', prob=0.5), + >>> dict( + >>> type='MultiBranch', + >>> branch_field=branch_field, + >>> sup=dict(type='PackDetInputs')) + >>> ] + >>> weak_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='LoadAnnotations', with_bbox=True), + >>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), + >>> dict(type='RandomFlip', prob=0.0), + >>> dict( + >>> type='MultiBranch', + >>> branch_field=branch_field, + >>> sup=dict(type='PackDetInputs')) + >>> ] + >>> strong_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='LoadAnnotations', with_bbox=True), + >>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), + >>> dict(type='RandomFlip', prob=1.0), + >>> dict( + >>> type='MultiBranch', + >>> branch_field=branch_field, + >>> sup=dict(type='PackDetInputs')) + >>> ] + >>> unsup_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='LoadEmptyAnnotations'), + >>> dict( + >>> type='MultiBranch', + >>> branch_field=branch_field, + >>> unsup_teacher=weak_pipeline, + >>> unsup_student=strong_pipeline) + >>> ] + >>> from mmcv.transforms import Compose + >>> sup_branch = Compose(sup_pipeline) + >>> unsup_branch = Compose(unsup_pipeline) + >>> print(sup_branch) + >>> Compose( + >>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa + >>> LoadAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, poly2mask=True, imdecode_backend='cv2') # noqa + >>> Resize(scale=(1333, 800), scale_factor=None, keep_ratio=True, clip_object_border=True), backend=cv2), interpolation=bilinear) # noqa + >>> RandomFlip(prob=0.5, direction=horizontal) + >>> MultiBranch(branch_pipelines=['sup']) + >>> ) + >>> print(unsup_branch) + >>> Compose( + >>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa + >>> LoadEmptyAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, seg_ignore_label=255) # noqa + >>> MultiBranch(branch_pipelines=['unsup_teacher', 'unsup_student']) + >>> ) + """ + + def __init__(self, branch_field: List[str], + **branch_pipelines: dict) -> None: + self.branch_field = branch_field + self.branch_pipelines = { + branch: Compose(pipeline) + for branch, pipeline in branch_pipelines.items() + } + + def transform(self, results: dict) -> dict: + """Transform function to apply transforms sequentially. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: + + - 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of + models from different branches. + - 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation + info of the sample from different branches. + """ + + multi_results = {} + for branch in self.branch_field: + multi_results[branch] = {'inputs': None, 'data_samples': None} + for branch, pipeline in self.branch_pipelines.items(): + branch_results = pipeline(copy.deepcopy(results)) + # If one branch pipeline returns None, + # it will sample another data from dataset. + if branch_results is None: + return None + multi_results[branch] = branch_results + + format_results = {} + for branch, results in multi_results.items(): + for key in results.keys(): + if format_results.get(key, None) is None: + format_results[key] = {branch: results[key]} + else: + format_results[key][branch] = results[key] + return format_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(branch_pipelines={list(self.branch_pipelines.keys())})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomOrder(Compose): + """Shuffle the transform Sequence.""" + + @cache_randomness + def _random_permutation(self): + return np.random.permutation(len(self.transforms)) + + def transform(self, results: Dict) -> Optional[Dict]: + """Transform function to apply transforms in random order. + + Args: + results (dict): A result dict contains the results to transform. + + Returns: + dict or None: Transformed results. + """ + inds = self._random_permutation() + for idx in inds: + t = self.transforms[idx] + results = t(results) + if results is None: + return None + return results + + def __repr__(self): + """Compute the string representation.""" + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += f'{t.__class__.__name__}, ' + format_string += ')' + return format_string + + +@TRANSFORMS.register_module() +class ProposalBroadcaster(BaseTransform): + """A transform wrapper to apply the wrapped transforms to process both + `gt_bboxes` and `proposals` without adding any codes. It will do the + following steps: + + 1. Scatter the broadcasting targets to a list of inputs of the wrapped + transforms. The type of the list should be list[dict, dict], which + the first is the original inputs, the second is the processing + results that `gt_bboxes` being rewritten by the `proposals`. + 2. Apply ``self.transforms``, with same random parameters, which is + sharing with a context manager. The type of the outputs is a + list[dict, dict]. + 3. Gather the outputs, update the `proposals` in the first item of + the outputs with the `gt_bboxes` in the second . + + Args: + transforms (list, optional): Sequence of transform + object or config dict to be wrapped. Defaults to []. + + Note: The `TransformBroadcaster` in MMCV can achieve the same operation as + `ProposalBroadcaster`, but need to set more complex parameters. + + Examples: + >>> pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='LoadProposals', num_max_proposals=2000), + >>> dict(type='LoadAnnotations', with_bbox=True), + >>> dict( + >>> type='ProposalBroadcaster', + >>> transforms=[ + >>> dict(type='Resize', scale=(1333, 800), + >>> keep_ratio=True), + >>> dict(type='RandomFlip', prob=0.5), + >>> ]), + >>> dict(type='PackDetInputs')] + """ + + def __init__(self, transforms: List[Union[dict, Callable]] = []) -> None: + self.transforms = Compose(transforms) + + def transform(self, results: dict) -> dict: + """Apply wrapped transform functions to process both `gt_bboxes` and + `proposals`. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + assert results.get('proposals', None) is not None, \ + '`proposals` should be in the results, please delete ' \ + '`ProposalBroadcaster` in your configs, or check whether ' \ + 'you have load proposals successfully.' + + inputs = self._process_input(results) + outputs = self._apply_transforms(inputs) + outputs = self._process_output(outputs) + return outputs + + def _process_input(self, data: dict) -> list: + """Scatter the broadcasting targets to a list of inputs of the wrapped + transforms. + + Args: + data (dict): The original input data. + + Returns: + list[dict]: A list of input data. + """ + cp_data = copy.deepcopy(data) + cp_data['gt_bboxes'] = cp_data['proposals'] + scatters = [data, cp_data] + return scatters + + def _apply_transforms(self, inputs: list) -> list: + """Apply ``self.transforms``. + + Args: + inputs (list[dict, dict]): list of input data. + + Returns: + list[dict]: The output of the wrapped pipeline. + """ + assert len(inputs) == 2 + ctx = cache_random_params + with ctx(self.transforms): + output_scatters = [self.transforms(_input) for _input in inputs] + return output_scatters + + def _process_output(self, output_scatters: list) -> dict: + """Gathering and renaming data items. + + Args: + output_scatters (list[dict, dict]): The output of the wrapped + pipeline. + + Returns: + dict: Updated result dict. + """ + assert isinstance(output_scatters, list) and \ + isinstance(output_scatters[0], dict) and \ + len(output_scatters) == 2 + outputs = output_scatters[0] + outputs['proposals'] = output_scatters[1]['gt_bboxes'] + return outputs diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d794eb4b06ec9db56ff3a5fc7b817d1d9332a989 --- /dev/null +++ b/mmdet/datasets/utils.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmcv.transforms import LoadImageFromFile + +from mmdet.datasets.transforms import LoadAnnotations, LoadPanopticAnnotations +from mmdet.registry import TRANSFORMS + + +def get_loading_pipeline(pipeline): + """Only keep loading image and annotations related configuration. + + Args: + pipeline (list[dict]): Data pipeline configs. + + Returns: + list[dict]: The new pipeline list with only keep + loading image and annotations related configuration. + + Examples: + >>> pipelines = [ + ... dict(type='LoadImageFromFile'), + ... dict(type='LoadAnnotations', with_bbox=True), + ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + ... dict(type='RandomFlip', flip_ratio=0.5), + ... dict(type='Normalize', **img_norm_cfg), + ... dict(type='Pad', size_divisor=32), + ... dict(type='DefaultFormatBundle'), + ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) + ... ] + >>> expected_pipelines = [ + ... dict(type='LoadImageFromFile'), + ... dict(type='LoadAnnotations', with_bbox=True) + ... ] + >>> assert expected_pipelines ==\ + ... get_loading_pipeline(pipelines) + """ + loading_pipeline_cfg = [] + for cfg in pipeline: + obj_cls = TRANSFORMS.get(cfg['type']) + # TODO:use more elegant way to distinguish loading modules + if obj_cls is not None and obj_cls in (LoadImageFromFile, + LoadAnnotations, + LoadPanopticAnnotations): + loading_pipeline_cfg.append(cfg) + assert len(loading_pipeline_cfg) == 2, \ + 'The data pipeline in your config file must include ' \ + 'loading image and annotations related pipeline.' + return loading_pipeline_cfg diff --git a/mmdet/datasets/v3det.py b/mmdet/datasets/v3det.py new file mode 100644 index 0000000000000000000000000000000000000000..25bfe3bc718841143653c54954240186c3376955 --- /dev/null +++ b/mmdet/datasets/v3det.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path +from typing import Optional + +import mmengine + +from mmdet.registry import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class V3DetDataset(CocoDataset): + """Dataset for V3Det.""" + + METAINFO = { + 'classes': None, + 'palette': None, + } + + def __init__( + self, + *args, + metainfo: Optional[dict] = None, + data_root: str = '', + label_file='annotations/category_name_13204_v3det_2023_v1.txt', # noqa + **kwargs) -> None: + class_names = tuple( + mmengine.list_from_file(os.path.join(data_root, label_file))) + if metainfo is None: + metainfo = {'classes': class_names} + super().__init__( + *args, data_root=data_root, metainfo=metainfo, **kwargs) diff --git a/mmdet/datasets/voc.py b/mmdet/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..65e73f2f0bd4f2b16d5237cd3b5f342e44cf0438 --- /dev/null +++ b/mmdet/datasets/voc.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .xml_style import XMLDataset + + +@DATASETS.register_module() +class VOCDataset(XMLDataset): + """Dataset for PASCAL VOC.""" + + METAINFO = { + 'classes': + ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', + 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', + 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192), + (197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255), + (153, 69, 1), (120, 166, 157), (0, 182, 199), + (0, 226, 252), (182, 182, 255), (0, 0, 230), (220, 20, 60), + (163, 255, 0), (0, 82, 0), (3, 95, 161), (0, 80, 100), + (183, 130, 88)] + } + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if 'VOC2007' in self.sub_data_root: + self._metainfo['dataset_type'] = 'VOC2007' + elif 'VOC2012' in self.sub_data_root: + self._metainfo['dataset_type'] = 'VOC2012' + else: + self._metainfo['dataset_type'] = None diff --git a/mmdet/datasets/wider_face.py b/mmdet/datasets/wider_face.py new file mode 100644 index 0000000000000000000000000000000000000000..62c7fff869ab970b6f96908a998ba6feb25ea205 --- /dev/null +++ b/mmdet/datasets/wider_face.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import xml.etree.ElementTree as ET + +from mmengine.dist import is_main_process +from mmengine.fileio import get_local_path, list_from_file +from mmengine.utils import ProgressBar + +from mmdet.registry import DATASETS +from mmdet.utils.typing_utils import List, Union +from .xml_style import XMLDataset + + +@DATASETS.register_module() +class WIDERFaceDataset(XMLDataset): + """Reader for the WIDER Face dataset in PASCAL VOC format. + + Conversion scripts can be found in + https://github.com/sovrasov/wider-face-pascal-voc-annotations + """ + METAINFO = {'classes': ('face', ), 'palette': [(0, 255, 0)]} + + def load_data_list(self) -> List[dict]: + """Load annotation from XML style ann_file. + + Returns: + list[dict]: Annotation info from XML file. + """ + assert self._metainfo.get('classes', None) is not None, \ + 'classes in `XMLDataset` can not be None.' + self.cat2label = { + cat: i + for i, cat in enumerate(self._metainfo['classes']) + } + + data_list = [] + img_ids = list_from_file(self.ann_file, backend_args=self.backend_args) + + # loading process takes around 10 mins + if is_main_process(): + prog_bar = ProgressBar(len(img_ids)) + + for img_id in img_ids: + raw_img_info = {} + raw_img_info['img_id'] = img_id + raw_img_info['file_name'] = f'{img_id}.jpg' + parsed_data_info = self.parse_data_info(raw_img_info) + data_list.append(parsed_data_info) + + if is_main_process(): + prog_bar.update() + return data_list + + def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + Args: + img_info (dict): Raw image information, usually it includes + `img_id`, `file_name`, and `xml_path`. + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + data_info = {} + img_id = img_info['img_id'] + xml_path = osp.join(self.data_prefix['img'], 'Annotations', + f'{img_id}.xml') + data_info['img_id'] = img_id + data_info['xml_path'] = xml_path + + # deal with xml file + with get_local_path( + xml_path, backend_args=self.backend_args) as local_path: + raw_ann_info = ET.parse(local_path) + root = raw_ann_info.getroot() + size = root.find('size') + width = int(size.find('width').text) + height = int(size.find('height').text) + folder = root.find('folder').text + img_path = osp.join(self.data_prefix['img'], folder, + img_info['file_name']) + data_info['img_path'] = img_path + + data_info['height'] = height + data_info['width'] = width + + # Coordinates are in range [0, width - 1 or height - 1] + data_info['instances'] = self._parse_instance_info( + raw_ann_info, minus_one=False) + return data_info diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py new file mode 100644 index 0000000000000000000000000000000000000000..06045ea0092238abdac9622511b336586858f8f5 --- /dev/null +++ b/mmdet/datasets/xml_style.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import xml.etree.ElementTree as ET +from typing import List, Optional, Union + +import mmcv +from mmengine.fileio import get, get_local_path, list_from_file + +from mmdet.registry import DATASETS +from .base_det_dataset import BaseDetDataset + + +@DATASETS.register_module() +class XMLDataset(BaseDetDataset): + """XML dataset for detection. + + Args: + img_subdir (str): Subdir where images are stored. Default: JPEGImages. + ann_subdir (str): Subdir where annotations are. Default: Annotations. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + + def __init__(self, + img_subdir: str = 'JPEGImages', + ann_subdir: str = 'Annotations', + **kwargs) -> None: + self.img_subdir = img_subdir + self.ann_subdir = ann_subdir + super().__init__(**kwargs) + + @property + def sub_data_root(self) -> str: + """Return the sub data root.""" + return self.data_prefix.get('sub_data_root', '') + + def load_data_list(self) -> List[dict]: + """Load annotation from XML style ann_file. + + Returns: + list[dict]: Annotation info from XML file. + """ + assert self._metainfo.get('classes', None) is not None, \ + '`classes` in `XMLDataset` can not be None.' + self.cat2label = { + cat: i + for i, cat in enumerate(self._metainfo['classes']) + } + + data_list = [] + img_ids = list_from_file(self.ann_file, backend_args=self.backend_args) + for img_id in img_ids: + file_name = osp.join(self.img_subdir, f'{img_id}.jpg') + xml_path = osp.join(self.sub_data_root, self.ann_subdir, + f'{img_id}.xml') + + raw_img_info = {} + raw_img_info['img_id'] = img_id + raw_img_info['file_name'] = file_name + raw_img_info['xml_path'] = xml_path + + parsed_data_info = self.parse_data_info(raw_img_info) + data_list.append(parsed_data_info) + return data_list + + @property + def bbox_min_size(self) -> Optional[int]: + """Return the minimum size of bounding boxes in the images.""" + if self.filter_cfg is not None: + return self.filter_cfg.get('bbox_min_size', None) + else: + return None + + def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + Args: + img_info (dict): Raw image information, usually it includes + `img_id`, `file_name`, and `xml_path`. + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + data_info = {} + img_path = osp.join(self.sub_data_root, img_info['file_name']) + data_info['img_path'] = img_path + data_info['img_id'] = img_info['img_id'] + data_info['xml_path'] = img_info['xml_path'] + + # deal with xml file + with get_local_path( + img_info['xml_path'], + backend_args=self.backend_args) as local_path: + raw_ann_info = ET.parse(local_path) + root = raw_ann_info.getroot() + size = root.find('size') + if size is not None: + width = int(size.find('width').text) + height = int(size.find('height').text) + else: + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, backend='cv2') + height, width = img.shape[:2] + del img, img_bytes + + data_info['height'] = height + data_info['width'] = width + + data_info['instances'] = self._parse_instance_info( + raw_ann_info, minus_one=True) + + return data_info + + def _parse_instance_info(self, + raw_ann_info: ET, + minus_one: bool = True) -> List[dict]: + """parse instance information. + + Args: + raw_ann_info (ElementTree): ElementTree object. + minus_one (bool): Whether to subtract 1 from the coordinates. + Defaults to True. + + Returns: + List[dict]: List of instances. + """ + instances = [] + for obj in raw_ann_info.findall('object'): + instance = {} + name = obj.find('name').text + if name not in self._metainfo['classes']: + continue + difficult = obj.find('difficult') + difficult = 0 if difficult is None else int(difficult.text) + bnd_box = obj.find('bndbox') + bbox = [ + int(float(bnd_box.find('xmin').text)), + int(float(bnd_box.find('ymin').text)), + int(float(bnd_box.find('xmax').text)), + int(float(bnd_box.find('ymax').text)) + ] + + # VOC needs to subtract 1 from the coordinates + if minus_one: + bbox = [x - 1 for x in bbox] + + ignore = False + if self.bbox_min_size is not None: + assert not self.test_mode + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + if w < self.bbox_min_size or h < self.bbox_min_size: + ignore = True + if difficult or ignore: + instance['ignore_flag'] = 1 + else: + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = self.cat2label[name] + instances.append(instance) + return instances + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ + if self.filter_cfg is not None else False + min_size = self.filter_cfg.get('min_size', 0) \ + if self.filter_cfg is not None else 0 + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + width = data_info['width'] + height = data_info['height'] + if filter_empty_gt and len(data_info['instances']) == 0: + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/mmdet/datasets/youtube_vis_dataset.py b/mmdet/datasets/youtube_vis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..38c3d3909f1b8fd795c181546094056c54c9c4b2 --- /dev/null +++ b/mmdet/datasets/youtube_vis_dataset.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .base_video_dataset import BaseVideoDataset + + +@DATASETS.register_module() +class YouTubeVISDataset(BaseVideoDataset): + """YouTube VIS dataset for video instance segmentation. + + Args: + dataset_version (str): Select dataset year version. + """ + + def __init__(self, dataset_version: str, *args, **kwargs): + self.set_dataset_classes(dataset_version) + super().__init__(*args, **kwargs) + + @classmethod + def set_dataset_classes(cls, dataset_version: str) -> None: + """Pass the category of the corresponding year to metainfo. + + Args: + dataset_version (str): Select dataset year version. + """ + classes_2019_version = ('person', 'giant_panda', 'lizard', 'parrot', + 'skateboard', 'sedan', 'ape', 'dog', 'snake', + 'monkey', 'hand', 'rabbit', 'duck', 'cat', + 'cow', 'fish', 'train', 'horse', 'turtle', + 'bear', 'motorbike', 'giraffe', 'leopard', + 'fox', 'deer', 'owl', 'surfboard', 'airplane', + 'truck', 'zebra', 'tiger', 'elephant', + 'snowboard', 'boat', 'shark', 'mouse', 'frog', + 'eagle', 'earless_seal', 'tennis_racket') + + classes_2021_version = ('airplane', 'bear', 'bird', 'boat', 'car', + 'cat', 'cow', 'deer', 'dog', 'duck', + 'earless_seal', 'elephant', 'fish', + 'flying_disc', 'fox', 'frog', 'giant_panda', + 'giraffe', 'horse', 'leopard', 'lizard', + 'monkey', 'motorbike', 'mouse', 'parrot', + 'person', 'rabbit', 'shark', 'skateboard', + 'snake', 'snowboard', 'squirrel', 'surfboard', + 'tennis_racket', 'tiger', 'train', 'truck', + 'turtle', 'whale', 'zebra') + + if dataset_version == '2019': + cls.METAINFO = dict(classes=classes_2019_version) + elif dataset_version == '2021': + cls.METAINFO = dict(classes=classes_2021_version) + else: + raise NotImplementedError('Not supported YouTubeVIS dataset' + f'version: {dataset_version}') diff --git a/mmdet/engine/.DS_Store b/mmdet/engine/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5ebe30b0e1168acef4cdac75e7bd935121afea83 Binary files /dev/null and b/mmdet/engine/.DS_Store differ diff --git a/mmdet/engine/__init__.py b/mmdet/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c91ace6ffa20948af572d3a0fd594e8a0b091775 --- /dev/null +++ b/mmdet/engine/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401, F403 +from .optimizers import * # noqa: F401, F403 +from .runner import * # noqa: F401, F403 +from .schedulers import * # noqa: F401, F403 diff --git a/mmdet/engine/hooks/__init__.py b/mmdet/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc03693b24fec39c430717348eb8f7947ed90ee --- /dev/null +++ b/mmdet/engine/hooks/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .checkloss_hook import CheckInvalidLossHook +from .mean_teacher_hook import MeanTeacherHook +from .memory_profiler_hook import MemoryProfilerHook +from .num_class_check_hook import NumClassCheckHook +from .pipeline_switch_hook import PipelineSwitchHook +from .set_epoch_info_hook import SetEpochInfoHook +from .sync_norm_hook import SyncNormHook +from .utils import trigger_visualization_hook +from .visualization_hook import DetVisualizationHook, TrackVisualizationHook +from .yolox_mode_switch_hook import YOLOXModeSwitchHook + +__all__ = [ + 'YOLOXModeSwitchHook', 'SyncNormHook', 'CheckInvalidLossHook', + 'SetEpochInfoHook', 'MemoryProfilerHook', 'DetVisualizationHook', + 'NumClassCheckHook', 'MeanTeacherHook', 'trigger_visualization_hook', + 'PipelineSwitchHook', 'TrackVisualizationHook' +] diff --git a/mmdet/engine/hooks/checkloss_hook.py b/mmdet/engine/hooks/checkloss_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3ebfcd5dfcd7ae329399723d3a9c0fc0a0d722ef --- /dev/null +++ b/mmdet/engine/hooks/checkloss_hook.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class CheckInvalidLossHook(Hook): + """Check invalid loss hook. + + This hook will regularly check whether the loss is valid + during training. + + Args: + interval (int): Checking interval (every k iterations). + Default: 50. + """ + + def __init__(self, interval: int = 50) -> None: + self.interval = interval + + def after_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None, + outputs: Optional[dict] = None) -> None: + """Regularly check whether the loss is valid every n iterations. + + Args: + runner (:obj:`Runner`): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (dict, Optional): Data from dataloader. + Defaults to None. + outputs (dict, Optional): Outputs from model. Defaults to None. + """ + if self.every_n_train_iters(runner, self.interval): + assert torch.isfinite(outputs['loss']), \ + runner.logger.info('loss become infinite or NaN!') diff --git a/mmdet/engine/hooks/mean_teacher_hook.py b/mmdet/engine/hooks/mean_teacher_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b924c0a5934248d05e7ce1add50e7574b739b9c7 --- /dev/null +++ b/mmdet/engine/hooks/mean_teacher_hook.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper +from mmengine.runner import Runner + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class MeanTeacherHook(Hook): + """Mean Teacher Hook. + + Mean Teacher is an efficient semi-supervised learning method in + `Mean Teacher `_. + This method requires two models with exactly the same structure, + as the student model and the teacher model, respectively. + The student model updates the parameters through gradient descent, + and the teacher model updates the parameters through + exponential moving average of the student model. + Compared with the student model, the teacher model + is smoother and accumulates more knowledge. + + Args: + momentum (float): The momentum used for updating teacher's parameter. + Teacher's parameter are updated with the formula: + `teacher = (1-momentum) * teacher + momentum * student`. + Defaults to 0.001. + interval (int): Update teacher's parameter every interval iteration. + Defaults to 1. + skip_buffers (bool): Whether to skip the model buffers, such as + batchnorm running stats (running_mean, running_var), it does not + perform the ema operation. Default to True. + """ + + def __init__(self, + momentum: float = 0.001, + interval: int = 1, + skip_buffer=True) -> None: + assert 0 < momentum < 1 + self.momentum = momentum + self.interval = interval + self.skip_buffers = skip_buffer + + def before_train(self, runner: Runner) -> None: + """To check that teacher model and student model exist.""" + model = runner.model + if is_model_wrapper(model): + model = model.module + assert hasattr(model, 'teacher') + assert hasattr(model, 'student') + # only do it at initial stage + if runner.iter == 0: + self.momentum_update(model, 1) + + def after_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None, + outputs: Optional[dict] = None) -> None: + """Update teacher's parameter every self.interval iterations.""" + if (runner.iter + 1) % self.interval != 0: + return + model = runner.model + if is_model_wrapper(model): + model = model.module + self.momentum_update(model, self.momentum) + + def momentum_update(self, model: nn.Module, momentum: float) -> None: + """Compute the moving average of the parameters using exponential + moving average.""" + if self.skip_buffers: + for (src_name, src_parm), (dst_name, dst_parm) in zip( + model.student.named_parameters(), + model.teacher.named_parameters()): + dst_parm.data.mul_(1 - momentum).add_( + src_parm.data, alpha=momentum) + else: + for (src_parm, + dst_parm) in zip(model.student.state_dict().values(), + model.teacher.state_dict().values()): + # exclude num_tracking + if dst_parm.dtype.is_floating_point: + dst_parm.data.mul_(1 - momentum).add_( + src_parm.data, alpha=momentum) diff --git a/mmdet/engine/hooks/memory_profiler_hook.py b/mmdet/engine/hooks/memory_profiler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3dcdcae0b669ade46026d28c46b35f35d90b504b --- /dev/null +++ b/mmdet/engine/hooks/memory_profiler_hook.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmdet.registry import HOOKS +from mmdet.structures import DetDataSample + + +@HOOKS.register_module() +class MemoryProfilerHook(Hook): + """Memory profiler hook recording memory information including virtual + memory, swap memory, and the memory of the current process. + + Args: + interval (int): Checking interval (every k iterations). + Default: 50. + """ + + def __init__(self, interval: int = 50) -> None: + try: + from psutil import swap_memory, virtual_memory + self._swap_memory = swap_memory + self._virtual_memory = virtual_memory + except ImportError: + raise ImportError('psutil is not installed, please install it by: ' + 'pip install psutil') + + try: + from memory_profiler import memory_usage + self._memory_usage = memory_usage + except ImportError: + raise ImportError( + 'memory_profiler is not installed, please install it by: ' + 'pip install memory_profiler') + + self.interval = interval + + def _record_memory_information(self, runner: Runner) -> None: + """Regularly record memory information. + + Args: + runner (:obj:`Runner`): The runner of the training or evaluation + process. + """ + # in Byte + virtual_memory = self._virtual_memory() + swap_memory = self._swap_memory() + # in MB + process_memory = self._memory_usage()[0] + factor = 1024 * 1024 + runner.logger.info( + 'Memory information ' + 'available_memory: ' + f'{round(virtual_memory.available / factor)} MB, ' + 'used_memory: ' + f'{round(virtual_memory.used / factor)} MB, ' + f'memory_utilization: {virtual_memory.percent} %, ' + 'available_swap_memory: ' + f'{round((swap_memory.total - swap_memory.used) / factor)}' + ' MB, ' + f'used_swap_memory: {round(swap_memory.used / factor)} MB, ' + f'swap_memory_utilization: {swap_memory.percent} %, ' + 'current_process_memory: ' + f'{round(process_memory)} MB') + + def after_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None, + outputs: Optional[dict] = None) -> None: + """Regularly record memory information. + + Args: + runner (:obj:`Runner`): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (dict, optional): Data from dataloader. + Defaults to None. + outputs (dict, optional): Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + self._record_memory_information(runner) + + def after_val_iter( + self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None, + outputs: Optional[Sequence[DetDataSample]] = None) -> None: + """Regularly record memory information. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict, optional): Data from dataloader. + Defaults to None. + outputs (Sequence[:obj:`DetDataSample`], optional): + Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + self._record_memory_information(runner) + + def after_test_iter( + self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None, + outputs: Optional[Sequence[DetDataSample]] = None) -> None: + """Regularly record memory information. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict, optional): Data from dataloader. + Defaults to None. + outputs (Sequence[:obj:`DetDataSample`], optional): + Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + self._record_memory_information(runner) diff --git a/mmdet/engine/hooks/num_class_check_hook.py b/mmdet/engine/hooks/num_class_check_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..6588473acfbd3ffe8e80eb163aa7ee449332e6b8 --- /dev/null +++ b/mmdet/engine/hooks/num_class_check_hook.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import VGG +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class NumClassCheckHook(Hook): + """Check whether the `num_classes` in head matches the length of `classes` + in `dataset.metainfo`.""" + + def _check_head(self, runner: Runner, mode: str) -> None: + """Check whether the `num_classes` in head matches the length of + `classes` in `dataset.metainfo`. + + Args: + runner (:obj:`Runner`): The runner of the training or evaluation + process. + """ + assert mode in ['train', 'val'] + model = runner.model + dataset = runner.train_dataloader.dataset if mode == 'train' else \ + runner.val_dataloader.dataset + if dataset.metainfo.get('classes', None) is None: + runner.logger.warning( + f'Please set `classes` ' + f'in the {dataset.__class__.__name__} `metainfo` and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + classes = dataset.metainfo['classes'] + assert type(classes) is not str, \ + (f'`classes` in {dataset.__class__.__name__}' + f'should be a tuple of str.' + f'Add comma if number of classes is 1 as ' + f'classes = ({classes},)') + from mmdet.models.roi_heads.mask_heads import FusedSemanticHead + for name, module in model.named_modules(): + if hasattr(module, 'num_classes') and not name.endswith( + 'rpn_head') and not isinstance( + module, (VGG, FusedSemanticHead)): + assert module.num_classes == len(classes), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of `classes` ' + f'{len(classes)}) in ' + f'{dataset.__class__.__name__}') + + def before_train_epoch(self, runner: Runner) -> None: + """Check whether the training dataset is compatible with head. + + Args: + runner (:obj:`Runner`): The runner of the training or evaluation + process. + """ + self._check_head(runner, 'train') + + def before_val_epoch(self, runner: Runner) -> None: + """Check whether the dataset in val epoch is compatible with head. + + Args: + runner (:obj:`Runner`): The runner of the training or evaluation + process. + """ + self._check_head(runner, 'val') diff --git a/mmdet/engine/hooks/pipeline_switch_hook.py b/mmdet/engine/hooks/pipeline_switch_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..a5abd897803b11793ebace86e45aac8f59938545 --- /dev/null +++ b/mmdet/engine/hooks/pipeline_switch_hook.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import Compose +from mmengine.hooks import Hook + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class PipelineSwitchHook(Hook): + """Switch data pipeline at switch_epoch. + + Args: + switch_epoch (int): switch pipeline at this epoch. + switch_pipeline (list[dict]): the pipeline to switch to. + """ + + def __init__(self, switch_epoch, switch_pipeline): + self.switch_epoch = switch_epoch + self.switch_pipeline = switch_pipeline + self._restart_dataloader = False + self._has_switched = False + + def before_train_epoch(self, runner): + """switch pipeline.""" + epoch = runner.epoch + train_loader = runner.train_dataloader + if epoch >= self.switch_epoch and not self._has_switched: + runner.logger.info('Switch pipeline now!') + # The dataset pipeline cannot be updated when persistent_workers + # is True, so we need to force the dataloader's multi-process + # restart. This is a very hacky approach. + train_loader.dataset.pipeline = Compose(self.switch_pipeline) + if hasattr(train_loader, 'persistent_workers' + ) and train_loader.persistent_workers is True: + train_loader._DataLoader__initialized = False + train_loader._iterator = None + self._restart_dataloader = True + self._has_switched = True + else: + # Once the restart is complete, we need to restore + # the initialization flag. + if self._restart_dataloader: + train_loader._DataLoader__initialized = True diff --git a/mmdet/engine/hooks/set_epoch_info_hook.py b/mmdet/engine/hooks/set_epoch_info_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..183f3167445dc0818e4fa37bdd2049d3876ed031 --- /dev/null +++ b/mmdet/engine/hooks/set_epoch_info_hook.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import Hook +from mmengine.model.wrappers import is_model_wrapper + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class SetEpochInfoHook(Hook): + """Set runner's epoch information to the model.""" + + def before_train_epoch(self, runner): + epoch = runner.epoch + model = runner.model + if is_model_wrapper(model): + model = model.module + model.set_epoch(epoch) diff --git a/mmdet/engine/hooks/sync_norm_hook.py b/mmdet/engine/hooks/sync_norm_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..a1734380c83157c911568098abfce761fb3c9a1f --- /dev/null +++ b/mmdet/engine/hooks/sync_norm_hook.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +from mmengine.dist import get_dist_info +from mmengine.hooks import Hook +from torch import nn + +from mmdet.registry import HOOKS +from mmdet.utils import all_reduce_dict + + +def get_norm_states(module: nn.Module) -> OrderedDict: + """Get the state_dict of batch norms in the module.""" + async_norm_states = OrderedDict() + for name, child in module.named_modules(): + if isinstance(child, nn.modules.batchnorm._NormBase): + for k, v in child.state_dict().items(): + async_norm_states['.'.join([name, k])] = v + return async_norm_states + + +@HOOKS.register_module() +class SyncNormHook(Hook): + """Synchronize Norm states before validation, currently used in YOLOX.""" + + def before_val_epoch(self, runner): + """Synchronizing norm.""" + module = runner.model + _, world_size = get_dist_info() + if world_size == 1: + return + norm_states = get_norm_states(module) + if len(norm_states) == 0: + return + # TODO: use `all_reduce_dict` in mmengine + norm_states = all_reduce_dict(norm_states, op='mean') + module.load_state_dict(norm_states, strict=False) diff --git a/mmdet/engine/hooks/utils.py b/mmdet/engine/hooks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d267cfe77be163c0520568b7b7936f4453914aab --- /dev/null +++ b/mmdet/engine/hooks/utils.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + # Turn on visualization + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualization_hook['test_out_dir'] = args.show_dir + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks.' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg diff --git a/mmdet/engine/hooks/visualization_hook.py b/mmdet/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..fad0f907ebc2ad47673bd9dfb5082988e57cf862 --- /dev/null +++ b/mmdet/engine/hooks/visualization_hook.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Optional, Sequence + +import mmcv +from mmengine.fileio import get +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.utils import mkdir_or_exist +from mmengine.visualization import Visualizer + +from mmdet.datasets.samplers import TrackImgSampler +from mmdet.registry import HOOKS +from mmdet.structures import DetDataSample, TrackDataSample + + +@HOOKS.register_module() +class DetVisualizationHook(Hook): + """Detection Visualization Hook. Used to visualize validation and testing + process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + 2. If ``test_out_dir`` is specified, it means that the prediction results + need to be saved to ``test_out_dir``. In order to avoid vis_backends + also storing data, so ``vis_backends`` needs to be excluded. + 3. ``vis_backends`` takes effect if the user does not specify ``show`` + and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or + TensorboardVisBackend to store the prediction result in Wandb or + Tensorboard. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + test_out_dir (str, optional): directory where painted images + will be saved in testing process. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + score_thr: float = 0.3, + show: bool = False, + wait_time: float = 0., + test_out_dir: Optional[str] = None, + backend_args: dict = None): + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.score_thr = score_thr + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args + self.draw = draw + self.test_out_dir = test_out_dir + self._test_index = 0 + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DetDataSample]) -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = outputs[0].img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + osp.basename(img_path) if self.show else 'val_img', + img, + data_sample=outputs[0], + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DetDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + if self.test_out_dir is not None: + self.test_out_dir = osp.join(runner.work_dir, runner.timestamp, + self.test_out_dir) + mkdir_or_exist(self.test_out_dir) + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + out_file = None + if self.test_out_dir is not None: + out_file = osp.basename(img_path) + out_file = osp.join(self.test_out_dir, out_file) + + self._visualizer.add_datasample( + osp.basename(img_path) if self.show else 'test_img', + img, + data_sample=data_sample, + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + out_file=out_file, + step=self._test_index) + + +@HOOKS.register_module() +class TrackVisualizationHook(Hook): + """Tracking Visualization Hook. Used to visualize validation and testing + process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + 2. If ``test_out_dir`` is specified, it means that the prediction results + need to be saved to ``test_out_dir``. In order to avoid vis_backends + also storing data, so ``vis_backends`` needs to be excluded. + 3. ``vis_backends`` takes effect if the user does not specify ``show`` + and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or + TensorboardVisBackend to store the prediction result in Wandb or + Tensorboard. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + frame_interval (int): The interval of visualization. Defaults to 30. + score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + test_out_dir (str, optional): directory where painted images + will be saved in testing process. + backend_args (dict): Arguments to instantiate a file client. + Defaults to ``None``. + """ + + def __init__(self, + draw: bool = False, + frame_interval: int = 30, + score_thr: float = 0.3, + show: bool = False, + wait_time: float = 0., + test_out_dir: Optional[str] = None, + backend_args: dict = None) -> None: + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.frame_interval = frame_interval + self.score_thr = score_thr + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args + self.draw = draw + self.test_out_dir = test_out_dir + self.image_idx = 0 + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[TrackDataSample]) -> None: + """Run after every ``self.interval`` validation iteration. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`TrackDataSample`]): Outputs from model. + """ + if self.draw is False: + return + + assert len(outputs) == 1,\ + 'only batch_size=1 is supported while validating.' + + sampler = runner.val_dataloader.sampler + if isinstance(sampler, TrackImgSampler): + if self.every_n_inner_iters(batch_idx, self.frame_interval): + total_curr_iter = runner.iter + batch_idx + track_data_sample = outputs[0] + self.visualize_single_image(track_data_sample[0], + total_curr_iter) + else: + # video visualization DefaultSampler + if self.every_n_inner_iters(batch_idx, 1): + track_data_sample = outputs[0] + video_length = len(track_data_sample) + + for frame_id in range(video_length): + if frame_id % self.frame_interval == 0: + total_curr_iter = runner.iter + self.image_idx + \ + frame_id + img_data_sample = track_data_sample[frame_id] + self.visualize_single_image(img_data_sample, + total_curr_iter) + self.image_idx = self.image_idx + video_length + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[TrackDataSample]) -> None: + """Run after every testing iteration. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`TrackDataSample`]): Outputs from model. + """ + if self.draw is False: + return + + assert len(outputs) == 1, \ + 'only batch_size=1 is supported while testing.' + + if self.test_out_dir is not None: + self.test_out_dir = osp.join(runner.work_dir, runner.timestamp, + self.test_out_dir) + mkdir_or_exist(self.test_out_dir) + + sampler = runner.test_dataloader.sampler + if isinstance(sampler, TrackImgSampler): + if self.every_n_inner_iters(batch_idx, self.frame_interval): + track_data_sample = outputs[0] + self.visualize_single_image(track_data_sample[0], batch_idx) + else: + # video visualization DefaultSampler + if self.every_n_inner_iters(batch_idx, 1): + track_data_sample = outputs[0] + video_length = len(track_data_sample) + + for frame_id in range(video_length): + if frame_id % self.frame_interval == 0: + img_data_sample = track_data_sample[frame_id] + self.visualize_single_image(img_data_sample, + self.image_idx + frame_id) + self.image_idx = self.image_idx + video_length + + def visualize_single_image(self, img_data_sample: DetDataSample, + step: int) -> None: + """ + Args: + img_data_sample (DetDataSample): single image output. + step (int): The index of the current image. + """ + img_path = img_data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + out_file = None + if self.test_out_dir is not None: + video_name = img_path.split('/')[-3] + mkdir_or_exist(osp.join(self.test_out_dir, video_name)) + out_file = osp.join(self.test_out_dir, video_name, + osp.basename(img_path)) + + self._visualizer.add_datasample( + osp.basename(img_path) if self.show else 'test_img', + img, + data_sample=img_data_sample, + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + out_file=out_file, + step=step) diff --git a/mmdet/engine/hooks/yolox_mode_switch_hook.py b/mmdet/engine/hooks/yolox_mode_switch_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..05a2c69068bedd1c6fb3836e1fc34568e9f6bc83 --- /dev/null +++ b/mmdet/engine/hooks/yolox_mode_switch_hook.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class YOLOXModeSwitchHook(Hook): + """Switch the mode of YOLOX during training. + + This hook turns off the mosaic and mixup data augmentation and switches + to use L1 loss in bbox_head. + + Args: + num_last_epochs (int): The number of latter epochs in the end of the + training to close the data augmentation and switch to L1 loss. + Defaults to 15. + skip_type_keys (Sequence[str], optional): Sequence of type string to be + skip pipeline. Defaults to ('Mosaic', 'RandomAffine', 'MixUp'). + """ + + def __init__( + self, + num_last_epochs: int = 15, + skip_type_keys: Sequence[str] = ('Mosaic', 'RandomAffine', 'MixUp') + ) -> None: + self.num_last_epochs = num_last_epochs + self.skip_type_keys = skip_type_keys + self._restart_dataloader = False + self._has_switched = False + + def before_train_epoch(self, runner) -> None: + """Close mosaic and mixup augmentation and switches to use L1 loss.""" + epoch = runner.epoch + train_loader = runner.train_dataloader + model = runner.model + # TODO: refactor after mmengine using model wrapper + if is_model_wrapper(model): + model = model.module + epoch_to_be_switched = ((epoch + 1) >= + runner.max_epochs - self.num_last_epochs) + if epoch_to_be_switched and not self._has_switched: + runner.logger.info('No mosaic and mixup aug now!') + # The dataset pipeline cannot be updated when persistent_workers + # is True, so we need to force the dataloader's multi-process + # restart. This is a very hacky approach. + train_loader.dataset.update_skip_type_keys(self.skip_type_keys) + if hasattr(train_loader, 'persistent_workers' + ) and train_loader.persistent_workers is True: + train_loader._DataLoader__initialized = False + train_loader._iterator = None + self._restart_dataloader = True + runner.logger.info('Add additional L1 loss now!') + if hasattr(model, 'detector'): + model.detector.bbox_head.use_l1 = True + else: + model.bbox_head.use_l1 = True + self._has_switched = True + else: + # Once the restart is complete, we need to restore + # the initialization flag. + if self._restart_dataloader: + train_loader._DataLoader__initialized = True diff --git a/mmdet/engine/optimizers/__init__.py b/mmdet/engine/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83db069ee34cad0888bbf388d3cc7030ba49bbbb --- /dev/null +++ b/mmdet/engine/optimizers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_decay_optimizer_constructor import \ + LearningRateDecayOptimizerConstructor + +__all__ = ['LearningRateDecayOptimizerConstructor'] diff --git a/mmdet/engine/optimizers/layer_decay_optimizer_constructor.py b/mmdet/engine/optimizers/layer_decay_optimizer_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..73028a0aef698d63dcba8c4935d6ef6c577d0f46 --- /dev/null +++ b/mmdet/engine/optimizers/layer_decay_optimizer_constructor.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import List + +import torch.nn as nn +from mmengine.dist import get_dist_info +from mmengine.logging import MMLogger +from mmengine.optim import DefaultOptimWrapperConstructor + +from mmdet.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +def get_layer_id_for_convnext(var_name, max_layer_id): + """Get the layer id to set the different learning rates in ``layer_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_layer_id (int): Maximum layer id. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + stage_id = int(var_name.split('.')[2]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + block_id = int(var_name.split('.')[3]) + if stage_id == 0: + layer_id = 1 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + else: + return max_layer_id + 1 + + +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_stage_id (int): Maximum stage id. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + return 0 + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + return stage_id + 1 + else: + return max_stage_id - 1 + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor): + # Different learning rates are set for different layers of backbone. + # Note: Currently, this optimizer constructor is built for ConvNeXt. + + def add_params(self, params: List[dict], module: nn.Module, + **kwargs) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + """ + logger = MMLogger.get_current_instance() + + parameter_groups = {} + logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}') + num_layers = self.paramwise_cfg.get('num_layers') + 2 + decay_rate = self.paramwise_cfg.get('decay_rate') + decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') + logger.info('Build LearningRateDecayOptimizerConstructor ' + f'{decay_type} {decay_rate} - {num_layers}') + weight_decay = self.base_wd + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') or name in ( + 'pos_embed', 'cls_token'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + elif decay_type == 'stage_wise': + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + group_name = f'layer_{layer_id}_{group_name}' + + if group_name not in parameter_groups: + scale = decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') + params.extend(parameter_groups.values()) diff --git a/mmdet/engine/runner/__init__.py b/mmdet/engine/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8bcce4448e48e2d64354ba6770f9f426fb3d869 --- /dev/null +++ b/mmdet/engine/runner/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .loops import TeacherStudentValLoop + +__all__ = ['TeacherStudentValLoop'] diff --git a/mmdet/engine/runner/loops.py b/mmdet/engine/runner/loops.py new file mode 100644 index 0000000000000000000000000000000000000000..afe53afa5c80facf3ba6c224bd358e0859dade32 --- /dev/null +++ b/mmdet/engine/runner/loops.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import is_model_wrapper +from mmengine.runner import ValLoop + +from mmdet.registry import LOOPS + + +@LOOPS.register_module() +class TeacherStudentValLoop(ValLoop): + """Loop for validation of model teacher and student.""" + + def run(self): + """Launch validation for model teacher and student.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + + model = self.runner.model + if is_model_wrapper(model): + model = model.module + assert hasattr(model, 'teacher') + assert hasattr(model, 'student') + + predict_on = model.semi_test_cfg.get('predict_on', None) + multi_metrics = dict() + for _predict_on in ['teacher', 'student']: + model.semi_test_cfg['predict_on'] = _predict_on + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + multi_metrics.update( + {'/'.join((_predict_on, k)): v + for k, v in metrics.items()}) + model.semi_test_cfg['predict_on'] = predict_on + + self.runner.call_hook('after_val_epoch', metrics=multi_metrics) + self.runner.call_hook('after_val') diff --git a/mmdet/engine/schedulers/__init__.py b/mmdet/engine/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01261646fa8255c643e86ba0517019760a50d387 --- /dev/null +++ b/mmdet/engine/schedulers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .quadratic_warmup import (QuadraticWarmupLR, QuadraticWarmupMomentum, + QuadraticWarmupParamScheduler) + +__all__ = [ + 'QuadraticWarmupParamScheduler', 'QuadraticWarmupMomentum', + 'QuadraticWarmupLR' +] diff --git a/mmdet/engine/schedulers/quadratic_warmup.py b/mmdet/engine/schedulers/quadratic_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..639b47854887786bf3f81d6d0a375033d190d91e --- /dev/null +++ b/mmdet/engine/schedulers/quadratic_warmup.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.optim.scheduler.lr_scheduler import LRSchedulerMixin +from mmengine.optim.scheduler.momentum_scheduler import MomentumSchedulerMixin +from mmengine.optim.scheduler.param_scheduler import INF, _ParamScheduler +from torch.optim import Optimizer + +from mmdet.registry import PARAM_SCHEDULERS + + +@PARAM_SCHEDULERS.register_module() +class QuadraticWarmupParamScheduler(_ParamScheduler): + r"""Warm up the parameter value of each parameter group by quadratic + formula: + + .. math:: + + X_{t} = X_{t-1} + \frac{2t+1}{{(end-begin)}^{2}} \times X_{base} + + Args: + optimizer (Optimizer): Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + if end >= INF: + raise ValueError('``end`` must be less than infinity,' + 'Please set ``end`` parameter of ' + '``QuadraticWarmupScheduler`` as the ' + 'number of warmup end.') + self.total_iters = end - begin + super().__init__( + optimizer=optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + if self.last_step == 0: + return [ + base_value * (2 * self.last_step + 1) / self.total_iters**2 + for base_value in self.base_values + ] + + return [ + group[self.param_name] + base_value * + (2 * self.last_step + 1) / self.total_iters**2 + for base_value, group in zip(self.base_values, + self.optimizer.param_groups) + ] + + +@PARAM_SCHEDULERS.register_module() +class QuadraticWarmupLR(LRSchedulerMixin, QuadraticWarmupParamScheduler): + """Warm up the learning rate of each parameter group by quadratic formula. + + Args: + optimizer (Optimizer): Wrapped optimizer. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + +@PARAM_SCHEDULERS.register_module() +class QuadraticWarmupMomentum(MomentumSchedulerMixin, + QuadraticWarmupParamScheduler): + """Warm up the momentum value of each parameter group by quadratic formula. + + Args: + optimizer (Optimizer): Wrapped optimizer. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ diff --git a/mmdet/evaluation/.DS_Store b/mmdet/evaluation/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7a43a0c5397a8186b127c9792f71541bef9cd828 Binary files /dev/null and b/mmdet/evaluation/.DS_Store differ diff --git a/mmdet/evaluation/__init__.py b/mmdet/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f70dc226d30f7b8e4ee5a44ca163ad1ae04eabf5 --- /dev/null +++ b/mmdet/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .functional import * # noqa: F401,F403 +from .metrics import * # noqa: F401,F403 diff --git a/mmdet/evaluation/functional/__init__.py b/mmdet/evaluation/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96d58ebd3ab0dd714a6f361622a7faf2a09486cb --- /dev/null +++ b/mmdet/evaluation/functional/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bbox_overlaps import bbox_overlaps +from .cityscapes_utils import evaluateImgLists +from .class_names import (cityscapes_classes, coco_classes, + coco_panoptic_classes, dataset_aliases, get_classes, + imagenet_det_classes, imagenet_vid_classes, + objects365v1_classes, objects365v2_classes, + oid_challenge_classes, oid_v6_classes, voc_classes) +from .mean_ap import average_precision, eval_map, print_map_summary +from .panoptic_utils import (INSTANCE_OFFSET, pq_compute_multi_core, + pq_compute_single_core) +from .recall import (eval_recalls, plot_iou_recall, plot_num_recall, + print_recall_summary) +from .ytvis import YTVIS +from .ytviseval import YTVISeval + +__all__ = [ + 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', + 'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes', + 'average_precision', 'eval_map', 'print_map_summary', 'eval_recalls', + 'print_recall_summary', 'plot_num_recall', 'plot_iou_recall', + 'oid_v6_classes', 'oid_challenge_classes', 'INSTANCE_OFFSET', + 'pq_compute_single_core', 'pq_compute_multi_core', 'bbox_overlaps', + 'objects365v1_classes', 'objects365v2_classes', 'coco_panoptic_classes', + 'evaluateImgLists', 'YTVIS', 'YTVISeval' +] diff --git a/mmdet/evaluation/functional/bbox_overlaps.py b/mmdet/evaluation/functional/bbox_overlaps.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6eb82fcfc8d5444dd2a13b7d95b978f8206a55 --- /dev/null +++ b/mmdet/evaluation/functional/bbox_overlaps.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def bbox_overlaps(bboxes1, + bboxes2, + mode='iou', + eps=1e-6, + use_legacy_coordinate=False): + """Calculate the ious between each bbox of bboxes1 and bboxes2. + + Args: + bboxes1 (ndarray): Shape (n, 4) + bboxes2 (ndarray): Shape (k, 4) + mode (str): IOU (intersection over union) or IOF (intersection + over foreground) + use_legacy_coordinate (bool): Whether to use coordinate system in + mmdet v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Note when function is used in `VOCDataset`, it should be + True to align with the official implementation + `http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar` + Default: False. + + Returns: + ious (ndarray): Shape (n, k) + """ + + assert mode in ['iou', 'iof'] + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + bboxes1 = bboxes1.astype(np.float32) + bboxes2 = bboxes2.astype(np.float32) + rows = bboxes1.shape[0] + cols = bboxes2.shape[0] + ious = np.zeros((rows, cols), dtype=np.float32) + if rows * cols == 0: + return ious + exchange = False + if bboxes1.shape[0] > bboxes2.shape[0]: + bboxes1, bboxes2 = bboxes2, bboxes1 + ious = np.zeros((cols, rows), dtype=np.float32) + exchange = True + area1 = (bboxes1[:, 2] - bboxes1[:, 0] + extra_length) * ( + bboxes1[:, 3] - bboxes1[:, 1] + extra_length) + area2 = (bboxes2[:, 2] - bboxes2[:, 0] + extra_length) * ( + bboxes2[:, 3] - bboxes2[:, 1] + extra_length) + for i in range(bboxes1.shape[0]): + x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) + y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) + x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) + y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) + overlap = np.maximum(x_end - x_start + extra_length, 0) * np.maximum( + y_end - y_start + extra_length, 0) + if mode == 'iou': + union = area1[i] + area2 - overlap + else: + union = area1[i] if not exchange else area2 + union = np.maximum(union, eps) + ious[i, :] = overlap / union + if exchange: + ious = ious.T + return ious diff --git a/mmdet/evaluation/functional/cityscapes_utils.py b/mmdet/evaluation/functional/cityscapes_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5ced3680deefe333af7cca3675a6359c02dd96f8 --- /dev/null +++ b/mmdet/evaluation/functional/cityscapes_utils.py @@ -0,0 +1,302 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) https://github.com/mcordts/cityscapesScripts +# A wrapper of `cityscapesscripts` which supports loading groundtruth +# image from `backend_args`. +import json +import os +import sys +from pathlib import Path +from typing import Optional, Union + +import mmcv +import numpy as np +from mmengine.fileio import get + +try: + import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa: E501 + from cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling import \ + CArgs # noqa: E501 + from cityscapesscripts.evaluation.instance import Instance + from cityscapesscripts.helpers.csHelpers import (id2label, labels, + writeDict2JSON) + HAS_CITYSCAPESAPI = True +except ImportError: + CArgs = object + HAS_CITYSCAPESAPI = False + + +def evaluateImgLists(prediction_list: list, + groundtruth_list: list, + args: CArgs, + backend_args: Optional[dict] = None, + dump_matches: bool = False) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + + evalInstanceLevelSemanticLabeling.evaluateImgLists``. Support loading + groundtruth image from file backend. + Args: + prediction_list (list): A list of prediction txt file. + groundtruth_list (list): A list of groundtruth image file. + args (CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + dump_matches (bool): whether dump matches.json. Defaults to False. + Returns: + dict: The computed metric. + """ + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + # determine labels of interest + CSEval.setInstanceLabels(args) + # get dictionary of all ground truth instances + gt_instances = getGtInstances( + groundtruth_list, args, backend_args=backend_args) + # match predictions and ground truth + matches = matchGtWithPreds(prediction_list, groundtruth_list, gt_instances, + args, backend_args) + if dump_matches: + CSEval.writeDict2JSON(matches, 'matches.json') + # evaluate matches + apScores = CSEval.evaluateMatches(matches, args) + # averages + avgDict = CSEval.computeAverages(apScores, args) + # result dict + resDict = CSEval.prepareJSONDataForResults(avgDict, apScores, args) + if args.JSONOutput: + # create output folder if necessary + path = os.path.dirname(args.exportFile) + CSEval.ensurePath(path) + # Write APs to JSON + CSEval.writeDict2JSON(resDict, args.exportFile) + + CSEval.printResults(avgDict, args) + + return resDict + + +def matchGtWithPreds(prediction_list: list, + groundtruth_list: list, + gt_instances: dict, + args: CArgs, + backend_args=None): + """A wrapper of obj:``cityscapesscripts.evaluation. + + evalInstanceLevelSemanticLabeling.matchGtWithPreds``. Support loading + groundtruth image from file backend. + Args: + prediction_list (list): A list of prediction txt file. + groundtruth_list (list): A list of groundtruth image file. + gt_instances (dict): Groundtruth dict. + args (CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + Returns: + dict: The processed prediction and groundtruth result. + """ + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + matches: dict = dict() + if not args.quiet: + print(f'Matching {len(prediction_list)} pairs of images...') + + count = 0 + for (pred, gt) in zip(prediction_list, groundtruth_list): + # Read input files + gt_image = readGTImage(gt, backend_args) + pred_info = readPredInfo(pred) + # Get and filter ground truth instances + unfiltered_instances = gt_instances[gt] + cur_gt_instances_orig = CSEval.filterGtInstances( + unfiltered_instances, args) + + # Try to assign all predictions + (cur_gt_instances, + cur_pred_instances) = CSEval.assignGt2Preds(cur_gt_instances_orig, + gt_image, pred_info, args) + + # append to global dict + matches[gt] = {} + matches[gt]['groundTruth'] = cur_gt_instances + matches[gt]['prediction'] = cur_pred_instances + + count += 1 + if not args.quiet: + print(f'\rImages Processed: {count}', end=' ') + sys.stdout.flush() + + if not args.quiet: + print('') + + return matches + + +def readGTImage(image_file: Union[str, Path], + backend_args: Optional[dict] = None) -> np.ndarray: + """Read an image from path. + + Same as obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.readGTImage``, but support loading + groundtruth image from file backend. + Args: + image_file (str or Path): Either a str or pathlib.Path. + backend_args (dict, optional): Instantiates the corresponding file + backend. It may contain `backend` key to specify the file + backend. If it contains, the file backend corresponding to this + value will be used and initialized with the remaining values, + otherwise the corresponding file backend will be selected + based on the prefix of the file path. Defaults to None. + Returns: + np.ndarray: The groundtruth image. + """ + img_bytes = get(image_file, backend_args=backend_args) + img = mmcv.imfrombytes(img_bytes, flag='unchanged', backend='pillow') + return img + + +def readPredInfo(prediction_file: str) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + + evalInstanceLevelSemanticLabeling.readPredInfo``. + Args: + prediction_file (str): The prediction txt file. + Returns: + dict: The processed prediction results. + """ + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + printError = CSEval.printError + + predInfo = {} + if (not os.path.isfile(prediction_file)): + printError(f"Infofile '{prediction_file}' " + 'for the predictions not found.') + with open(prediction_file) as f: + for line in f: + splittedLine = line.split(' ') + if len(splittedLine) != 3: + printError('Invalid prediction file. Expected content: ' + 'relPathPrediction1 labelIDPrediction1 ' + 'confidencePrediction1') + if os.path.isabs(splittedLine[0]): + printError('Invalid prediction file. First entry in each ' + 'line must be a relative path.') + + filename = os.path.join( + os.path.dirname(prediction_file), splittedLine[0]) + + imageInfo = {} + imageInfo['labelID'] = int(float(splittedLine[1])) + imageInfo['conf'] = float(splittedLine[2]) # type: ignore + predInfo[filename] = imageInfo + + return predInfo + + +def getGtInstances(groundtruth_list: list, + args: CArgs, + backend_args: Optional[dict] = None) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + + evalInstanceLevelSemanticLabeling.getGtInstances``. Support loading + groundtruth image from file backend. + Args: + groundtruth_list (list): A list of groundtruth image file. + args (CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + Returns: + dict: The computed metric. + """ + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + # if there is a global statistics json, then load it + if (os.path.isfile(args.gtInstancesFile)): + if not args.quiet: + print('Loading ground truth instances from JSON.') + with open(args.gtInstancesFile) as json_file: + gt_instances = json.load(json_file) + # otherwise create it + else: + if (not args.quiet): + print('Creating ground truth instances from png files.') + gt_instances = instances2dict( + groundtruth_list, args, backend_args=backend_args) + writeDict2JSON(gt_instances, args.gtInstancesFile) + + return gt_instances + + +def instances2dict(image_list: list, + args: CArgs, + backend_args: Optional[dict] = None) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + + evalInstanceLevelSemanticLabeling.instances2dict``. Support loading + groundtruth image from file backend. + Args: + image_list (list): A list of image file. + args (CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + Returns: + dict: The processed groundtruth results. + """ + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + imgCount = 0 + instanceDict = {} + + if not isinstance(image_list, list): + image_list = [image_list] + + if not args.quiet: + print(f'Processing {len(image_list)} images...') + + for image_name in image_list: + # Load image + img_bytes = get(image_name, backend_args=backend_args) + imgNp = mmcv.imfrombytes(img_bytes, flag='unchanged', backend='pillow') + + # Initialize label categories + instances: dict = {} + for label in labels: + instances[label.name] = [] + + # Loop through all instance ids in instance image + for instanceId in np.unique(imgNp): + instanceObj = Instance(imgNp, instanceId) + + instances[id2label[instanceObj.labelID].name].append( + instanceObj.toDict()) + + instanceDict[image_name] = instances + imgCount += 1 + + if not args.quiet: + print(f'\rImages Processed: {imgCount}', end=' ') + sys.stdout.flush() + + return instanceDict diff --git a/mmdet/evaluation/functional/class_names.py b/mmdet/evaluation/functional/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ea7094685de38a9196d1240d23beb1b44d4138 --- /dev/null +++ b/mmdet/evaluation/functional/class_names.py @@ -0,0 +1,517 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import is_str + + +def wider_face_classes() -> list: + """Class names of WIDERFace.""" + return ['face'] + + +def voc_classes() -> list: + """Class names of PASCAL VOC.""" + return [ + 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', + 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', + 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' + ] + + +def imagenet_det_classes() -> list: + """Class names of ImageNet Det.""" + return [ + 'accordion', 'airplane', 'ant', 'antelope', 'apple', 'armadillo', + 'artichoke', 'axe', 'baby_bed', 'backpack', 'bagel', 'balance_beam', + 'banana', 'band_aid', 'banjo', 'baseball', 'basketball', 'bathing_cap', + 'beaker', 'bear', 'bee', 'bell_pepper', 'bench', 'bicycle', 'binder', + 'bird', 'bookshelf', 'bow_tie', 'bow', 'bowl', 'brassiere', 'burrito', + 'bus', 'butterfly', 'camel', 'can_opener', 'car', 'cart', 'cattle', + 'cello', 'centipede', 'chain_saw', 'chair', 'chime', 'cocktail_shaker', + 'coffee_maker', 'computer_keyboard', 'computer_mouse', 'corkscrew', + 'cream', 'croquet_ball', 'crutch', 'cucumber', 'cup_or_mug', 'diaper', + 'digital_clock', 'dishwasher', 'dog', 'domestic_cat', 'dragonfly', + 'drum', 'dumbbell', 'electric_fan', 'elephant', 'face_powder', 'fig', + 'filing_cabinet', 'flower_pot', 'flute', 'fox', 'french_horn', 'frog', + 'frying_pan', 'giant_panda', 'goldfish', 'golf_ball', 'golfcart', + 'guacamole', 'guitar', 'hair_dryer', 'hair_spray', 'hamburger', + 'hammer', 'hamster', 'harmonica', 'harp', 'hat_with_a_wide_brim', + 'head_cabbage', 'helmet', 'hippopotamus', 'horizontal_bar', 'horse', + 'hotdog', 'iPod', 'isopod', 'jellyfish', 'koala_bear', 'ladle', + 'ladybug', 'lamp', 'laptop', 'lemon', 'lion', 'lipstick', 'lizard', + 'lobster', 'maillot', 'maraca', 'microphone', 'microwave', 'milk_can', + 'miniskirt', 'monkey', 'motorcycle', 'mushroom', 'nail', 'neck_brace', + 'oboe', 'orange', 'otter', 'pencil_box', 'pencil_sharpener', 'perfume', + 'person', 'piano', 'pineapple', 'ping-pong_ball', 'pitcher', 'pizza', + 'plastic_bag', 'plate_rack', 'pomegranate', 'popsicle', 'porcupine', + 'power_drill', 'pretzel', 'printer', 'puck', 'punching_bag', 'purse', + 'rabbit', 'racket', 'ray', 'red_panda', 'refrigerator', + 'remote_control', 'rubber_eraser', 'rugby_ball', 'ruler', + 'salt_or_pepper_shaker', 'saxophone', 'scorpion', 'screwdriver', + 'seal', 'sheep', 'ski', 'skunk', 'snail', 'snake', 'snowmobile', + 'snowplow', 'soap_dispenser', 'soccer_ball', 'sofa', 'spatula', + 'squirrel', 'starfish', 'stethoscope', 'stove', 'strainer', + 'strawberry', 'stretcher', 'sunglasses', 'swimming_trunks', 'swine', + 'syringe', 'table', 'tape_player', 'tennis_ball', 'tick', 'tie', + 'tiger', 'toaster', 'traffic_light', 'train', 'trombone', 'trumpet', + 'turtle', 'tv_or_monitor', 'unicycle', 'vacuum', 'violin', + 'volleyball', 'waffle_iron', 'washer', 'water_bottle', 'watercraft', + 'whale', 'wine_bottle', 'zebra' + ] + + +def imagenet_vid_classes() -> list: + """Class names of ImageNet VID.""" + return [ + 'airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car', + 'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda', + 'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit', + 'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle', + 'watercraft', 'whale', 'zebra' + ] + + +def coco_classes() -> list: + """Class names of COCO.""" + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', + 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard', + 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush' + ] + + +def coco_panoptic_classes() -> list: + """Class names of COCO panoptic.""" + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', + 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', + 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', + 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', + 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', + 'wall-wood', 'water-other', 'window-blind', 'window-other', + 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', + 'cabinet-merged', 'table-merged', 'floor-other-merged', + 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', + 'paper-merged', 'food-other-merged', 'building-other-merged', + 'rock-merged', 'wall-other-merged', 'rug-merged' + ] + + +def cityscapes_classes() -> list: + """Class names of Cityscapes.""" + return [ + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def oid_challenge_classes() -> list: + """Class names of Open Images Challenge.""" + return [ + 'Footwear', 'Jeans', 'House', 'Tree', 'Woman', 'Man', 'Land vehicle', + 'Person', 'Wheel', 'Bus', 'Human face', 'Bird', 'Dress', 'Girl', + 'Vehicle', 'Building', 'Cat', 'Car', 'Belt', 'Elephant', 'Dessert', + 'Butterfly', 'Train', 'Guitar', 'Poster', 'Book', 'Boy', 'Bee', + 'Flower', 'Window', 'Hat', 'Human head', 'Dog', 'Human arm', 'Drink', + 'Human mouth', 'Human hair', 'Human nose', 'Human hand', 'Table', + 'Marine invertebrates', 'Fish', 'Sculpture', 'Rose', 'Street light', + 'Glasses', 'Fountain', 'Skyscraper', 'Swimwear', 'Brassiere', 'Drum', + 'Duck', 'Countertop', 'Furniture', 'Ball', 'Human leg', 'Boat', + 'Balloon', 'Bicycle helmet', 'Goggles', 'Door', 'Human eye', 'Shirt', + 'Toy', 'Teddy bear', 'Pasta', 'Tomato', 'Human ear', + 'Vehicle registration plate', 'Microphone', 'Musical keyboard', + 'Tower', 'Houseplant', 'Flowerpot', 'Fruit', 'Vegetable', + 'Musical instrument', 'Suit', 'Motorcycle', 'Bagel', 'French fries', + 'Hamburger', 'Chair', 'Salt and pepper shakers', 'Snail', 'Airplane', + 'Horse', 'Laptop', 'Computer keyboard', 'Football helmet', 'Cocktail', + 'Juice', 'Tie', 'Computer monitor', 'Human beard', 'Bottle', + 'Saxophone', 'Lemon', 'Mouse', 'Sock', 'Cowboy hat', 'Sun hat', + 'Football', 'Porch', 'Sunglasses', 'Lobster', 'Crab', 'Picture frame', + 'Van', 'Crocodile', 'Surfboard', 'Shorts', 'Helicopter', 'Helmet', + 'Sports uniform', 'Taxi', 'Swan', 'Goose', 'Coat', 'Jacket', 'Handbag', + 'Flag', 'Skateboard', 'Television', 'Tire', 'Spoon', 'Palm tree', + 'Stairs', 'Salad', 'Castle', 'Oven', 'Microwave oven', 'Wine', + 'Ceiling fan', 'Mechanical fan', 'Cattle', 'Truck', 'Box', 'Ambulance', + 'Desk', 'Wine glass', 'Reptile', 'Tank', 'Traffic light', 'Billboard', + 'Tent', 'Insect', 'Spider', 'Treadmill', 'Cupboard', 'Shelf', + 'Seat belt', 'Human foot', 'Bicycle', 'Bicycle wheel', 'Couch', + 'Bookcase', 'Fedora', 'Backpack', 'Bench', 'Oyster', + 'Moths and butterflies', 'Lavender', 'Waffle', 'Fork', 'Animal', + 'Accordion', 'Mobile phone', 'Plate', 'Coffee cup', 'Saucer', + 'Platter', 'Dagger', 'Knife', 'Bull', 'Tortoise', 'Sea turtle', 'Deer', + 'Weapon', 'Apple', 'Ski', 'Taco', 'Traffic sign', 'Beer', 'Necklace', + 'Sunflower', 'Piano', 'Organ', 'Harpsichord', 'Bed', 'Cabinetry', + 'Nightstand', 'Curtain', 'Chest of drawers', 'Drawer', 'Parrot', + 'Sandal', 'High heels', 'Tableware', 'Cart', 'Mushroom', 'Kite', + 'Missile', 'Seafood', 'Camera', 'Paper towel', 'Toilet paper', + 'Sombrero', 'Radish', 'Lighthouse', 'Segway', 'Pig', 'Watercraft', + 'Golf cart', 'studio couch', 'Dolphin', 'Whale', 'Earrings', 'Otter', + 'Sea lion', 'Whiteboard', 'Monkey', 'Gondola', 'Zebra', + 'Baseball glove', 'Scarf', 'Adhesive tape', 'Trousers', 'Scoreboard', + 'Lily', 'Carnivore', 'Power plugs and sockets', 'Office building', + 'Sandwich', 'Swimming pool', 'Headphones', 'Tin can', 'Crown', 'Doll', + 'Cake', 'Frog', 'Beetle', 'Ant', 'Gas stove', 'Canoe', 'Falcon', + 'Blue jay', 'Egg', 'Fire hydrant', 'Raccoon', 'Muffin', 'Wall clock', + 'Coffee', 'Mug', 'Tea', 'Bear', 'Waste container', 'Home appliance', + 'Candle', 'Lion', 'Mirror', 'Starfish', 'Marine mammal', 'Wheelchair', + 'Umbrella', 'Alpaca', 'Violin', 'Cello', 'Brown bear', 'Canary', 'Bat', + 'Ruler', 'Plastic bag', 'Penguin', 'Watermelon', 'Harbor seal', 'Pen', + 'Pumpkin', 'Harp', 'Kitchen appliance', 'Roller skates', 'Bust', + 'Coffee table', 'Tennis ball', 'Tennis racket', 'Ladder', 'Boot', + 'Bowl', 'Stop sign', 'Volleyball', 'Eagle', 'Paddle', 'Chicken', + 'Skull', 'Lamp', 'Beehive', 'Maple', 'Sink', 'Goldfish', 'Tripod', + 'Coconut', 'Bidet', 'Tap', 'Bathroom cabinet', 'Toilet', + 'Filing cabinet', 'Pretzel', 'Table tennis racket', 'Bronze sculpture', + 'Rocket', 'Mouse', 'Hamster', 'Lizard', 'Lifejacket', 'Goat', + 'Washing machine', 'Trumpet', 'Horn', 'Trombone', 'Sheep', + 'Tablet computer', 'Pillow', 'Kitchen & dining room table', + 'Parachute', 'Raven', 'Glove', 'Loveseat', 'Christmas tree', + 'Shellfish', 'Rifle', 'Shotgun', 'Sushi', 'Sparrow', 'Bread', + 'Toaster', 'Watch', 'Asparagus', 'Artichoke', 'Suitcase', 'Antelope', + 'Broccoli', 'Ice cream', 'Racket', 'Banana', 'Cookie', 'Cucumber', + 'Dragonfly', 'Lynx', 'Caterpillar', 'Light bulb', 'Office supplies', + 'Miniskirt', 'Skirt', 'Fireplace', 'Potato', 'Light switch', + 'Croissant', 'Cabbage', 'Ladybug', 'Handgun', 'Luggage and bags', + 'Window blind', 'Snowboard', 'Baseball bat', 'Digital clock', + 'Serving tray', 'Infant bed', 'Sofa bed', 'Guacamole', 'Fox', 'Pizza', + 'Snowplow', 'Jet ski', 'Refrigerator', 'Lantern', 'Convenience store', + 'Sword', 'Rugby ball', 'Owl', 'Ostrich', 'Pancake', 'Strawberry', + 'Carrot', 'Tart', 'Dice', 'Turkey', 'Rabbit', 'Invertebrate', 'Vase', + 'Stool', 'Swim cap', 'Shower', 'Clock', 'Jellyfish', 'Aircraft', + 'Chopsticks', 'Orange', 'Snake', 'Sewing machine', 'Kangaroo', 'Mixer', + 'Food processor', 'Shrimp', 'Towel', 'Porcupine', 'Jaguar', 'Cannon', + 'Limousine', 'Mule', 'Squirrel', 'Kitchen knife', 'Tiara', 'Tiger', + 'Bow and arrow', 'Candy', 'Rhinoceros', 'Shark', 'Cricket ball', + 'Doughnut', 'Plumbing fixture', 'Camel', 'Polar bear', 'Coin', + 'Printer', 'Blender', 'Giraffe', 'Billiard table', 'Kettle', + 'Dinosaur', 'Pineapple', 'Zucchini', 'Jug', 'Barge', 'Teapot', + 'Golf ball', 'Binoculars', 'Scissors', 'Hot dog', 'Door handle', + 'Seahorse', 'Bathtub', 'Leopard', 'Centipede', 'Grapefruit', 'Snowman', + 'Cheetah', 'Alarm clock', 'Grape', 'Wrench', 'Wok', 'Bell pepper', + 'Cake stand', 'Barrel', 'Woodpecker', 'Flute', 'Corded phone', + 'Willow', 'Punching bag', 'Pomegranate', 'Telephone', 'Pear', + 'Common fig', 'Bench', 'Wood-burning stove', 'Burrito', 'Nail', + 'Turtle', 'Submarine sandwich', 'Drinking straw', 'Peach', 'Popcorn', + 'Frying pan', 'Picnic basket', 'Honeycomb', 'Envelope', 'Mango', + 'Cutting board', 'Pitcher', 'Stationary bicycle', 'Dumbbell', + 'Personal care', 'Dog bed', 'Snowmobile', 'Oboe', 'Briefcase', + 'Squash', 'Tick', 'Slow cooker', 'Coffeemaker', 'Measuring cup', + 'Crutch', 'Stretcher', 'Screwdriver', 'Flashlight', 'Spatula', + 'Pressure cooker', 'Ring binder', 'Beaker', 'Torch', 'Winter melon' + ] + + +def oid_v6_classes() -> list: + """Class names of Open Images V6.""" + return [ + 'Tortoise', 'Container', 'Magpie', 'Sea turtle', 'Football', + 'Ambulance', 'Ladder', 'Toothbrush', 'Syringe', 'Sink', 'Toy', + 'Organ (Musical Instrument)', 'Cassette deck', 'Apple', 'Human eye', + 'Cosmetics', 'Paddle', 'Snowman', 'Beer', 'Chopsticks', 'Human beard', + 'Bird', 'Parking meter', 'Traffic light', 'Croissant', 'Cucumber', + 'Radish', 'Towel', 'Doll', 'Skull', 'Washing machine', 'Glove', 'Tick', + 'Belt', 'Sunglasses', 'Banjo', 'Cart', 'Ball', 'Backpack', 'Bicycle', + 'Home appliance', 'Centipede', 'Boat', 'Surfboard', 'Boot', + 'Headphones', 'Hot dog', 'Shorts', 'Fast food', 'Bus', 'Boy', + 'Screwdriver', 'Bicycle wheel', 'Barge', 'Laptop', 'Miniskirt', + 'Drill (Tool)', 'Dress', 'Bear', 'Waffle', 'Pancake', 'Brown bear', + 'Woodpecker', 'Blue jay', 'Pretzel', 'Bagel', 'Tower', 'Teapot', + 'Person', 'Bow and arrow', 'Swimwear', 'Beehive', 'Brassiere', 'Bee', + 'Bat (Animal)', 'Starfish', 'Popcorn', 'Burrito', 'Chainsaw', + 'Balloon', 'Wrench', 'Tent', 'Vehicle registration plate', 'Lantern', + 'Toaster', 'Flashlight', 'Billboard', 'Tiara', 'Limousine', 'Necklace', + 'Carnivore', 'Scissors', 'Stairs', 'Computer keyboard', 'Printer', + 'Traffic sign', 'Chair', 'Shirt', 'Poster', 'Cheese', 'Sock', + 'Fire hydrant', 'Land vehicle', 'Earrings', 'Tie', 'Watercraft', + 'Cabinetry', 'Suitcase', 'Muffin', 'Bidet', 'Snack', 'Snowmobile', + 'Clock', 'Medical equipment', 'Cattle', 'Cello', 'Jet ski', 'Camel', + 'Coat', 'Suit', 'Desk', 'Cat', 'Bronze sculpture', 'Juice', 'Gondola', + 'Beetle', 'Cannon', 'Computer mouse', 'Cookie', 'Office building', + 'Fountain', 'Coin', 'Calculator', 'Cocktail', 'Computer monitor', + 'Box', 'Stapler', 'Christmas tree', 'Cowboy hat', 'Hiking equipment', + 'Studio couch', 'Drum', 'Dessert', 'Wine rack', 'Drink', 'Zucchini', + 'Ladle', 'Human mouth', 'Dairy Product', 'Dice', 'Oven', 'Dinosaur', + 'Ratchet (Device)', 'Couch', 'Cricket ball', 'Winter melon', 'Spatula', + 'Whiteboard', 'Pencil sharpener', 'Door', 'Hat', 'Shower', 'Eraser', + 'Fedora', 'Guacamole', 'Dagger', 'Scarf', 'Dolphin', 'Sombrero', + 'Tin can', 'Mug', 'Tap', 'Harbor seal', 'Stretcher', 'Can opener', + 'Goggles', 'Human body', 'Roller skates', 'Coffee cup', + 'Cutting board', 'Blender', 'Plumbing fixture', 'Stop sign', + 'Office supplies', 'Volleyball (Ball)', 'Vase', 'Slow cooker', + 'Wardrobe', 'Coffee', 'Whisk', 'Paper towel', 'Personal care', 'Food', + 'Sun hat', 'Tree house', 'Flying disc', 'Skirt', 'Gas stove', + 'Salt and pepper shakers', 'Mechanical fan', 'Face powder', 'Fax', + 'Fruit', 'French fries', 'Nightstand', 'Barrel', 'Kite', 'Tart', + 'Treadmill', 'Fox', 'Flag', 'French horn', 'Window blind', + 'Human foot', 'Golf cart', 'Jacket', 'Egg (Food)', 'Street light', + 'Guitar', 'Pillow', 'Human leg', 'Isopod', 'Grape', 'Human ear', + 'Power plugs and sockets', 'Panda', 'Giraffe', 'Woman', 'Door handle', + 'Rhinoceros', 'Bathtub', 'Goldfish', 'Houseplant', 'Goat', + 'Baseball bat', 'Baseball glove', 'Mixing bowl', + 'Marine invertebrates', 'Kitchen utensil', 'Light switch', 'House', + 'Horse', 'Stationary bicycle', 'Hammer', 'Ceiling fan', 'Sofa bed', + 'Adhesive tape', 'Harp', 'Sandal', 'Bicycle helmet', 'Saucer', + 'Harpsichord', 'Human hair', 'Heater', 'Harmonica', 'Hamster', + 'Curtain', 'Bed', 'Kettle', 'Fireplace', 'Scale', 'Drinking straw', + 'Insect', 'Hair dryer', 'Kitchenware', 'Indoor rower', 'Invertebrate', + 'Food processor', 'Bookcase', 'Refrigerator', 'Wood-burning stove', + 'Punching bag', 'Common fig', 'Cocktail shaker', 'Jaguar (Animal)', + 'Golf ball', 'Fashion accessory', 'Alarm clock', 'Filing cabinet', + 'Artichoke', 'Table', 'Tableware', 'Kangaroo', 'Koala', 'Knife', + 'Bottle', 'Bottle opener', 'Lynx', 'Lavender (Plant)', 'Lighthouse', + 'Dumbbell', 'Human head', 'Bowl', 'Humidifier', 'Porch', 'Lizard', + 'Billiard table', 'Mammal', 'Mouse', 'Motorcycle', + 'Musical instrument', 'Swim cap', 'Frying pan', 'Snowplow', + 'Bathroom cabinet', 'Missile', 'Bust', 'Man', 'Waffle iron', 'Milk', + 'Ring binder', 'Plate', 'Mobile phone', 'Baked goods', 'Mushroom', + 'Crutch', 'Pitcher (Container)', 'Mirror', 'Personal flotation device', + 'Table tennis racket', 'Pencil case', 'Musical keyboard', 'Scoreboard', + 'Briefcase', 'Kitchen knife', 'Nail (Construction)', 'Tennis ball', + 'Plastic bag', 'Oboe', 'Chest of drawers', 'Ostrich', 'Piano', 'Girl', + 'Plant', 'Potato', 'Hair spray', 'Sports equipment', 'Pasta', + 'Penguin', 'Pumpkin', 'Pear', 'Infant bed', 'Polar bear', 'Mixer', + 'Cupboard', 'Jacuzzi', 'Pizza', 'Digital clock', 'Pig', 'Reptile', + 'Rifle', 'Lipstick', 'Skateboard', 'Raven', 'High heels', 'Red panda', + 'Rose', 'Rabbit', 'Sculpture', 'Saxophone', 'Shotgun', 'Seafood', + 'Submarine sandwich', 'Snowboard', 'Sword', 'Picture frame', 'Sushi', + 'Loveseat', 'Ski', 'Squirrel', 'Tripod', 'Stethoscope', 'Submarine', + 'Scorpion', 'Segway', 'Training bench', 'Snake', 'Coffee table', + 'Skyscraper', 'Sheep', 'Television', 'Trombone', 'Tea', 'Tank', 'Taco', + 'Telephone', 'Torch', 'Tiger', 'Strawberry', 'Trumpet', 'Tree', + 'Tomato', 'Train', 'Tool', 'Picnic basket', 'Cooking spray', + 'Trousers', 'Bowling equipment', 'Football helmet', 'Truck', + 'Measuring cup', 'Coffeemaker', 'Violin', 'Vehicle', 'Handbag', + 'Paper cutter', 'Wine', 'Weapon', 'Wheel', 'Worm', 'Wok', 'Whale', + 'Zebra', 'Auto part', 'Jug', 'Pizza cutter', 'Cream', 'Monkey', 'Lion', + 'Bread', 'Platter', 'Chicken', 'Eagle', 'Helicopter', 'Owl', 'Duck', + 'Turtle', 'Hippopotamus', 'Crocodile', 'Toilet', 'Toilet paper', + 'Squid', 'Clothing', 'Footwear', 'Lemon', 'Spider', 'Deer', 'Frog', + 'Banana', 'Rocket', 'Wine glass', 'Countertop', 'Tablet computer', + 'Waste container', 'Swimming pool', 'Dog', 'Book', 'Elephant', 'Shark', + 'Candle', 'Leopard', 'Axe', 'Hand dryer', 'Soap dispenser', + 'Porcupine', 'Flower', 'Canary', 'Cheetah', 'Palm tree', 'Hamburger', + 'Maple', 'Building', 'Fish', 'Lobster', 'Garden Asparagus', + 'Furniture', 'Hedgehog', 'Airplane', 'Spoon', 'Otter', 'Bull', + 'Oyster', 'Horizontal bar', 'Convenience store', 'Bomb', 'Bench', + 'Ice cream', 'Caterpillar', 'Butterfly', 'Parachute', 'Orange', + 'Antelope', 'Beaker', 'Moths and butterflies', 'Window', 'Closet', + 'Castle', 'Jellyfish', 'Goose', 'Mule', 'Swan', 'Peach', 'Coconut', + 'Seat belt', 'Raccoon', 'Chisel', 'Fork', 'Lamp', 'Camera', + 'Squash (Plant)', 'Racket', 'Human face', 'Human arm', 'Vegetable', + 'Diaper', 'Unicycle', 'Falcon', 'Chime', 'Snail', 'Shellfish', + 'Cabbage', 'Carrot', 'Mango', 'Jeans', 'Flowerpot', 'Pineapple', + 'Drawer', 'Stool', 'Envelope', 'Cake', 'Dragonfly', 'Common sunflower', + 'Microwave oven', 'Honeycomb', 'Marine mammal', 'Sea lion', 'Ladybug', + 'Shelf', 'Watch', 'Candy', 'Salad', 'Parrot', 'Handgun', 'Sparrow', + 'Van', 'Grinder', 'Spice rack', 'Light bulb', 'Corded phone', + 'Sports uniform', 'Tennis racket', 'Wall clock', 'Serving tray', + 'Kitchen & dining room table', 'Dog bed', 'Cake stand', + 'Cat furniture', 'Bathroom accessory', 'Facial tissue holder', + 'Pressure cooker', 'Kitchen appliance', 'Tire', 'Ruler', + 'Luggage and bags', 'Microphone', 'Broccoli', 'Umbrella', 'Pastry', + 'Grapefruit', 'Band-aid', 'Animal', 'Bell pepper', 'Turkey', 'Lily', + 'Pomegranate', 'Doughnut', 'Glasses', 'Human nose', 'Pen', 'Ant', + 'Car', 'Aircraft', 'Human hand', 'Skunk', 'Teddy bear', 'Watermelon', + 'Cantaloupe', 'Dishwasher', 'Flute', 'Balance beam', 'Sandwich', + 'Shrimp', 'Sewing machine', 'Binoculars', 'Rays and skates', 'Ipod', + 'Accordion', 'Willow', 'Crab', 'Crown', 'Seahorse', 'Perfume', + 'Alpaca', 'Taxi', 'Canoe', 'Remote control', 'Wheelchair', + 'Rugby ball', 'Armadillo', 'Maracas', 'Helmet' + ] + + +def objects365v1_classes() -> list: + """Class names of Objects365 V1.""" + return [ + 'person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle', + 'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk', + 'handbag', 'street lights', 'book', 'plate', 'helmet', 'leather shoes', + 'pillow', 'glove', 'potted plant', 'bracelet', 'flower', 'tv', + 'storage box', 'vase', 'bench', 'wine glass', 'boots', 'bowl', + 'dining table', 'umbrella', 'boat', 'flag', 'speaker', 'trash bin/can', + 'stool', 'backpack', 'couch', 'belt', 'carpet', 'basket', + 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table', 'suv', + 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil', 'microphone', + 'sandals', 'canned', 'necklace', 'mirror', 'faucet', 'bicycle', + 'bread', 'high heels', 'ring', 'van', 'watch', 'sink', 'horse', 'fish', + 'apple', 'camera', 'candle', 'teddy bear', 'cake', 'motorcycle', + 'wild bird', 'laptop', 'knife', 'traffic sign', 'cell phone', 'paddle', + 'truck', 'cow', 'power outlet', 'clock', 'drum', 'fork', 'bus', + 'hanger', 'nightstand', 'pot/pan', 'sheep', 'guitar', 'traffic cone', + 'tea pot', 'keyboard', 'tripod', 'hockey', 'fan', 'dog', 'spoon', + 'blackboard/whiteboard', 'balloon', 'air conditioner', 'cymbal', + 'mouse', 'telephone', 'pickup truck', 'orange', 'banana', 'airplane', + 'luggage', 'skis', 'soccer', 'trolley', 'oven', 'remote', + 'baseball glove', 'paper towel', 'refrigerator', 'train', 'tomato', + 'machinery vehicle', 'tent', 'shampoo/shower gel', 'head phone', + 'lantern', 'donut', 'cleaning products', 'sailboat', 'tangerine', + 'pizza', 'kite', 'computer box', 'elephant', 'toiletries', 'gas stove', + 'broccoli', 'toilet', 'stroller', 'shovel', 'baseball bat', + 'microwave', 'skateboard', 'surfboard', 'surveillance camera', 'gun', + 'life saver', 'cat', 'lemon', 'liquid soap', 'zebra', 'duck', + 'sports car', 'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', + 'converter', 'tissue ', 'carrot', 'washing machine', 'vent', 'cookies', + 'cutting/chopping board', 'tennis racket', 'candy', + 'skating and skiing shoes', 'scissors', 'folder', 'baseball', + 'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine', + 'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear', + 'american football', 'basketball', 'potato', 'paint brush', 'printer', + 'billiards', 'fire hydrant', 'goose', 'projector', 'sausage', + 'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball', + 'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee', + 'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender', + 'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango', + 'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion', + 'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale', + 'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple', + 'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle', + 'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar', + 'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD', + 'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado', + 'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear', + 'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn', + 'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball', + 'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice', + 'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel', + 'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste', 'antelope', + 'shrimp', 'rickshaw', 'trombone', 'pomegranate', 'coconut', + 'jellyfish', 'mushroom', 'calculator', 'treadmill', 'butterfly', + 'egg tart', 'cheese', 'pig', 'pomelo', 'race car', 'rice cooker', + 'tuba', 'crosswalk sign', 'papaya', 'hair drier', 'green onion', + 'chips', 'dolphin', 'sushi', 'urinal', 'donkey', 'electric drill', + 'spring rolls', 'tortoise/turtle', 'parrot', 'flute', 'measuring cup', + 'shark', 'steak', 'poker card', 'binoculars', 'llama', 'radish', + 'noodles', 'yak', 'mop', 'crab', 'microscope', 'barbell', 'bread/bun', + 'baozi', 'lion', 'red cabbage', 'polar bear', 'lighter', 'seal', + 'mangosteen', 'comb', 'eraser', 'pitaya', 'scallop', 'pencil case', + 'saw', 'table tennis paddle', 'okra', 'starfish', 'eagle', 'monkey', + 'durian', 'game board', 'rabbit', 'french horn', 'ambulance', + 'asparagus', 'hoverboard', 'pasta', 'target', 'hotair balloon', + 'chainsaw', 'lobster', 'iron', 'flashlight' + ] + + +def objects365v2_classes() -> list: + """Class names of Objects365 V2.""" + return [ + 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', + 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf', + 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', + 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', + 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots', + 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt', + 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', + 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', + 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum', + 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle', 'Guitar', + 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck', + 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', + 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent', + 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner', + 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork', + 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', 'Pot', + 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger', + 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine', + 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', + 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', + 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', + 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', + 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane', + 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat', + 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', + 'Elephant', 'Skateboard', 'Surfboard', 'Gun', + 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot', + 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper', + 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', + 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board', + 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder', + 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', + 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', + 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards', + 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase', + 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear', 'Heavy Truck', + 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', 'Tennis Racket', + 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis', + 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion', + 'Green beans', 'Projector', 'Frisbee', + 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon', + 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon', + 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', + 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer', + 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', + 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle', + 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone', + 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', + 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', + 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', + 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese', + 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue', + 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap', + 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', + 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak', + 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate', + 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba', + 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', 'Buttefly', + 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', + 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter', + 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', + 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak', + 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop', + 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle', + 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster', + 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling', + 'Table Tennis ' + ] + + +dataset_aliases = { + 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], + 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], + 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'], + 'coco': ['coco', 'mscoco', 'ms_coco'], + 'coco_panoptic': ['coco_panoptic', 'panoptic'], + 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WIDERFace'], + 'cityscapes': ['cityscapes'], + 'oid_challenge': ['oid_challenge', 'openimages_challenge'], + 'oid_v6': ['oid_v6', 'openimages_v6'], + 'objects365v1': ['objects365v1', 'obj365v1'], + 'objects365v2': ['objects365v2', 'obj365v2'] +} + + +def get_classes(dataset) -> list: + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/mmdet/evaluation/functional/mean_ap.py b/mmdet/evaluation/functional/mean_ap.py new file mode 100644 index 0000000000000000000000000000000000000000..989972a48467f74fa915fa6f3807d0db3becdba2 --- /dev/null +++ b/mmdet/evaluation/functional/mean_ap.py @@ -0,0 +1,792 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing import Pool + +import numpy as np +from mmengine.logging import print_log +from mmengine.utils import is_str +from terminaltables import AsciiTable + +from .bbox_overlaps import bbox_overlaps +from .class_names import get_classes + + +def average_precision(recalls, precisions, mode='area'): + """Calculate average precision (for single or multiple scales). + + Args: + recalls (ndarray): shape (num_scales, num_dets) or (num_dets, ) + precisions (ndarray): shape (num_scales, num_dets) or (num_dets, ) + mode (str): 'area' or '11points', 'area' means calculating the area + under precision-recall curve, '11points' means calculating + the average precision of recalls at [0, 0.1, ..., 1] + + Returns: + float or ndarray: calculated average precision + """ + no_scale = False + if recalls.ndim == 1: + no_scale = True + recalls = recalls[np.newaxis, :] + precisions = precisions[np.newaxis, :] + assert recalls.shape == precisions.shape and recalls.ndim == 2 + num_scales = recalls.shape[0] + ap = np.zeros(num_scales, dtype=np.float32) + if mode == 'area': + zeros = np.zeros((num_scales, 1), dtype=recalls.dtype) + ones = np.ones((num_scales, 1), dtype=recalls.dtype) + mrec = np.hstack((zeros, recalls, ones)) + mpre = np.hstack((zeros, precisions, zeros)) + for i in range(mpre.shape[1] - 1, 0, -1): + mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i]) + for i in range(num_scales): + ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0] + ap[i] = np.sum( + (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1]) + elif mode == '11points': + for i in range(num_scales): + for thr in np.arange(0, 1 + 1e-3, 0.1): + precs = precisions[i, recalls[i, :] >= thr] + prec = precs.max() if precs.size > 0 else 0 + ap[i] += prec + ap /= 11 + else: + raise ValueError( + 'Unrecognized mode, only "area" and "11points" are supported') + if no_scale: + ap = ap[0] + return ap + + +def tpfp_imagenet(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + default_iou_thr=0.5, + area_ranges=None, + use_legacy_coordinate=False, + **kwargs): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Defaults to None + default_iou_thr (float): IoU threshold to be considered as matched for + medium and large bboxes (small ones have special rules). + Defaults to 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. Defaults to None. + use_legacy_coordinate (bool): Whether to use coordinate system in + mmdet v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Defaults to False. + + Returns: + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). + """ + + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], + dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp + # of a certain scale. + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + return tp, fp + ious = bbox_overlaps( + det_bboxes, gt_bboxes - 1, use_legacy_coordinate=use_legacy_coordinate) + gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length + gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length + iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)), + default_iou_thr) + # sort all detections by scores in descending order + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = gt_w * gt_h + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + max_iou = -1 + matched_gt = -1 + # find best overlapped available gt + for j in range(num_gts): + # different from PASCAL VOC: allow finding other gts if the + # best overlapped ones are already matched by other det bboxes + if gt_covered[j]: + continue + elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou: + max_iou = ious[i, j] + matched_gt = j + # there are 4 cases for a det bbox: + # 1. it matches a gt, tp = 1, fp = 0 + # 2. it matches an ignored gt, tp = 0, fp = 0 + # 3. it matches no gt and within area range, tp = 0, fp = 1 + # 4. it matches no gt but is beyond area range, tp = 0, fp = 0 + if matched_gt >= 0: + gt_covered[matched_gt] = 1 + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + tp[k, i] = 1 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :4] + area = (bbox[2] - bbox[0] + extra_length) * ( + bbox[3] - bbox[1] + extra_length) + if area >= min_area and area < max_area: + fp[k, i] = 1 + return tp, fp + + +def tpfp_default(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None, + use_legacy_coordinate=False, + **kwargs): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Defaults to None + iou_thr (float): IoU threshold to be considered as matched. + Defaults to 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be + evaluated, in the format [(min1, max1), (min2, max2), ...]. + Defaults to None. + use_legacy_coordinate (bool): Whether to use coordinate system in + mmdet v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Defaults to False. + + Returns: + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). + """ + + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], + dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of + # a certain scale + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + + # if there is no gt bboxes in this image, then all det bboxes + # within area range are false positives + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + return tp, fp + + ious = bbox_overlaps( + det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate) + # for each det, the max iou with all gts + ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it + ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length) + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + if ious_max[i] >= iou_thr: + matched_gt = ious_argmax[i] + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[k, i] = 1 + else: + fp[k, i] = 1 + # otherwise ignore this detected bbox, tp = 0, fp = 0 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :4] + area = (bbox[2] - bbox[0] + extra_length) * ( + bbox[3] - bbox[1] + extra_length) + if area >= min_area and area < max_area: + fp[k, i] = 1 + return tp, fp + + +def tpfp_openimages(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None, + use_legacy_coordinate=False, + gt_bboxes_group_of=None, + use_group_of=True, + ioa_thr=0.5, + **kwargs): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 4). Defaults to None + iou_thr (float): IoU threshold to be considered as matched. + Defaults to 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be + evaluated, in the format [(min1, max1), (min2, max2), ...]. + Defaults to None. + use_legacy_coordinate (bool): Whether to use coordinate system in + mmdet v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Defaults to False. + gt_bboxes_group_of (ndarray): GT group_of of this image, of shape + (k, 1). Defaults to None + use_group_of (bool): Whether to use group of when calculate TP and FP, + which only used in OpenImages evaluation. Defaults to True. + ioa_thr (float | None): IoA threshold to be considered as matched, + which only used in OpenImages evaluation. Defaults to 0.5. + + Returns: + tuple[np.ndarray]: Returns a tuple (tp, fp, det_bboxes), where + (tp, fp) whose elements are 0 and 1. The shape of each array is + (num_scales, m). (det_bboxes) whose will filter those are not + matched by group of gts when processing Open Images evaluation. + The shape is (num_scales, m). + """ + + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + # an indicator of ignored gts + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], + dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of + # a certain scale + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + + # if there is no gt bboxes in this image, then all det bboxes + # within area range are false positives + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + return tp, fp, det_bboxes + + if gt_bboxes_group_of is not None and use_group_of: + # if handle group-of boxes, divided gt boxes into two parts: + # non-group-of and group-of.Then calculate ious and ioas through + # non-group-of group-of gts respectively. This only used in + # OpenImages evaluation. + assert gt_bboxes_group_of.shape[0] == gt_bboxes.shape[0] + non_group_gt_bboxes = gt_bboxes[~gt_bboxes_group_of] + group_gt_bboxes = gt_bboxes[gt_bboxes_group_of] + num_gts_group = group_gt_bboxes.shape[0] + ious = bbox_overlaps(det_bboxes, non_group_gt_bboxes) + ioas = bbox_overlaps(det_bboxes, group_gt_bboxes, mode='iof') + else: + # if not consider group-of boxes, only calculate ious through gt boxes + ious = bbox_overlaps( + det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate) + ioas = None + + if ious.shape[1] > 0: + # for each det, the max iou with all gts + ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it + ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = ( + gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length) + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + if ious_max[i] >= iou_thr: + matched_gt = ious_argmax[i] + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[k, i] = 1 + else: + fp[k, i] = 1 + # otherwise ignore this detected bbox, tp = 0, fp = 0 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :4] + area = (bbox[2] - bbox[0] + extra_length) * ( + bbox[3] - bbox[1] + extra_length) + if area >= min_area and area < max_area: + fp[k, i] = 1 + else: + # if there is no no-group-of gt bboxes in this image, + # then all det bboxes within area range are false positives. + # Only used in OpenImages evaluation. + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + det_areas = ( + det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * ( + det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length) + for i, (min_area, max_area) in enumerate(area_ranges): + fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 + + if ioas is None or ioas.shape[1] <= 0: + return tp, fp, det_bboxes + else: + # The evaluation of group-of TP and FP are done in two stages: + # 1. All detections are first matched to non group-of boxes; true + # positives are determined. + # 2. Detections that are determined as false positives are matched + # against group-of boxes and calculated group-of TP and FP. + # Only used in OpenImages evaluation. + det_bboxes_group = np.zeros( + (num_scales, ioas.shape[1], det_bboxes.shape[1]), dtype=float) + match_group_of = np.zeros((num_scales, num_dets), dtype=bool) + tp_group = np.zeros((num_scales, num_gts_group), dtype=np.float32) + ioas_max = ioas.max(axis=1) + # for each det, which gt overlaps most with it + ioas_argmax = ioas.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + box_is_covered = tp[k] + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) + gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) + for i in sort_inds: + matched_gt = ioas_argmax[i] + if not box_is_covered[i]: + if ioas_max[i] >= ioa_thr: + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not tp_group[k, matched_gt]: + tp_group[k, matched_gt] = 1 + match_group_of[k, i] = True + else: + match_group_of[k, i] = True + + if det_bboxes_group[k, matched_gt, -1] < \ + det_bboxes[i, -1]: + det_bboxes_group[k, matched_gt] = \ + det_bboxes[i] + + fp_group = (tp_group <= 0).astype(float) + tps = [] + fps = [] + # concatenate tp, fp, and det-boxes which not matched group of + # gt boxes and tp_group, fp_group, and det_bboxes_group which + # matched group of boxes respectively. + for i in range(num_scales): + tps.append( + np.concatenate((tp[i][~match_group_of[i]], tp_group[i]))) + fps.append( + np.concatenate((fp[i][~match_group_of[i]], fp_group[i]))) + det_bboxes = np.concatenate( + (det_bboxes[~match_group_of[i]], det_bboxes_group[i])) + + tp = np.vstack(tps) + fp = np.vstack(fps) + return tp, fp, det_bboxes + + +def get_cls_results(det_results, annotations, class_id): + """Get det results and gt information of a certain class. + + Args: + det_results (list[list]): Same as `eval_map()`. + annotations (list[dict]): Same as `eval_map()`. + class_id (int): ID of a specific class. + + Returns: + tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes + """ + cls_dets = [img_res[class_id] for img_res in det_results] + cls_gts = [] + cls_gts_ignore = [] + for ann in annotations: + gt_inds = ann['labels'] == class_id + cls_gts.append(ann['bboxes'][gt_inds, :]) + + if ann.get('labels_ignore', None) is not None: + ignore_inds = ann['labels_ignore'] == class_id + cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) + else: + cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32)) + + return cls_dets, cls_gts, cls_gts_ignore + + +def get_cls_group_ofs(annotations, class_id): + """Get `gt_group_of` of a certain class, which is used in Open Images. + + Args: + annotations (list[dict]): Same as `eval_map()`. + class_id (int): ID of a specific class. + + Returns: + list[np.ndarray]: `gt_group_of` of a certain class. + """ + gt_group_ofs = [] + for ann in annotations: + gt_inds = ann['labels'] == class_id + if ann.get('gt_is_group_ofs', None) is not None: + gt_group_ofs.append(ann['gt_is_group_ofs'][gt_inds]) + else: + gt_group_ofs.append(np.empty((0, 1), dtype=bool)) + + return gt_group_ofs + + +def eval_map(det_results, + annotations, + scale_ranges=None, + iou_thr=0.5, + ioa_thr=None, + dataset=None, + logger=None, + tpfp_fn=None, + nproc=4, + use_legacy_coordinate=False, + use_group_of=False, + eval_mode='area'): + """Evaluate mAP of a dataset. + + Args: + det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. + The outer list indicates images, and the inner list indicates + per-class detected bboxes. + annotations (list[dict]): Ground truth annotations where each item of + the list indicates an image. Keys of annotations are: + + - `bboxes`: numpy array of shape (n, 4) + - `labels`: numpy array of shape (n, ) + - `bboxes_ignore` (optional): numpy array of shape (k, 4) + - `labels_ignore` (optional): numpy array of shape (k, ) + scale_ranges (list[tuple] | None): Range of scales to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. A range of + (32, 64) means the area range between (32**2, 64**2). + Defaults to None. + iou_thr (float): IoU threshold to be considered as matched. + Defaults to 0.5. + ioa_thr (float | None): IoA threshold to be considered as matched, + which only used in OpenImages evaluation. Defaults to None. + dataset (list[str] | str | None): Dataset name or dataset classes, + there are minor differences in metrics for different datasets, e.g. + "voc", "imagenet_det", etc. Defaults to None. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmengine.logging.print_log()` for details. + Defaults to None. + tpfp_fn (callable | None): The function used to determine true/ + false positives. If None, :func:`tpfp_default` is used as default + unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this + case). If it is given as a function, then this function is used + to evaluate tp & fp. Default None. + nproc (int): Processes used for computing TP and FP. + Defaults to 4. + use_legacy_coordinate (bool): Whether to use coordinate system in + mmdet v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Defaults to False. + use_group_of (bool): Whether to use group of when calculate TP and FP, + which only used in OpenImages evaluation. Defaults to False. + eval_mode (str): 'area' or '11points', 'area' means calculating the + area under precision-recall curve, '11points' means calculating + the average precision of recalls at [0, 0.1, ..., 1], + PASCAL VOC2007 uses `11points` as default evaluate mode, while + others are 'area'. Defaults to 'area'. + + Returns: + tuple: (mAP, [dict, dict, ...]) + """ + assert len(det_results) == len(annotations) + assert eval_mode in ['area', '11points'], \ + f'Unrecognized {eval_mode} mode, only "area" and "11points" ' \ + 'are supported' + if not use_legacy_coordinate: + extra_length = 0. + else: + extra_length = 1. + + num_imgs = len(det_results) + num_scales = len(scale_ranges) if scale_ranges is not None else 1 + num_classes = len(det_results[0]) # positive class num + area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] + if scale_ranges is not None else None) + + # There is no need to use multi processes to process + # when num_imgs = 1 . + if num_imgs > 1: + assert nproc > 0, 'nproc must be at least one.' + nproc = min(nproc, num_imgs) + pool = Pool(nproc) + + eval_results = [] + for i in range(num_classes): + # get gt and det bboxes of this class + cls_dets, cls_gts, cls_gts_ignore = get_cls_results( + det_results, annotations, i) + # choose proper function according to datasets to compute tp and fp + if tpfp_fn is None: + if dataset in ['det', 'vid']: + tpfp_fn = tpfp_imagenet + elif dataset in ['oid_challenge', 'oid_v6'] \ + or use_group_of is True: + tpfp_fn = tpfp_openimages + else: + tpfp_fn = tpfp_default + if not callable(tpfp_fn): + raise ValueError( + f'tpfp_fn has to be a function or None, but got {tpfp_fn}') + + if num_imgs > 1: + # compute tp and fp for each image with multiple processes + args = [] + if use_group_of: + # used in Open Images Dataset evaluation + gt_group_ofs = get_cls_group_ofs(annotations, i) + args.append(gt_group_ofs) + args.append([use_group_of for _ in range(num_imgs)]) + if ioa_thr is not None: + args.append([ioa_thr for _ in range(num_imgs)]) + + tpfp = pool.starmap( + tpfp_fn, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)], + [use_legacy_coordinate for _ in range(num_imgs)], *args)) + else: + tpfp = tpfp_fn( + cls_dets[0], + cls_gts[0], + cls_gts_ignore[0], + iou_thr, + area_ranges, + use_legacy_coordinate, + gt_bboxes_group_of=(get_cls_group_ofs(annotations, i)[0] + if use_group_of else None), + use_group_of=use_group_of, + ioa_thr=ioa_thr) + tpfp = [tpfp] + + if use_group_of: + tp, fp, cls_dets = tuple(zip(*tpfp)) + else: + tp, fp = tuple(zip(*tpfp)) + # calculate gt number of each scale + # ignored gts or gts beyond the specific scale are not counted + num_gts = np.zeros(num_scales, dtype=int) + for j, bbox in enumerate(cls_gts): + if area_ranges is None: + num_gts[0] += bbox.shape[0] + else: + gt_areas = (bbox[:, 2] - bbox[:, 0] + extra_length) * ( + bbox[:, 3] - bbox[:, 1] + extra_length) + for k, (min_area, max_area) in enumerate(area_ranges): + num_gts[k] += np.sum((gt_areas >= min_area) + & (gt_areas < max_area)) + # sort all det bboxes by score, also sort tp and fp + cls_dets = np.vstack(cls_dets) + num_dets = cls_dets.shape[0] + sort_inds = np.argsort(-cls_dets[:, -1]) + tp = np.hstack(tp)[:, sort_inds] + fp = np.hstack(fp)[:, sort_inds] + # calculate recall and precision with tp and fp + tp = np.cumsum(tp, axis=1) + fp = np.cumsum(fp, axis=1) + eps = np.finfo(np.float32).eps + recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) + precisions = tp / np.maximum((tp + fp), eps) + # calculate AP + if scale_ranges is None: + recalls = recalls[0, :] + precisions = precisions[0, :] + num_gts = num_gts.item() + ap = average_precision(recalls, precisions, eval_mode) + eval_results.append({ + 'num_gts': num_gts, + 'num_dets': num_dets, + 'recall': recalls, + 'precision': precisions, + 'ap': ap + }) + + if num_imgs > 1: + pool.close() + + if scale_ranges is not None: + # shape (num_classes, num_scales) + all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) + all_num_gts = np.vstack( + [cls_result['num_gts'] for cls_result in eval_results]) + mean_ap = [] + for i in range(num_scales): + if np.any(all_num_gts[:, i] > 0): + mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) + else: + mean_ap.append(0.0) + else: + aps = [] + for cls_result in eval_results: + if cls_result['num_gts'] > 0: + aps.append(cls_result['ap']) + mean_ap = np.array(aps).mean().item() if aps else 0.0 + + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) + + return mean_ap, eval_results + + +def print_map_summary(mean_ap, + results, + dataset=None, + scale_ranges=None, + logger=None): + """Print mAP and results of each class. + + A table will be printed to show the gts/dets/recall/AP of each class and + the mAP. + + Args: + mean_ap (float): Calculated from `eval_map()`. + results (list[dict]): Calculated from `eval_map()`. + dataset (list[str] | str | None): Dataset name or dataset classes. + scale_ranges (list[tuple] | None): Range of scales to be evaluated. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmengine.logging.print_log()` for details. + Defaults to None. + """ + + if logger == 'silent': + return + + if isinstance(results[0]['ap'], np.ndarray): + num_scales = len(results[0]['ap']) + else: + num_scales = 1 + + if scale_ranges is not None: + assert len(scale_ranges) == num_scales + + num_classes = len(results) + + recalls = np.zeros((num_scales, num_classes), dtype=np.float32) + aps = np.zeros((num_scales, num_classes), dtype=np.float32) + num_gts = np.zeros((num_scales, num_classes), dtype=int) + for i, cls_result in enumerate(results): + if cls_result['recall'].size > 0: + recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] + aps[:, i] = cls_result['ap'] + num_gts[:, i] = cls_result['num_gts'] + + if dataset is None: + label_names = [str(i) for i in range(num_classes)] + elif is_str(dataset): + label_names = get_classes(dataset) + else: + label_names = dataset + + if not isinstance(mean_ap, list): + mean_ap = [mean_ap] + + header = ['class', 'gts', 'dets', 'recall', 'ap'] + for i in range(num_scales): + if scale_ranges is not None: + print_log(f'Scale range {scale_ranges[i]}', logger=logger) + table_data = [header] + for j in range(num_classes): + row_data = [ + label_names[j], num_gts[i, j], results[j]['num_dets'], + f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' + ] + table_data.append(row_data) + table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) + table = AsciiTable(table_data) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) diff --git a/mmdet/evaluation/functional/panoptic_utils.py b/mmdet/evaluation/functional/panoptic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6faa8ed52bc46c2cb74b1974b8daa521e616e996 --- /dev/null +++ b/mmdet/evaluation/functional/panoptic_utils.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Copyright (c) 2018, Alexander Kirillov +# This file supports `backend_args` for `panopticapi`, +# the source code is copied from `panopticapi`, +# only the way to load the gt images is modified. +import multiprocessing +import os + +import mmcv +import numpy as np +from mmengine.fileio import get + +# A custom value to distinguish instance ID and category ID; need to +# be greater than the number of categories. +# For a pixel in the panoptic result map: +# pan_id = ins_id * INSTANCE_OFFSET + cat_id +INSTANCE_OFFSET = 1000 + +try: + from panopticapi.evaluation import OFFSET, VOID, PQStat + from panopticapi.utils import rgb2id +except ImportError: + PQStat = None + rgb2id = None + VOID = 0 + OFFSET = 256 * 256 * 256 + + +def pq_compute_single_core(proc_id, + annotation_set, + gt_folder, + pred_folder, + categories, + backend_args=None, + print_log=False): + """The single core function to evaluate the metric of Panoptic + Segmentation. + + Same as the function with the same name in `panopticapi`. Only the function + to load the images is changed to use the file client. + + Args: + proc_id (int): The id of the mini process. + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + backend_args (object): The Backend of the dataset. If None, + the backend will be set to `local`. + print_log (bool): Whether to print the log. Defaults to False. + """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + pq_stat = PQStat() + + idx = 0 + for gt_ann, pred_ann in annotation_set: + if print_log and idx % 100 == 0: + print('Core: {}, {} from {} images processed'.format( + proc_id, idx, len(annotation_set))) + idx += 1 + # The gt images can be on the local disk or `ceph`, so we use + # backend here. + img_bytes = get( + os.path.join(gt_folder, gt_ann['file_name']), + backend_args=backend_args) + pan_gt = mmcv.imfrombytes(img_bytes, flag='color', channel_order='rgb') + pan_gt = rgb2id(pan_gt) + + # The predictions can only be on the local dist now. + pan_pred = mmcv.imread( + os.path.join(pred_folder, pred_ann['file_name']), + flag='color', + channel_order='rgb') + pan_pred = rgb2id(pan_pred) + + gt_segms = {el['id']: el for el in gt_ann['segments_info']} + pred_segms = {el['id']: el for el in pred_ann['segments_info']} + + # predicted segments area calculation + prediction sanity checks + pred_labels_set = set(el['id'] for el in pred_ann['segments_info']) + labels, labels_cnt = np.unique(pan_pred, return_counts=True) + for label, label_cnt in zip(labels, labels_cnt): + if label not in pred_segms: + if label == VOID: + continue + raise KeyError( + 'In the image with ID {} segment with ID {} is ' + 'presented in PNG and not presented in JSON.'.format( + gt_ann['image_id'], label)) + pred_segms[label]['area'] = label_cnt + pred_labels_set.remove(label) + if pred_segms[label]['category_id'] not in categories: + raise KeyError( + 'In the image with ID {} segment with ID {} has ' + 'unknown category_id {}.'.format( + gt_ann['image_id'], label, + pred_segms[label]['category_id'])) + if len(pred_labels_set) != 0: + raise KeyError( + 'In the image with ID {} the following segment IDs {} ' + 'are presented in JSON and not presented in PNG.'.format( + gt_ann['image_id'], list(pred_labels_set))) + + # confusion matrix calculation + pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype( + np.uint64) + gt_pred_map = {} + labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) + for label, intersection in zip(labels, labels_cnt): + gt_id = label // OFFSET + pred_id = label % OFFSET + gt_pred_map[(gt_id, pred_id)] = intersection + + # count all matched pairs + gt_matched = set() + pred_matched = set() + for label_tuple, intersection in gt_pred_map.items(): + gt_label, pred_label = label_tuple + if gt_label not in gt_segms: + continue + if pred_label not in pred_segms: + continue + if gt_segms[gt_label]['iscrowd'] == 1: + continue + if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][ + 'category_id']: + continue + + union = pred_segms[pred_label]['area'] + gt_segms[gt_label][ + 'area'] - intersection - gt_pred_map.get((VOID, pred_label), 0) + iou = intersection / union + if iou > 0.5: + pq_stat[gt_segms[gt_label]['category_id']].tp += 1 + pq_stat[gt_segms[gt_label]['category_id']].iou += iou + gt_matched.add(gt_label) + pred_matched.add(pred_label) + + # count false positives + crowd_labels_dict = {} + for gt_label, gt_info in gt_segms.items(): + if gt_label in gt_matched: + continue + # crowd segments are ignored + if gt_info['iscrowd'] == 1: + crowd_labels_dict[gt_info['category_id']] = gt_label + continue + pq_stat[gt_info['category_id']].fn += 1 + + # count false positives + for pred_label, pred_info in pred_segms.items(): + if pred_label in pred_matched: + continue + # intersection of the segment with VOID + intersection = gt_pred_map.get((VOID, pred_label), 0) + # plus intersection with corresponding CROWD region if it exists + if pred_info['category_id'] in crowd_labels_dict: + intersection += gt_pred_map.get( + (crowd_labels_dict[pred_info['category_id']], pred_label), + 0) + # predicted segment is ignored if more than half of + # the segment correspond to VOID and CROWD regions + if intersection / pred_info['area'] > 0.5: + continue + pq_stat[pred_info['category_id']].fp += 1 + + if print_log: + print('Core: {}, all {} images processed'.format( + proc_id, len(annotation_set))) + return pq_stat + + +def pq_compute_multi_core(matched_annotations_list, + gt_folder, + pred_folder, + categories, + backend_args=None, + nproc=32): + """Evaluate the metrics of Panoptic Segmentation with multithreading. + + Same as the function with the same name in `panopticapi`. + + Args: + matched_annotations_list (list): The matched annotation list. Each + element is a tuple of annotations of the same image with the + format (gt_anns, pred_anns). + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + backend_args (object): The file client of the dataset. If None, + the backend will be set to `local`. + nproc (int): Number of processes for panoptic quality computing. + Defaults to 32. When `nproc` exceeds the number of cpu cores, + the number of cpu cores is used. + """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + cpu_num = min(nproc, multiprocessing.cpu_count()) + + annotations_split = np.array_split(matched_annotations_list, cpu_num) + print('Number of cores: {}, images per core: {}'.format( + cpu_num, len(annotations_split[0]))) + workers = multiprocessing.Pool(processes=cpu_num) + processes = [] + for proc_id, annotation_set in enumerate(annotations_split): + p = workers.apply_async(pq_compute_single_core, + (proc_id, annotation_set, gt_folder, + pred_folder, categories, backend_args)) + processes.append(p) + + # Close the process pool, otherwise it will lead to memory + # leaking problems. + workers.close() + workers.join() + + pq_stat = PQStat() + for p in processes: + pq_stat += p.get() + + return pq_stat diff --git a/mmdet/evaluation/functional/recall.py b/mmdet/evaluation/functional/recall.py new file mode 100644 index 0000000000000000000000000000000000000000..4bce2bf3614ab454dbbdf48efc4650018cc71b13 --- /dev/null +++ b/mmdet/evaluation/functional/recall.py @@ -0,0 +1,199 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence + +import numpy as np +from mmengine.logging import print_log +from terminaltables import AsciiTable + +from .bbox_overlaps import bbox_overlaps + + +def _recalls(all_ious, proposal_nums, thrs): + + img_num = all_ious.shape[0] + total_gt_num = sum([ious.shape[0] for ious in all_ious]) + + _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32) + for k, proposal_num in enumerate(proposal_nums): + tmp_ious = np.zeros(0) + for i in range(img_num): + ious = all_ious[i][:, :proposal_num].copy() + gt_ious = np.zeros((ious.shape[0])) + if ious.size == 0: + tmp_ious = np.hstack((tmp_ious, gt_ious)) + continue + for j in range(ious.shape[0]): + gt_max_overlaps = ious.argmax(axis=1) + max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps] + gt_idx = max_ious.argmax() + gt_ious[j] = max_ious[gt_idx] + box_idx = gt_max_overlaps[gt_idx] + ious[gt_idx, :] = -1 + ious[:, box_idx] = -1 + tmp_ious = np.hstack((tmp_ious, gt_ious)) + _ious[k, :] = tmp_ious + + _ious = np.fliplr(np.sort(_ious, axis=1)) + recalls = np.zeros((proposal_nums.size, thrs.size)) + for i, thr in enumerate(thrs): + recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num) + + return recalls + + +def set_recall_param(proposal_nums, iou_thrs): + """Check proposal_nums and iou_thrs and set correct format.""" + if isinstance(proposal_nums, Sequence): + _proposal_nums = np.array(proposal_nums) + elif isinstance(proposal_nums, int): + _proposal_nums = np.array([proposal_nums]) + else: + _proposal_nums = proposal_nums + + if iou_thrs is None: + _iou_thrs = np.array([0.5]) + elif isinstance(iou_thrs, Sequence): + _iou_thrs = np.array(iou_thrs) + elif isinstance(iou_thrs, float): + _iou_thrs = np.array([iou_thrs]) + else: + _iou_thrs = iou_thrs + + return _proposal_nums, _iou_thrs + + +def eval_recalls(gts, + proposals, + proposal_nums=None, + iou_thrs=0.5, + logger=None, + use_legacy_coordinate=False): + """Calculate recalls. + + Args: + gts (list[ndarray]): a list of arrays of shape (n, 4) + proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5) + proposal_nums (int | Sequence[int]): Top N proposals to be evaluated. + iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5. + logger (logging.Logger | str | None): The way to print the recall + summary. See `mmengine.logging.print_log()` for details. + Default: None. + use_legacy_coordinate (bool): Whether use coordinate system + in mmdet v1.x. "1" was added to both height and width + which means w, h should be + computed as 'x2 - x1 + 1` and 'y2 - y1 + 1'. Default: False. + + + Returns: + ndarray: recalls of different ious and proposal nums + """ + + img_num = len(gts) + assert img_num == len(proposals) + proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs) + all_ious = [] + for i in range(img_num): + if proposals[i].ndim == 2 and proposals[i].shape[1] == 5: + scores = proposals[i][:, 4] + sort_idx = np.argsort(scores)[::-1] + img_proposal = proposals[i][sort_idx, :] + else: + img_proposal = proposals[i] + prop_num = min(img_proposal.shape[0], proposal_nums[-1]) + if gts[i] is None or gts[i].shape[0] == 0: + ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32) + else: + ious = bbox_overlaps( + gts[i], + img_proposal[:prop_num, :4], + use_legacy_coordinate=use_legacy_coordinate) + all_ious.append(ious) + all_ious = np.array(all_ious) + recalls = _recalls(all_ious, proposal_nums, iou_thrs) + + print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger) + return recalls + + +def print_recall_summary(recalls, + proposal_nums, + iou_thrs, + row_idxs=None, + col_idxs=None, + logger=None): + """Print recalls in a table. + + Args: + recalls (ndarray): calculated from `bbox_recalls` + proposal_nums (ndarray or list): top N proposals + iou_thrs (ndarray or list): iou thresholds + row_idxs (ndarray): which rows(proposal nums) to print + col_idxs (ndarray): which cols(iou thresholds) to print + logger (logging.Logger | str | None): The way to print the recall + summary. See `mmengine.logging.print_log()` for details. + Default: None. + """ + proposal_nums = np.array(proposal_nums, dtype=np.int32) + iou_thrs = np.array(iou_thrs) + if row_idxs is None: + row_idxs = np.arange(proposal_nums.size) + if col_idxs is None: + col_idxs = np.arange(iou_thrs.size) + row_header = [''] + iou_thrs[col_idxs].tolist() + table_data = [row_header] + for i, num in enumerate(proposal_nums[row_idxs]): + row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()] + row.insert(0, num) + table_data.append(row) + table = AsciiTable(table_data) + print_log('\n' + table.table, logger=logger) + + +def plot_num_recall(recalls, proposal_nums): + """Plot Proposal_num-Recalls curve. + + Args: + recalls(ndarray or list): shape (k,) + proposal_nums(ndarray or list): same shape as `recalls` + """ + if isinstance(proposal_nums, np.ndarray): + _proposal_nums = proposal_nums.tolist() + else: + _proposal_nums = proposal_nums + if isinstance(recalls, np.ndarray): + _recalls = recalls.tolist() + else: + _recalls = recalls + + import matplotlib.pyplot as plt + f = plt.figure() + plt.plot([0] + _proposal_nums, [0] + _recalls) + plt.xlabel('Proposal num') + plt.ylabel('Recall') + plt.axis([0, proposal_nums.max(), 0, 1]) + f.show() + + +def plot_iou_recall(recalls, iou_thrs): + """Plot IoU-Recalls curve. + + Args: + recalls(ndarray or list): shape (k,) + iou_thrs(ndarray or list): same shape as `recalls` + """ + if isinstance(iou_thrs, np.ndarray): + _iou_thrs = iou_thrs.tolist() + else: + _iou_thrs = iou_thrs + if isinstance(recalls, np.ndarray): + _recalls = recalls.tolist() + else: + _recalls = recalls + + import matplotlib.pyplot as plt + f = plt.figure() + plt.plot(_iou_thrs + [1.0], _recalls + [0.]) + plt.xlabel('IoU') + plt.ylabel('Recall') + plt.axis([iou_thrs.min(), 1, 0, 1]) + f.show() diff --git a/mmdet/evaluation/functional/ytvis.py b/mmdet/evaluation/functional/ytvis.py new file mode 100644 index 0000000000000000000000000000000000000000..c65a7e9bc956c7de42e0d6e511dabb3d7325782d --- /dev/null +++ b/mmdet/evaluation/functional/ytvis.py @@ -0,0 +1,305 @@ +# Copyright (c) Github URL +# Copied from +# https://github.com/youtubevos/cocoapi/blob/master/PythonAPI/pycocotools/ytvos.py +__author__ = 'ychfan' +# Interface for accessing the YouTubeVIS dataset. + +# The following API functions are defined: +# YTVIS - YTVIS api class that loads YouTubeVIS annotation file +# and prepare data structures. +# decodeMask - Decode binary mask M encoded via run-length encoding. +# encodeMask - Encode binary mask M using run-length encoding. +# getAnnIds - Get ann ids that satisfy given filter conditions. +# getCatIds - Get cat ids that satisfy given filter conditions. +# getImgIds - Get img ids that satisfy given filter conditions. +# loadAnns - Load anns with the specified ids. +# loadCats - Load cats with the specified ids. +# loadImgs - Load imgs with the specified ids. +# annToMask - Convert segmentation in an annotation to binary mask. +# loadRes - Load algorithm results and create API for accessing them. + +# Microsoft COCO Toolbox. version 2.0 +# Data, paper, and tutorials available at: http://mscoco.org/ +# Code written by Piotr Dollar and Tsung-Yi Lin, 2014. +# Licensed under the Simplified BSD License [see bsd.txt] + +import copy +import itertools +import json +import sys +import time +from collections import defaultdict + +import numpy as np +from pycocotools import mask as maskUtils + +PYTHON_VERSION = sys.version_info[0] + + +def _isArrayLike(obj): + return hasattr(obj, '__iter__') and hasattr(obj, '__len__') + + +class YTVIS: + + def __init__(self, annotation_file=None): + """Constructor of Microsoft COCO helper class for reading and + visualizing annotations. + + :param annotation_file (str | dict): location of annotation file or + dict results. + :param image_folder (str): location to the folder that hosts images. + :return: + """ + # load dataset + self.dataset, self.anns, self.cats, self.vids = dict(), dict(), dict( + ), dict() + self.vidToAnns, self.catToVids = defaultdict(list), defaultdict(list) + if annotation_file is not None: + print('loading annotations into memory...') + tic = time.time() + if type(annotation_file) == str: + dataset = json.load(open(annotation_file, 'r')) + else: + dataset = annotation_file + assert type( + dataset + ) == dict, 'annotation file format {} not supported'.format( + type(dataset)) + print('Done (t={:0.2f}s)'.format(time.time() - tic)) + self.dataset = dataset + self.createIndex() + + def createIndex(self): + # create index + print('creating index...') + anns, cats, vids = {}, {}, {} + vidToAnns, catToVids = defaultdict(list), defaultdict(list) + if 'annotations' in self.dataset: + for ann in self.dataset['annotations']: + vidToAnns[ann['video_id']].append(ann) + anns[ann['id']] = ann + + if 'videos' in self.dataset: + for vid in self.dataset['videos']: + vids[vid['id']] = vid + + if 'categories' in self.dataset: + for cat in self.dataset['categories']: + cats[cat['id']] = cat + + if 'annotations' in self.dataset and 'categories' in self.dataset: + for ann in self.dataset['annotations']: + catToVids[ann['category_id']].append(ann['video_id']) + + print('index created!') + + # create class members + self.anns = anns + self.vidToAnns = vidToAnns + self.catToVids = catToVids + self.vids = vids + self.cats = cats + + def getAnnIds(self, vidIds=[], catIds=[], areaRng=[], iscrowd=None): + """Get ann ids that satisfy given filter conditions. default skips that + filter. + + :param vidIds (int array) : get anns for given vids + catIds (int array) : get anns for given cats + areaRng (float array) : get anns for given area range + iscrowd (boolean) : get anns for given crowd label + :return: ids (int array) : integer array of ann ids + """ + vidIds = vidIds if _isArrayLike(vidIds) else [vidIds] + catIds = catIds if _isArrayLike(catIds) else [catIds] + + if len(vidIds) == len(catIds) == len(areaRng) == 0: + anns = self.dataset['annotations'] + else: + if not len(vidIds) == 0: + lists = [ + self.vidToAnns[vidId] for vidId in vidIds + if vidId in self.vidToAnns + ] + anns = list(itertools.chain.from_iterable(lists)) + else: + anns = self.dataset['annotations'] + anns = anns if len(catIds) == 0 else [ + ann for ann in anns if ann['category_id'] in catIds + ] + anns = anns if len(areaRng) == 0 else [ + ann for ann in anns if ann['avg_area'] > areaRng[0] + and ann['avg_area'] < areaRng[1] + ] + if iscrowd is not None: + ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] + else: + ids = [ann['id'] for ann in anns] + return ids + + def getCatIds(self, catNms=[], supNms=[], catIds=[]): + """filtering parameters. default skips that filter. + + :param catNms (str array) : get cats for given cat names + :param supNms (str array) : get cats for given supercategory names + :param catIds (int array) : get cats for given cat ids + :return: ids (int array) : integer array of cat ids + """ + catNms = catNms if _isArrayLike(catNms) else [catNms] + supNms = supNms if _isArrayLike(supNms) else [supNms] + catIds = catIds if _isArrayLike(catIds) else [catIds] + + if len(catNms) == len(supNms) == len(catIds) == 0: + cats = self.dataset['categories'] + else: + cats = self.dataset['categories'] + cats = cats if len(catNms) == 0 else [ + cat for cat in cats if cat['name'] in catNms + ] + cats = cats if len(supNms) == 0 else [ + cat for cat in cats if cat['supercategory'] in supNms + ] + cats = cats if len(catIds) == 0 else [ + cat for cat in cats if cat['id'] in catIds + ] + ids = [cat['id'] for cat in cats] + return ids + + def getVidIds(self, vidIds=[], catIds=[]): + """Get vid ids that satisfy given filter conditions. + + :param vidIds (int array) : get vids for given ids + :param catIds (int array) : get vids with all given cats + :return: ids (int array) : integer array of vid ids + """ + vidIds = vidIds if _isArrayLike(vidIds) else [vidIds] + catIds = catIds if _isArrayLike(catIds) else [catIds] + + if len(vidIds) == len(catIds) == 0: + ids = self.vids.keys() + else: + ids = set(vidIds) + for i, catId in enumerate(catIds): + if i == 0 and len(ids) == 0: + ids = set(self.catToVids[catId]) + else: + ids &= set(self.catToVids[catId]) + return list(ids) + + def loadAnns(self, ids=[]): + """Load anns with the specified ids. + + :param ids (int array) : integer ids specifying anns + :return: anns (object array) : loaded ann objects + """ + if _isArrayLike(ids): + return [self.anns[id] for id in ids] + elif type(ids) == int: + return [self.anns[ids]] + + def loadCats(self, ids=[]): + """Load cats with the specified ids. + + :param ids (int array) : integer ids specifying cats + :return: cats (object array) : loaded cat objects + """ + if _isArrayLike(ids): + return [self.cats[id] for id in ids] + elif type(ids) == int: + return [self.cats[ids]] + + def loadVids(self, ids=[]): + """Load anns with the specified ids. + + :param ids (int array) : integer ids specifying vid + :return: vids (object array) : loaded vid objects + """ + if _isArrayLike(ids): + return [self.vids[id] for id in ids] + elif type(ids) == int: + return [self.vids[ids]] + + def loadRes(self, resFile): + """Load result file and return a result api object. + + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = YTVIS() + res.dataset['videos'] = [img for img in self.dataset['videos']] + + print('Loading and preparing results...') + tic = time.time() + if type(resFile) == str or (PYTHON_VERSION == 2 + and type(resFile) == str): + anns = json.load(open(resFile)) + elif type(resFile) == np.ndarray: + anns = self.loadNumpyAnnotations(resFile) + else: + anns = resFile + assert type(anns) == list, 'results in not an array of objects' + annsVidIds = [ann['video_id'] for ann in anns] + assert set(annsVidIds) == (set(annsVidIds) & set(self.getVidIds())), \ + 'Results do not correspond to current coco set' + if 'segmentations' in anns[0]: + res.dataset['categories'] = copy.deepcopy( + self.dataset['categories']) + for id, ann in enumerate(anns): + ann['areas'] = [] + if 'bboxes' not in ann: + ann['bboxes'] = [] + for seg in ann['segmentations']: + # now only support compressed RLE format + # as segmentation results + if seg: + ann['areas'].append(maskUtils.area(seg)) + if len(ann['bboxes']) < len(ann['areas']): + ann['bboxes'].append(maskUtils.toBbox(seg)) + else: + ann['areas'].append(None) + if len(ann['bboxes']) < len(ann['areas']): + ann['bboxes'].append(None) + ann['id'] = id + 1 + l_ori = [a for a in ann['areas'] if a] + if len(l_ori) == 0: + ann['avg_area'] = 0 + else: + ann['avg_area'] = np.array(l_ori).mean() + ann['iscrowd'] = 0 + print('DONE (t={:0.2f}s)'.format(time.time() - tic)) + + res.dataset['annotations'] = anns + res.createIndex() + return res + + def annToRLE(self, ann, frameId): + """Convert annotation which can be polygons, uncompressed RLE to RLE. + + :return: binary mask (numpy 2D array) + """ + t = self.vids[ann['video_id']] + h, w = t['height'], t['width'] + segm = ann['segmentations'][frameId] + if type(segm) == list: + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(segm, h, w) + rle = maskUtils.merge(rles) + elif type(segm['counts']) == list: + # uncompressed RLE + rle = maskUtils.frPyObjects(segm, h, w) + else: + # rle + rle = segm + return rle + + def annToMask(self, ann, frameId): + """Convert annotation which can be polygons, uncompressed RLE, or RLE + to binary mask. + + :return: binary mask (numpy 2D array) + """ + rle = self.annToRLE(ann, frameId) + m = maskUtils.decode(rle) + return m diff --git a/mmdet/evaluation/functional/ytviseval.py b/mmdet/evaluation/functional/ytviseval.py new file mode 100644 index 0000000000000000000000000000000000000000..fdaf110d37c61b4e02873a4dc83e1722a70a29f1 --- /dev/null +++ b/mmdet/evaluation/functional/ytviseval.py @@ -0,0 +1,623 @@ +# Copyright (c) Github URL +# Copied from +# https://github.com/youtubevos/cocoapi/blob/master/PythonAPI/pycocotools/ytvoseval.py +__author__ = 'ychfan' + +import copy +import datetime +import time +from collections import defaultdict + +import numpy as np +from pycocotools import mask as maskUtils + + +class YTVISeval: + # Interface for evaluating video instance segmentation on + # the YouTubeVIS dataset. + # + # The usage for YTVISeval is as follows: + # cocoGt=..., cocoDt=... # load dataset and results + # E = YTVISeval(cocoGt,cocoDt); # initialize YTVISeval object + # E.params.recThrs = ...; # set parameters as desired + # E.evaluate(); # run per image evaluation + # E.accumulate(); # accumulate per image results + # E.summarize(); # display summary metrics of results + # For example usage see evalDemo.m and http://mscoco.org/. + # + # The evaluation parameters are as follows (defaults in brackets): + # imgIds - [all] N img ids to use for evaluation + # catIds - [all] K cat ids to use for evaluation + # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation + # recThrs - [0:.01:1] R=101 recall thresholds for evaluation + # areaRng - [...] A=4 object area ranges for evaluation + # maxDets - [1 10 100] M=3 thresholds on max detections per image + # iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints' + # iouType replaced the now DEPRECATED useSegm parameter. + # useCats - [1] if true use category labels for evaluation + # Note: if useCats=0 category labels are ignored as in proposal scoring. + # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified. + # + # evaluate(): evaluates detections on every image and every category and + # concats the results into the "evalImgs" with fields: + # dtIds - [1xD] id for each of the D detections (dt) + # gtIds - [1xG] id for each of the G ground truths (gt) + # dtMatches - [TxD] matching gt id at each IoU or 0 + # gtMatches - [TxG] matching dt id at each IoU or 0 + # dtScores - [1xD] confidence of each dt + # gtIgnore - [1xG] ignore flag for each gt + # dtIgnore - [TxD] ignore flag for each dt at each IoU + # + # accumulate(): accumulates the per-image, per-category evaluation + # results in "evalImgs" into the dictionary "eval" with fields: + # params - parameters used for evaluation + # date - date evaluation was performed + # counts - [T,R,K,A,M] parameter dimensions (see above) + # precision - [TxRxKxAxM] precision for every evaluation setting + # recall - [TxKxAxM] max recall for every evaluation setting + # Note: precision and recall==-1 for settings with no gt objects. + # + # See also coco, mask, pycocoDemo, pycocoEvalDemo + # + # Microsoft COCO Toolbox. version 2.0 + # Data, paper, and tutorials available at: http://mscoco.org/ + # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. + # Licensed under the Simplified BSD License [see coco/license.txt] + def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'): + """Initialize CocoEval using coco APIs for gt and dt. + + :param cocoGt: coco object with ground truth annotations + :param cocoDt: coco object with detection results + :return: None + """ + if not iouType: + print('iouType not specified. use default iouType segm') + self.cocoGt = cocoGt # ground truth COCO API + self.cocoDt = cocoDt # detections COCO API + self.params = {} # evaluation parameters + self.evalVids = defaultdict( + list) # per-image per-category evaluation results [KxAxI] elements + self.eval = {} # accumulated evaluation results + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self.params = Params(iouType=iouType) # parameters + self._paramsEval = {} # parameters for evaluation + self.stats = [] # result summarization + self.ious = {} # ious between all gts and dts + if cocoGt is not None: + self.params.vidIds = sorted(cocoGt.getVidIds()) + self.params.catIds = sorted(cocoGt.getCatIds()) + + def _prepare(self): + ''' + Prepare ._gts and ._dts for evaluation based on params + :return: None + ''' + + def _toMask(anns, coco): + # modify ann['segmentation'] by reference + for ann in anns: + for i, a in enumerate(ann['segmentations']): + if a: + rle = coco.annToRLE(ann, i) + ann['segmentations'][i] = rle + l_ori = [a for a in ann['areas'] if a] + if len(l_ori) == 0: + ann['avg_area'] = 0 + else: + ann['avg_area'] = np.array(l_ori).mean() + + p = self.params + if p.useCats: + gts = self.cocoGt.loadAnns( + self.cocoGt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds)) + dts = self.cocoDt.loadAnns( + self.cocoDt.getAnnIds(vidIds=p.vidIds, catIds=p.catIds)) + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(vidIds=p.vidIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(vidIds=p.vidIds)) + + # convert ground truth to mask if iouType == 'segm' + if p.iouType == 'segm': + _toMask(gts, self.cocoGt) + _toMask(dts, self.cocoDt) + # set ignore flag + for gt in gts: + gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0 + gt['ignore'] = 'iscrowd' in gt and gt['iscrowd'] + if p.iouType == 'keypoints': + gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore'] + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + for gt in gts: + self._gts[gt['video_id'], gt['category_id']].append(gt) + for dt in dts: + self._dts[dt['video_id'], dt['category_id']].append(dt) + self.evalVids = defaultdict( + list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + def evaluate(self): + ''' + Run per image evaluation on given images and store + results (a list of dict) in self.evalVids + :return: None + ''' + tic = time.time() + print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'. + format(p.iouType)) + print('Evaluate annotation type *{}*'.format(p.iouType)) + p.vidIds = list(np.unique(p.vidIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = {(vidId, catId): computeIoU(vidId, catId) + for vidId in p.vidIds for catId in catIds} + + evaluateVid = self.evaluateVid + maxDet = p.maxDets[-1] + + self.evalImgs = [ + evaluateVid(vidId, catId, areaRng, maxDet) for catId in catIds + for areaRng in p.areaRng for vidId in p.vidIds + ] + self._paramsEval = copy.deepcopy(self.params) + toc = time.time() + print('DONE (t={:0.2f}s).'.format(toc - tic)) + + def computeIoU(self, vidId, catId): + p = self.params + if p.useCats: + gt = self._gts[vidId, catId] + dt = self._dts[vidId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[vidId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[vidId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + inds = np.argsort([-d['score'] for d in dt], kind='mergesort') + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0:p.maxDets[-1]] + + if p.iouType == 'segm': + g = [g['segmentations'] for g in gt] + d = [d['segmentations'] for d in dt] + elif p.iouType == 'bbox': + g = [g['bboxes'] for g in gt] + d = [d['bboxes'] for d in dt] + else: + raise Exception('unknown iouType for iou computation') + + # compute iou between each dt and gt region + + def iou_seq(d_seq, g_seq): + i = .0 + u = .0 + for d, g in zip(d_seq, g_seq): + if d and g: + i += maskUtils.area(maskUtils.merge([d, g], True)) + u += maskUtils.area(maskUtils.merge([d, g], False)) + elif not d and g: + u += maskUtils.area(g) + elif d and not g: + u += maskUtils.area(d) + if not u > .0: + print('Mask sizes in video {} and category {} may not match!'. + format(vidId, catId)) + iou = i / u if u > .0 else .0 + return iou + + ious = np.zeros([len(d), len(g)]) + for i, j in np.ndindex(ious.shape): + ious[i, j] = iou_seq(d[i], g[j]) + + return ious + + def computeOks(self, imgId, catId): + p = self.params + + gts = self._gts[imgId, catId] + dts = self._dts[imgId, catId] + inds = np.argsort([-d['score'] for d in dts], kind='mergesort') + dts = [dts[i] for i in inds] + if len(dts) > p.maxDets[-1]: + dts = dts[0:p.maxDets[-1]] + # if len(gts) == 0 and len(dts) == 0: + if len(gts) == 0 or len(dts) == 0: + return [] + ious = np.zeros((len(dts), len(gts))) + sigmas = np.array([ + .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, + .87, .87, .89, .89 + ]) / 10.0 + vars = (sigmas * 2)**2 + k = len(sigmas) + # compute oks between each detection and ground truth object + for j, gt in enumerate(gts): + # create bounds for ignore regions(double the gt bbox) + g = np.array(gt['keypoints']) + xg = g[0::3] + yg = g[1::3] + vg = g[2::3] + k1 = np.count_nonzero(vg > 0) + bb = gt['bbox'] + x0 = bb[0] - bb[2] + x1 = bb[0] + bb[2] * 2 + y0 = bb[1] - bb[3] + y1 = bb[1] + bb[3] * 2 + for i, dt in enumerate(dts): + d = np.array(dt['keypoints']) + xd = d[0::3] + yd = d[1::3] + if k1 > 0: + # measure the per-keypoint distance if keypoints visible + dx = xd - xg + dy = yd - yg + else: + # measure minimum distance to keypoints + z = np.zeros((k)) + dx = np.max((z, x0 - xd), axis=0) + np.max( + (z, xd - x1), axis=0) + dy = np.max((z, y0 - yd), axis=0) + np.max( + (z, yd - y1), axis=0) + e = (dx**2 + dy**2) / vars / (gt['avg_area'] + + np.spacing(1)) / 2 + if k1 > 0: + e = e[vg > 0] + ious[i, j] = np.sum(np.exp(-e)) / e.shape[0] + return ious + + def evaluateVid(self, vidId, catId, aRng, maxDet): + ''' + perform evaluation for single category and image + :return: dict (single image results) + ''' + p = self.params + if p.useCats: + gt = self._gts[vidId, catId] + dt = self._dts[vidId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[vidId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[vidId, cId]] + if len(gt) == 0 and len(dt) == 0: + return None + + for g in gt: + if g['ignore'] or (g['avg_area'] < aRng[0] + or g['avg_area'] > aRng[1]): + g['_ignore'] = 1 + else: + g['_ignore'] = 0 + + # sort dt highest score first, sort gt ignore last + gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') + gt = [gt[i] for i in gtind] + dtind = np.argsort([-d['score'] for d in dt], kind='mergesort') + dt = [dt[i] for i in dtind[0:maxDet]] + iscrowd = [int(o['iscrowd']) for o in gt] + # load computed ious + ious = self.ious[vidId, catId][:, gtind] if len( + self.ious[vidId, catId]) > 0 else self.ious[vidId, catId] + + T = len(p.iouThrs) + G = len(gt) + D = len(dt) + gtm = np.zeros((T, G)) + dtm = np.zeros((T, D)) + gtIg = np.array([g['_ignore'] for g in gt]) + dtIg = np.zeros((T, D)) + if not len(ious) == 0: + for tind, t in enumerate(p.iouThrs): + for dind, d in enumerate(dt): + # information about best match so far (m=-1 -> unmatched) + iou = min([t, 1 - 1e-10]) + m = -1 + for gind, g in enumerate(gt): + # if this gt already matched, and not a crowd, continue + if gtm[tind, gind] > 0 and not iscrowd[gind]: + continue + # if dt matched to reg gt, and on ignore gt, stop + if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1: + break + # continue to next gt unless better match made + if ious[dind, gind] < iou: + continue + # if match successful and best so far, + # store appropriately + iou = ious[dind, gind] + m = gind + # if match made store id of match for both dt and gt + if m == -1: + continue + dtIg[tind, dind] = gtIg[m] + dtm[tind, dind] = gt[m]['id'] + gtm[tind, m] = d['id'] + # set unmatched detections outside of area range to ignore + a = np.array([ + d['avg_area'] < aRng[0] or d['avg_area'] > aRng[1] for d in dt + ]).reshape((1, len(dt))) + dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, + 0))) + # store results for given image and category + return { + 'video_id': vidId, + 'category_id': catId, + 'aRng': aRng, + 'maxDet': maxDet, + 'dtIds': [d['id'] for d in dt], + 'gtIds': [g['id'] for g in gt], + 'dtMatches': dtm, + 'gtMatches': gtm, + 'dtScores': [d['score'] for d in dt], + 'gtIgnore': gtIg, + 'dtIgnore': dtIg, + } + + def accumulate(self, p=None): + """Accumulate per image evaluation results and store the result in + self.eval. + + :param p: input params for evaluation + :return: None + """ + print('Accumulating evaluation results...') + tic = time.time() + if not self.evalImgs: + print('Please run evaluate() first') + # allows input customized parameters + if p is None: + p = self.params + p.catIds = p.catIds if p.useCats == 1 else [-1] + T = len(p.iouThrs) + R = len(p.recThrs) + K = len(p.catIds) if p.useCats else 1 + A = len(p.areaRng) + M = len(p.maxDets) + precision = -np.ones( + (T, R, K, A, M)) # -1 for the precision of absent categories + recall = -np.ones((T, K, A, M)) + scores = -np.ones((T, R, K, A, M)) + + # create dictionary for future indexing + _pe = self._paramsEval + catIds = _pe.catIds if _pe.useCats else [-1] + setK = set(catIds) + setA = set(map(tuple, _pe.areaRng)) + setM = set(_pe.maxDets) + setI = set(_pe.vidIds) + # get inds to evaluate + k_list = [n for n, k in enumerate(p.catIds) if k in setK] + m_list = [m for n, m in enumerate(p.maxDets) if m in setM] + a_list = [ + n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) + if a in setA + ] + i_list = [n for n, i in enumerate(p.vidIds) if i in setI] + I0 = len(_pe.vidIds) + A0 = len(_pe.areaRng) + # retrieve E at each category, area range, and max number of detections + for k, k0 in enumerate(k_list): + Nk = k0 * A0 * I0 + for a, a0 in enumerate(a_list): + Na = a0 * I0 + for m, maxDet in enumerate(m_list): + E = [self.evalImgs[Nk + Na + i] for i in i_list] + E = [e for e in E if e is not None] + if len(E) == 0: + continue + dtScores = np.concatenate( + [e['dtScores'][0:maxDet] for e in E]) + + inds = np.argsort(-dtScores, kind='mergesort') + dtScoresSorted = dtScores[inds] + + dtm = np.concatenate( + [e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, + inds] + dtIg = np.concatenate( + [e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, + inds] + gtIg = np.concatenate([e['gtIgnore'] for e in E]) + npig = np.count_nonzero(gtIg == 0) + if npig == 0: + continue + tps = np.logical_and(dtm, np.logical_not(dtIg)) + fps = np.logical_and( + np.logical_not(dtm), np.logical_not(dtIg)) + + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) + for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): + tp = np.array(tp) + fp = np.array(fp) + nd_ori = len(tp) + rc = tp / npig + pr = tp / (fp + tp + np.spacing(1)) + q = np.zeros((R, )) + ss = np.zeros((R, )) + + if nd_ori: + recall[t, k, a, m] = rc[-1] + else: + recall[t, k, a, m] = 0 + + # use python array gets significant speed improvement + pr = pr.tolist() + q = q.tolist() + + for i in range(nd_ori - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + inds = np.searchsorted(rc, p.recThrs, side='left') + try: + for ri, pi in enumerate(inds): + q[ri] = pr[pi] + ss[ri] = dtScoresSorted[pi] + except Exception: + pass + precision[t, :, k, a, m] = np.array(q) + scores[t, :, k, a, m] = np.array(ss) + self.eval = { + 'params': p, + 'counts': [T, R, K, A, M], + 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'precision': precision, + 'recall': recall, + 'scores': scores, + } + toc = time.time() + print('DONE (t={:0.2f}s).'.format(toc - tic)) + + def summarize(self): + """Compute and display summary metrics for evaluation results. + + Note this function can *only* be applied on the default parameter + setting + """ + + def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100): + p = self.params + iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | ' \ + 'maxDets={:>3d} ] = {:0.3f}' + titleStr = 'Average Precision' if ap == 1 else 'Average Recall' + typeStr = '(AP)' if ap == 1 else '(AR)' + iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ + if iouThr is None else '{:0.2f}'.format(iouThr) + + aind = [ + i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng + ] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval['precision'] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval['recall'] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print( + iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, + mean_s)) + return mean_s + + def _summarizeDets(): + stats = np.zeros((12, )) + stats[0] = _summarize(1) + stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2]) + stats[2] = _summarize( + 1, iouThr=.75, maxDets=self.params.maxDets[2]) + stats[3] = _summarize( + 1, areaRng='small', maxDets=self.params.maxDets[2]) + stats[4] = _summarize( + 1, areaRng='medium', maxDets=self.params.maxDets[2]) + stats[5] = _summarize( + 1, areaRng='large', maxDets=self.params.maxDets[2]) + stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) + stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) + stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) + stats[9] = _summarize( + 0, areaRng='small', maxDets=self.params.maxDets[2]) + stats[10] = _summarize( + 0, areaRng='medium', maxDets=self.params.maxDets[2]) + stats[11] = _summarize( + 0, areaRng='large', maxDets=self.params.maxDets[2]) + return stats + + def _summarizeKps(): + stats = np.zeros((10, )) + stats[0] = _summarize(1, maxDets=20) + stats[1] = _summarize(1, maxDets=20, iouThr=.5) + stats[2] = _summarize(1, maxDets=20, iouThr=.75) + stats[3] = _summarize(1, maxDets=20, areaRng='medium') + stats[4] = _summarize(1, maxDets=20, areaRng='large') + stats[5] = _summarize(0, maxDets=20) + stats[6] = _summarize(0, maxDets=20, iouThr=.5) + stats[7] = _summarize(0, maxDets=20, iouThr=.75) + stats[8] = _summarize(0, maxDets=20, areaRng='medium') + stats[9] = _summarize(0, maxDets=20, areaRng='large') + return stats + + if not self.eval: + raise Exception('Please run accumulate() first') + iouType = self.params.iouType + if iouType == 'segm' or iouType == 'bbox': + summarize = _summarizeDets + elif iouType == 'keypoints': + summarize = _summarizeKps + self.stats = summarize() + + def __str__(self): + self.summarize() + + +class Params: + """Params for coco evaluation api.""" + + def setDetParams(self): + self.vidIds = [] + self.catIds = [] + # np.arange causes trouble. the data point on arange + # is slightly larger than the true value + self.iouThrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.recThrs = np.linspace( + .0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True) + self.maxDets = [1, 10, 100] + self.areaRng = [[0**2, 1e5**2], [0**2, 128**2], [128**2, 256**2], + [256**2, 1e5**2]] + self.areaRngLbl = ['all', 'small', 'medium', 'large'] + self.useCats = 1 + + def setKpParams(self): + self.vidIds = [] + self.catIds = [] + # np.arange causes trouble. the data point on arange + # is slightly larger than the true value + self.iouThrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.recThrs = np.linspace( + .0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True) + self.maxDets = [20] + self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]] + self.areaRngLbl = ['all', 'medium', 'large'] + self.useCats = 1 + + def __init__(self, iouType='segm'): + if iouType == 'segm' or iouType == 'bbox': + self.setDetParams() + elif iouType == 'keypoints': + self.setKpParams() + else: + raise Exception('iouType not supported') + self.iouType = iouType + # useSegm is deprecated + self.useSegm = None diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ec0e46250290bcd1bfe5afbf688b76d29f1881 --- /dev/null +++ b/mmdet/evaluation/metrics/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_video_metric import BaseVideoMetric +from .cityscapes_metric import CityScapesMetric +from .coco_caption_metric import COCOCaptionMetric +from .coco_metric import CocoMetric +from .coco_occluded_metric import CocoOccludedSeparatedMetric +from .coco_panoptic_metric import CocoPanopticMetric +from .coco_video_metric import CocoVideoMetric +from .crowdhuman_metric import CrowdHumanMetric +from .dump_det_results import DumpDetResults +from .dump_proposals_metric import DumpProposals +from .lvis_metric import LVISMetric +from .mot_challenge_metric import MOTChallengeMetric +from .openimages_metric import OpenImagesMetric +from .refseg_metric import RefSegMetric +from .reid_metric import ReIDMetrics +from .semseg_metric import SemSegMetric +from .voc_metric import VOCMetric +from .youtube_vis_metric import YouTubeVISMetric + +__all__ = [ + 'CityScapesMetric', 'CocoMetric', 'CocoPanopticMetric', 'OpenImagesMetric', + 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals', + 'CocoOccludedSeparatedMetric', 'DumpDetResults', 'BaseVideoMetric', + 'MOTChallengeMetric', 'CocoVideoMetric', 'ReIDMetrics', 'YouTubeVISMetric', + 'COCOCaptionMetric', 'SemSegMetric', 'RefSegMetric' +] diff --git a/mmdet/evaluation/metrics/base_video_metric.py b/mmdet/evaluation/metrics/base_video_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..90c7cdcbed5f12b59b6978ccba7576d6d2c25c5e --- /dev/null +++ b/mmdet/evaluation/metrics/base_video_metric.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import pickle +import shutil +import tempfile +import warnings +from typing import Optional, Sequence + +import torch +from mmengine.dist import (barrier, broadcast, broadcast_object_list, + get_dist_info, is_main_process) +from mmengine.evaluator import BaseMetric +from mmengine.utils import mkdir_or_exist + + +class BaseVideoMetric(BaseMetric): + """Base class for a metric in video task. + + The metric first processes each batch of data_samples and predictions, + and appends the processed results to the results list. Then it + collects all results together from all ranks if distributed training + is used. Finally, it computes the metrics of the entire dataset. + + A subclass of class:`BaseVideoMetric` should assign a meaningful value + to the class attribute `default_prefix`. See the argument `prefix` for + details. + """ + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for track_data_sample in data_samples: + video_data_samples = track_data_sample['video_data_samples'] + ori_video_len = video_data_samples[0].ori_video_length + if ori_video_len == len(video_data_samples): + # video process + self.process_video(video_data_samples) + else: + # image process + self.process_image(video_data_samples, ori_video_len) + + def evaluate(self, size: int = 1) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + if len(self.results) == 0: + warnings.warn( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.') + + results = collect_tracking_results(self.results, self.collect_device) + + if is_main_process(): + _metrics = self.compute_metrics(results) # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results list + self.results.clear() + return metrics[0] + + +def collect_tracking_results(results: list, + device: str = 'cpu', + tmpdir: Optional[str] = None) -> Optional[list]: + """Collected results in distributed environments. different from the + function mmengine.dist.collect_results, tracking compute metrics don't use + paramenter size, which means length of the entire validation dataset. + because it's equal to video num, but compute metrics need image num. + + Args: + results (list): Result list containing result parts to be + collected. Each item of ``result_part`` should be a picklable + object. + device (str): Device name. Optional values are 'cpu' and 'gpu'. + tmpdir (str | None): Temporal directory for collected results to + store. If set to None, it will create a temporal directory for it. + ``tmpdir`` should be None when device is 'gpu'. Defaults to None. + + Returns: + list or None: The collected results. + """ + if device not in ['gpu', 'cpu']: + raise NotImplementedError( + f"device must be 'cpu' or 'gpu', but got {device}") + + if device == 'gpu': + assert tmpdir is None, 'tmpdir should be None when device is "gpu"' + raise NotImplementedError('GPU collecting has not been supported yet') + else: + return collect_tracking_results_cpu(results, tmpdir) + + +def collect_tracking_results_cpu(result_part: list, + tmpdir: Optional[str] = None + ) -> Optional[list]: + """Collect results on cpu mode. + + Saves the results on different gpus to 'tmpdir' and collects them by the + rank 0 worker. + + Args: + result_part (list): The part of prediction results. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. If is None, use `tempfile.mkdtemp()` + to make a temporary path. Defaults to None. + + Returns: + list or None: The collected results. + """ + rank, world_size = get_dist_info() + if world_size == 1: + return result_part + + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8) + if rank == 0: + mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8) + dir_tensor[:len(tmpdir)] = tmpdir + broadcast(dir_tensor, 0) + tmpdir = dir_tensor.numpy().tobytes().decode().rstrip() + else: + mkdir_or_exist(tmpdir) + + # dump the part result to the dir + with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore + pickle.dump(result_part, f, protocol=2) + + barrier() + + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore + with open(path, 'rb') as f: + part_list.extend(pickle.load(f)) + shutil.rmtree(tmpdir) + return part_list diff --git a/mmdet/evaluation/metrics/cityscapes_metric.py b/mmdet/evaluation/metrics/cityscapes_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cdc179a3c76ef3742dd3ee6692c7deb9905459 --- /dev/null +++ b/mmdet/evaluation/metrics/cityscapes_metric.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import shutil +import tempfile +from collections import OrderedDict +from typing import Dict, Optional, Sequence + +import mmcv +import numpy as np +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmdet.registry import METRICS + +try: + import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa: E501 + import cityscapesscripts.helpers.labels as CSLabels + + from mmdet.evaluation.functional import evaluateImgLists + HAS_CITYSCAPESAPI = True +except ImportError: + HAS_CITYSCAPESAPI = False + + +@METRICS.register_module() +class CityScapesMetric(BaseMetric): + """CityScapes metric for instance segmentation. + + Args: + outfile_prefix (str): The prefix of txt and png files. The txt and + png file will be save in a directory whose path is + "outfile_prefix.results/". + seg_prefix (str, optional): Path to the directory which contains the + cityscapes instance segmentation masks. It's necessary when + training and validation. It could be None when infer on test + dataset. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + dump_matches (bool): Whether dump matches.json file during evaluating. + Defaults to False. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + default_prefix: Optional[str] = 'cityscapes' + + def __init__(self, + outfile_prefix: str, + seg_prefix: Optional[str] = None, + format_only: bool = False, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + dump_matches: bool = False, + file_client_args: dict = None, + backend_args: dict = None) -> None: + + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + super().__init__(collect_device=collect_device, prefix=prefix) + + self.tmp_dir = None + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + else: + assert seg_prefix is not None, '`seg_prefix` is necessary when ' + 'computing the CityScapes metrics' + + if outfile_prefix is None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.outfile_prefix = osp.join(self.tmp_dir.name, 'results') + else: + # the directory to save predicted panoptic segmentation mask + self.outfile_prefix = osp.join(outfile_prefix, 'results') # type: ignore # yapf: disable # noqa: E501 + + dir_name = osp.expanduser(self.outfile_prefix) + + if osp.exists(dir_name) and is_main_process(): + logger: MMLogger = MMLogger.get_current_instance() + logger.info('remove previous results.') + shutil.rmtree(dir_name) + os.makedirs(dir_name, exist_ok=True) + + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + self.seg_prefix = seg_prefix + self.dump_matches = dump_matches + + def __del__(self) -> None: + """Clean up the results if necessary.""" + if self.tmp_dir is not None: + self.tmp_dir.cleanup() + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + # parse pred + result = dict() + pred = data_sample['pred_instances'] + filename = data_sample['img_path'] + basename = osp.splitext(osp.basename(filename))[0] + pred_txt = osp.join(self.outfile_prefix, basename + '_pred.txt') + result['pred_txt'] = pred_txt + labels = pred['labels'].cpu().numpy() + masks = pred['masks'].cpu().numpy().astype(np.uint8) + if 'mask_scores' in pred: + # some detectors use different scores for bbox and mask + mask_scores = pred['mask_scores'].cpu().numpy() + else: + mask_scores = pred['scores'].cpu().numpy() + + with open(pred_txt, 'w') as f: + for i, (label, mask, mask_score) in enumerate( + zip(labels, masks, mask_scores)): + class_name = self.dataset_meta['classes'][label] + class_id = CSLabels.name2label[class_name].id + png_filename = osp.join( + self.outfile_prefix, + basename + f'_{i}_{class_name}.png') + mmcv.imwrite(mask, png_filename) + f.write(f'{osp.basename(png_filename)} ' + f'{class_id} {mask_score}\n') + + # parse gt + gt = dict() + img_path = filename.replace('leftImg8bit.png', + 'gtFine_instanceIds.png') + gt['file_name'] = img_path.replace('leftImg8bit', 'gtFine') + + self.results.append((gt, result)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + if self.format_only: + logger.info( + f'results are saved to {osp.dirname(self.outfile_prefix)}') + return OrderedDict() + logger.info('starts to compute metric') + + gts, preds = zip(*results) + # set global states in cityscapes evaluation API + gt_instances_file = osp.join(self.outfile_prefix, 'gtInstances.json') # type: ignore # yapf: disable # noqa: E501 + # split gt and prediction list + gts, preds = zip(*results) + CSEval.args.JSONOutput = False + CSEval.args.colorized = False + CSEval.args.gtInstancesFile = gt_instances_file + + groundTruthImgList = [gt['file_name'] for gt in gts] + predictionImgList = [pred['pred_txt'] for pred in preds] + CSEval_results = evaluateImgLists( + predictionImgList, + groundTruthImgList, + CSEval.args, + self.backend_args, + dump_matches=self.dump_matches)['averages'] + + eval_results = OrderedDict() + eval_results['mAP'] = CSEval_results['allAp'] + eval_results['AP@50'] = CSEval_results['allAp50%'] + + return eval_results diff --git a/mmdet/evaluation/metrics/coco_caption_metric.py b/mmdet/evaluation/metrics/coco_caption_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c7350150f73d8d568597b352e33ad2a202c609 --- /dev/null +++ b/mmdet/evaluation/metrics/coco_caption_metric.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import tempfile +from typing import List, Optional + +from mmengine.evaluator import BaseMetric +from mmengine.utils import track_iter_progress +from pycocotools.coco import COCO + +from mmdet.registry import METRICS + +try: + from pycocoevalcap.eval import COCOEvalCap +except ImportError: + COCOEvalCap = None + + +@METRICS.register_module() +class COCOCaptionMetric(BaseMetric): + """Coco Caption evaluation wrapper. + + Save the generated captions and transform into coco format. + Calling COCO API for caption metrics. + + Args: + ann_file (str): the path for the COCO format caption ground truth + json file, load for evaluations. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + def __init__(self, + ann_file: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + if COCOEvalCap is None: + raise RuntimeError( + 'COCOEvalCap is not installed, please install it by: ' + 'pip install pycocoevalcap') + + super().__init__(collect_device=collect_device, prefix=prefix) + self.ann_file = ann_file + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['caption'] = data_sample['pred_caption'] + result['image_id'] = int(data_sample['img_id']) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + + with tempfile.TemporaryDirectory() as temp_dir: + + eval_result_file = save_result( + result=results, + result_dir=temp_dir, + filename='caption_pred', + remove_duplicate='image_id', + ) + + coco_val = coco_caption_eval(eval_result_file, self.ann_file) + + return coco_val + + +def save_result(result, result_dir, filename, remove_duplicate=''): + """Saving predictions as json file for evaluation.""" + # combine results from all processes + if remove_duplicate: + result_new = [] + id_list = [] + for res in track_iter_progress(result): + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + final_result_file_url = os.path.join(result_dir, '%s.json' % filename) + print(f'result file saved to {final_result_file_url}') + json.dump(result, open(final_result_file_url, 'w')) + + return final_result_file_url + + +def coco_caption_eval(results_file, ann_file): + """Evaluation between gt json and prediction json files.""" + # create coco object and coco_result object + coco = COCO(ann_file) + coco_result = coco.loadRes(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # make sure the image ids are the same + coco_eval.params['image_id'] = coco_result.getImgIds() + + # This will take some times at the first run + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + return coco_eval.eval diff --git a/mmdet/evaluation/metrics/coco_metric.py b/mmdet/evaluation/metrics/coco_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..cfdc66e03b96e62366a921c137fc5a5727e26302 --- /dev/null +++ b/mmdet/evaluation/metrics/coco_metric.py @@ -0,0 +1,597 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import itertools +import os.path as osp +import tempfile +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.fileio import dump, get_local_path, load +from mmengine.logging import MMLogger +from terminaltables import AsciiTable + +from mmdet.datasets.api_wrappers import COCO, COCOeval, COCOevalMP +from mmdet.registry import METRICS +from mmdet.structures.mask import encode_mask_results +from ..functional import eval_recalls + + +@METRICS.register_module() +class CocoMetric(BaseMetric): + """COCO evaluation metric. + + Evaluate AR, AP, and mAP for detection tasks including proposal/box + detection and instance segmentation. Please refer to + https://cocodataset.org/#detection-eval for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', 'proposal', and 'proposal_fast'. + Defaults to 'bbox'. + classwise (bool): Whether to evaluate the metric class-wise. + Defaults to False. + proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. + Defaults to (100, 300, 1000). + iou_thrs (float | List[float], optional): IoU threshold to compute AP + and AR. If not specified, IoUs from 0.5 to 0.95 will be used. + Defaults to None. + metric_items (List[str], optional): Metric result names to be + recorded in the evaluation result. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + sort_categories (bool): Whether sort categories in annotations. Only + used for `Objects365V1Dataset`. Defaults to False. + use_mp_eval (bool): Whether to use mul-processing evaluation + """ + default_prefix: Optional[str] = 'coco' + + def __init__(self, + ann_file: Optional[str] = None, + metric: Union[str, List[str]] = 'bbox', + classwise: bool = False, + proposal_nums: Sequence[int] = (100, 300, 1000), + iou_thrs: Optional[Union[float, Sequence[float]]] = None, + metric_items: Optional[Sequence[str]] = None, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + file_client_args: dict = None, + backend_args: dict = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + sort_categories: bool = False, + use_mp_eval: bool = False) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + # coco evaluation metrics + self.metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] + for metric in self.metrics: + if metric not in allowed_metrics: + raise KeyError( + "metric should be one of 'bbox', 'segm', 'proposal', " + f"'proposal_fast', but got {metric}.") + + # do class wise evaluation, default False + self.classwise = classwise + # whether to use multi processing evaluation, default False + self.use_mp_eval = use_mp_eval + + # proposal_nums used to compute recall or precision. + self.proposal_nums = list(proposal_nums) + + # iou_thrs used to compute recall or precision. + if iou_thrs is None: + iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.iou_thrs = iou_thrs + self.metric_items = metric_items + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.outfile_prefix = outfile_prefix + + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + # if ann_file is not specified, + # initialize coco api with the converted dataset + if ann_file is not None: + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._coco_api = COCO(local_path) + if sort_categories: + # 'categories' list in objects365_train.json and + # objects365_val.json is inconsistent, need sort + # list(or dict) before get cat_ids. + cats = self._coco_api.cats + sorted_cats = {i: cats[i] for i in sorted(cats)} + self._coco_api.cats = sorted_cats + categories = self._coco_api.dataset['categories'] + sorted_categories = sorted( + categories, key=lambda i: i['id']) + self._coco_api.dataset['categories'] = sorted_categories + else: + self._coco_api = None + + # handle dataset lazy init + self.cat_ids = None + self.img_ids = None + + def fast_eval_recall(self, + results: List[dict], + proposal_nums: Sequence[int], + iou_thrs: Sequence[float], + logger: Optional[MMLogger] = None) -> np.ndarray: + """Evaluate proposal recall with COCO's fast_eval_recall. + + Args: + results (List[dict]): Results of the dataset. + proposal_nums (Sequence[int]): Proposal numbers used for + evaluation. + iou_thrs (Sequence[float]): IoU thresholds used for evaluation. + logger (MMLogger, optional): Logger used for logging the recall + summary. + Returns: + np.ndarray: Averaged recall results. + """ + gt_bboxes = [] + pred_bboxes = [result['bboxes'] for result in results] + for i in range(len(self.img_ids)): + ann_ids = self._coco_api.get_ann_ids(img_ids=self.img_ids[i]) + ann_info = self._coco_api.load_anns(ann_ids) + if len(ann_info) == 0: + gt_bboxes.append(np.zeros((0, 4))) + continue + bboxes = [] + for ann in ann_info: + if ann.get('ignore', False) or ann['iscrowd']: + continue + x1, y1, w, h = ann['bbox'] + bboxes.append([x1, y1, x1 + w, y1 + h]) + bboxes = np.array(bboxes, dtype=np.float32) + if bboxes.shape[0] == 0: + bboxes = np.zeros((0, 4)) + gt_bboxes.append(bboxes) + + recalls = eval_recalls( + gt_bboxes, pred_bboxes, proposal_nums, iou_thrs, logger=logger) + ar = recalls.mean(axis=1) + return ar + + def xyxy2xywh(self, bbox: np.ndarray) -> list: + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox: List = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def results2json(self, results: Sequence[dict], + outfile_prefix: str) -> dict: + """Dump the detection results to a COCO style json file. + + There are 3 types of results: proposals, bbox predictions, mask + predictions, and they have different data types. This method will + automatically recognize the type, and dump them to json files. + + Args: + results (Sequence[dict]): Testing results of the + dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.bbox.json", "somepath/xxx.segm.json", + "somepath/xxx.proposal.json". + + Returns: + dict: Possible keys are "bbox", "segm", "proposal", and + values are corresponding filenames. + """ + bbox_json_results = [] + segm_json_results = [] if 'masks' in results[0] else None + for idx, result in enumerate(results): + image_id = result.get('img_id', idx) + labels = result['labels'] + bboxes = result['bboxes'] + scores = result['scores'] + # bbox results + for i, label in enumerate(labels): + data = dict() + data['image_id'] = image_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(scores[i]) + data['category_id'] = self.cat_ids[label] + bbox_json_results.append(data) + + if segm_json_results is None: + continue + + # segm results + masks = result['masks'] + mask_scores = result.get('mask_scores', scores) + for i, label in enumerate(labels): + data = dict() + data['image_id'] = image_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(mask_scores[i]) + data['category_id'] = self.cat_ids[label] + if isinstance(masks[i]['counts'], bytes): + masks[i]['counts'] = masks[i]['counts'].decode() + data['segmentation'] = masks[i] + segm_json_results.append(data) + + result_files = dict() + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + dump(bbox_json_results, result_files['bbox']) + + if segm_json_results is not None: + result_files['segm'] = f'{outfile_prefix}.segm.json' + dump(segm_json_results, result_files['segm']) + + return result_files + + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> str: + """Convert ground truth to coco format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + Returns: + str: The filename of the json file. + """ + categories = [ + dict(id=id, name=name) + for id, name in enumerate(self.dataset_meta['classes']) + ] + image_infos = [] + annotations = [] + + for idx, gt_dict in enumerate(gt_dicts): + img_id = gt_dict.get('img_id', idx) + image_info = dict( + id=img_id, + width=gt_dict['width'], + height=gt_dict['height'], + file_name='') + image_infos.append(image_info) + for ann in gt_dict['anns']: + label = ann['bbox_label'] + bbox = ann['bbox'] + coco_bbox = [ + bbox[0], + bbox[1], + bbox[2] - bbox[0], + bbox[3] - bbox[1], + ] + + annotation = dict( + id=len(annotations) + + 1, # coco api requires id starts with 1 + image_id=img_id, + bbox=coco_bbox, + iscrowd=ann.get('ignore_flag', 0), + category_id=int(label), + area=coco_bbox[2] * coco_bbox[3]) + if ann.get('mask', None): + mask = ann['mask'] + # area = mask_util.area(mask) + if isinstance(mask, dict) and isinstance( + mask['counts'], bytes): + mask['counts'] = mask['counts'].decode() + annotation['segmentation'] = mask + # annotation['area'] = float(area) + annotations.append(annotation) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmdet CocoMetric.') + coco_json = dict( + info=info, + images=image_infos, + categories=categories, + licenses=None, + ) + if len(annotations) > 0: + coco_json['annotations'] = annotations + converted_json_path = f'{outfile_prefix}.gt.json' + dump(coco_json, converted_json_path) + return converted_json_path + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + result = dict() + pred = data_sample['pred_instances'] + result['img_id'] = data_sample['img_id'] + result['bboxes'] = pred['bboxes'].cpu().numpy() + result['scores'] = pred['scores'].cpu().numpy() + result['labels'] = pred['labels'].cpu().numpy() + # encode mask to RLE + if 'masks' in pred: + result['masks'] = encode_mask_results( + pred['masks'].detach().cpu().numpy()) if isinstance( + pred['masks'], torch.Tensor) else pred['masks'] + # some detectors use different scores for bbox and mask + if 'mask_scores' in pred: + result['mask_scores'] = pred['mask_scores'].cpu().numpy() + + # parse gt + gt = dict() + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + gt['img_id'] = data_sample['img_id'] + if self._coco_api is None: + # TODO: Need to refactor to support LoadAnnotations + assert 'instances' in data_sample, \ + 'ground truth is required for evaluation when ' \ + '`ann_file` is not provided' + gt['anns'] = data_sample['instances'] + # add converted result to the results list + self.results.append((gt, result)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # split gt and prediction list + gts, preds = zip(*results) + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + if self._coco_api is None: + # use converted gt json file to initialize coco api + logger.info('Converting ground truth to coco format...') + coco_json_path = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=outfile_prefix) + self._coco_api = COCO(coco_json_path) + + # handle lazy init + if self.cat_ids is None: + self.cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + if self.img_ids is None: + self.img_ids = self._coco_api.get_img_ids() + + # convert predictions to coco format and dump to json file + result_files = self.results2json(preds, outfile_prefix) + + eval_results = OrderedDict() + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + for metric in self.metrics: + logger.info(f'Evaluating {metric}...') + + # TODO: May refactor fast_eval_recall to an independent metric? + # fast eval recall + if metric == 'proposal_fast': + ar = self.fast_eval_recall( + preds, self.proposal_nums, self.iou_thrs, logger=logger) + log_msg = [] + for i, num in enumerate(self.proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') + log_msg = ''.join(log_msg) + logger.info(log_msg) + continue + + # evaluate proposal, bbox and segm + iou_type = 'bbox' if metric == 'proposal' else metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + coco_dt = self._coco_api.loadRes(predictions) + + except IndexError: + logger.error( + 'The testing results of the whole dataset is empty.') + break + + if self.use_mp_eval: + coco_eval = COCOevalMP(self._coco_api, coco_dt, iou_type) + else: + coco_eval = COCOeval(self._coco_api, coco_dt, iou_type) + + coco_eval.params.catIds = self.cat_ids + coco_eval.params.imgIds = self.img_ids + coco_eval.params.maxDets = list(self.proposal_nums) + coco_eval.params.iouThrs = self.iou_thrs + + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + metric_items = self.metric_items + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError( + f'metric item "{metric_item}" is not supported') + + if metric == 'proposal': + coco_eval.params.useCats = 0 + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if metric_items is None: + metric_items = [ + 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', + 'AR_m@1000', 'AR_l@1000' + ] + + for item in metric_items: + val = float( + f'{coco_eval.stats[coco_metric_names[item]]:.3f}') + eval_results[item] = val + else: + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if self.classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = coco_eval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, cat_id in enumerate(self.cat_ids): + t = [] + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self._coco_api.loadCats(cat_id)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{nm["name"]}') + t.append(f'{round(ap, 3)}') + eval_results[f'{nm["name"]}_precision'] = round(ap, 3) + + # indexes of IoU @50 and @75 + for iou in [0, 5]: + precision = precisions[iou, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{round(ap, 3)}') + + # indexes of area of small, median and large + for area in [1, 2, 3]: + precision = precisions[:, :, idx, area, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{round(ap, 3)}') + results_per_category.append(tuple(t)) + + num_columns = len(results_per_category[0]) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = [ + 'category', 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', + 'mAP_m', 'mAP_l' + ] + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + logger.info('\n' + table.table) + + if metric_items is None: + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = coco_eval.stats[coco_metric_names[metric_item]] + eval_results[key] = float(f'{round(val, 3)}') + + ap = coco_eval.stats[:6] + logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} ' + f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results diff --git a/mmdet/evaluation/metrics/coco_occluded_metric.py b/mmdet/evaluation/metrics/coco_occluded_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..81235a04e6ee1929cfd6b5cdc284d239765b0d69 --- /dev/null +++ b/mmdet/evaluation/metrics/coco_occluded_metric.py @@ -0,0 +1,204 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import mmengine +import numpy as np +from mmengine.fileio import load +from mmengine.logging import print_log +from pycocotools import mask as coco_mask +from terminaltables import AsciiTable + +from mmdet.registry import METRICS +from .coco_metric import CocoMetric + + +@METRICS.register_module() +class CocoOccludedSeparatedMetric(CocoMetric): + """Metric of separated and occluded masks which presented in paper `A Tri- + Layer Plugin to Improve Occluded Detection. + + `_. + + Separated COCO and Occluded COCO are automatically generated subsets of + COCO val dataset, collecting separated objects and partially occluded + objects for a large variety of categories. In this way, we define + occlusion into two major categories: separated and partially occluded. + + - Separation: target object segmentation mask is separated into distinct + regions by the occluder. + - Partial Occlusion: target object is partially occluded but the + segmentation mask is connected. + + These two new scalable real-image datasets are to benchmark a model's + capability to detect occluded objects of 80 common categories. + + Please cite the paper if you use this dataset: + + @article{zhan2022triocc, + title={A Tri-Layer Plugin to Improve Occluded Detection}, + author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew}, + journal={British Machine Vision Conference}, + year={2022} + } + + Args: + occluded_ann (str): Path to the occluded coco annotation file. + separated_ann (str): Path to the separated coco annotation file. + score_thr (float): Score threshold of the detection masks. + Defaults to 0.3. + iou_thr (float): IoU threshold for the recall calculation. + Defaults to 0.75. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', 'proposal', and 'proposal_fast'. + Defaults to 'bbox'. + """ + default_prefix: Optional[str] = 'coco' + + def __init__( + self, + *args, + occluded_ann: + str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/occluded_coco.pkl', # noqa + separated_ann: + str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/separated_coco.pkl', # noqa + score_thr: float = 0.3, + iou_thr: float = 0.75, + metric: Union[str, List[str]] = ['bbox', 'segm'], + **kwargs) -> None: + super().__init__(*args, metric=metric, **kwargs) + self.occluded_ann = load(occluded_ann) + self.separated_ann = load(separated_ann) + self.score_thr = score_thr + self.iou_thr = iou_thr + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + coco_metric_res = super().compute_metrics(results) + eval_res = self.evaluate_occluded_separated(results) + coco_metric_res.update(eval_res) + return coco_metric_res + + def evaluate_occluded_separated(self, results: List[tuple]) -> dict: + """Compute the recall of occluded and separated masks. + + Args: + results (list[tuple]): Testing results of the dataset. + + Returns: + dict[str, float]: The recall of occluded and separated masks. + """ + dict_det = {} + print_log('processing detection results...') + prog_bar = mmengine.ProgressBar(len(results)) + for i in range(len(results)): + gt, dt = results[i] + img_id = dt['img_id'] + cur_img_name = self._coco_api.imgs[img_id]['file_name'] + if cur_img_name not in dict_det.keys(): + dict_det[cur_img_name] = [] + + for bbox, score, label, mask in zip(dt['bboxes'], dt['scores'], + dt['labels'], dt['masks']): + cur_binary_mask = coco_mask.decode(mask) + dict_det[cur_img_name].append([ + score, self.dataset_meta['classes'][label], + cur_binary_mask, bbox + ]) + dict_det[cur_img_name].sort( + key=lambda x: (-x[0], x[3][0], x[3][1]) + ) # rank by confidence from high to low, avoid same confidence + prog_bar.update() + print_log('\ncomputing occluded mask recall...', logger='current') + occluded_correct_num, occluded_recall = self.compute_recall( + dict_det, gt_ann=self.occluded_ann, is_occ=True) + print_log( + f'\nCOCO occluded mask recall: {occluded_recall:.2f}%', + logger='current') + print_log( + f'COCO occluded mask success num: {occluded_correct_num}', + logger='current') + print_log('computing separated mask recall...', logger='current') + separated_correct_num, separated_recall = self.compute_recall( + dict_det, gt_ann=self.separated_ann, is_occ=False) + print_log( + f'\nCOCO separated mask recall: {separated_recall:.2f}%', + logger='current') + print_log( + f'COCO separated mask success num: {separated_correct_num}', + logger='current') + table_data = [ + ['mask type', 'recall', 'num correct'], + ['occluded', f'{occluded_recall:.2f}%', occluded_correct_num], + ['separated', f'{separated_recall:.2f}%', separated_correct_num] + ] + table = AsciiTable(table_data) + print_log('\n' + table.table, logger='current') + return dict( + occluded_recall=occluded_recall, separated_recall=separated_recall) + + def compute_recall(self, + result_dict: dict, + gt_ann: list, + is_occ: bool = True) -> tuple: + """Compute the recall of occluded or separated masks. + + Args: + result_dict (dict): Processed mask results. + gt_ann (list): Occluded or separated coco annotations. + is_occ (bool): Whether the annotation is occluded mask. + Defaults to True. + Returns: + tuple: number of correct masks and the recall. + """ + correct = 0 + prog_bar = mmengine.ProgressBar(len(gt_ann)) + for iter_i in range(len(gt_ann)): + cur_item = gt_ann[iter_i] + cur_img_name = cur_item[0] + cur_gt_bbox = cur_item[3] + if is_occ: + cur_gt_bbox = [ + cur_gt_bbox[0], cur_gt_bbox[1], + cur_gt_bbox[0] + cur_gt_bbox[2], + cur_gt_bbox[1] + cur_gt_bbox[3] + ] + cur_gt_class = cur_item[1] + cur_gt_mask = coco_mask.decode(cur_item[4]) + + assert cur_img_name in result_dict.keys() + cur_detections = result_dict[cur_img_name] + + correct_flag = False + for i in range(len(cur_detections)): + cur_det_confidence = cur_detections[i][0] + if cur_det_confidence < self.score_thr: + break + cur_det_class = cur_detections[i][1] + if cur_det_class != cur_gt_class: + continue + cur_det_mask = cur_detections[i][2] + cur_iou = self.mask_iou(cur_det_mask, cur_gt_mask) + if cur_iou >= self.iou_thr: + correct_flag = True + break + if correct_flag: + correct += 1 + prog_bar.update() + recall = correct / len(gt_ann) * 100 + return correct, recall + + def mask_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: + """Compute IoU between two masks.""" + mask1_area = np.count_nonzero(mask1 == 1) + mask2_area = np.count_nonzero(mask2 == 1) + intersection = np.count_nonzero(np.logical_and(mask1 == 1, mask2 == 1)) + iou = intersection / (mask1_area + mask2_area - intersection) + return iou diff --git a/mmdet/evaluation/metrics/coco_panoptic_metric.py b/mmdet/evaluation/metrics/coco_panoptic_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..1554c0908d1e1143579929872f6bd1266a7e7a13 --- /dev/null +++ b/mmdet/evaluation/metrics/coco_panoptic_metric.py @@ -0,0 +1,618 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import itertools +import os.path as osp +import tempfile +from typing import Dict, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.fileio import dump, get_local_path, load +from mmengine.logging import MMLogger, print_log +from terminaltables import AsciiTable + +from mmdet.datasets.api_wrappers import COCOPanoptic +from mmdet.registry import METRICS +from ..functional import (INSTANCE_OFFSET, pq_compute_multi_core, + pq_compute_single_core) + +try: + import panopticapi + from panopticapi.evaluation import VOID, PQStat + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + panopticapi = None + id2rgb = None + rgb2id = None + VOID = None + PQStat = None + + +@METRICS.register_module() +class CocoPanopticMetric(BaseMetric): + """COCO panoptic segmentation evaluation metric. + + Evaluate PQ, SQ RQ for panoptic segmentation tasks. Please refer to + https://cocodataset.org/#panoptic-eval for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None. + seg_prefix (str, optional): Path to the directory which contains the + coco panoptic segmentation mask. It should be specified when + evaluate. Defaults to None. + classwise (bool): Whether to evaluate the metric class-wise. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. + It should be specified when format_only is True. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + nproc (int): Number of processes for panoptic quality computing. + Defaults to 32. When ``nproc`` exceeds the number of cpu cores, + the number of cpu cores is used. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + default_prefix: Optional[str] = 'coco_panoptic' + + def __init__(self, + ann_file: Optional[str] = None, + seg_prefix: Optional[str] = None, + classwise: bool = False, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + nproc: int = 32, + file_client_args: dict = None, + backend_args: dict = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + if panopticapi is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + super().__init__(collect_device=collect_device, prefix=prefix) + self.classwise = classwise + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.tmp_dir = None + # outfile_prefix should be a prefix of a path which points to a shared + # storage when train or test with multi nodes. + self.outfile_prefix = outfile_prefix + if outfile_prefix is None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.outfile_prefix = osp.join(self.tmp_dir.name, 'results') + # the directory to save predicted panoptic segmentation mask + self.seg_out_dir = f'{self.outfile_prefix}.panoptic' + self.nproc = nproc + self.seg_prefix = seg_prefix + + self.cat_ids = None + self.cat2label = None + + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + if ann_file: + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._coco_api = COCOPanoptic(local_path) + self.categories = self._coco_api.cats + else: + self._coco_api = None + self.categories = None + + def __del__(self) -> None: + """Clean up.""" + if self.tmp_dir is not None: + self.tmp_dir.cleanup() + + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> Tuple[str, str]: + """Convert ground truth to coco panoptic segmentation format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. + outfile_prefix (str): The filename prefix of the json file. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + + Returns: + Tuple[str, str]: The filename of the json file and the name of the\ + directory which contains panoptic segmentation masks. + """ + assert len(gt_dicts) > 0, 'gt_dicts is empty.' + gt_folder = osp.dirname(gt_dicts[0]['seg_map_path']) + converted_json_path = f'{outfile_prefix}.gt.json' + + categories = [] + for id, name in enumerate(self.dataset_meta['classes']): + isthing = 1 if name in self.dataset_meta['thing_classes'] else 0 + categories.append({'id': id, 'name': name, 'isthing': isthing}) + + image_infos = [] + annotations = [] + for gt_dict in gt_dicts: + img_id = gt_dict['image_id'] + image_info = { + 'id': img_id, + 'width': gt_dict['width'], + 'height': gt_dict['height'], + 'file_name': osp.split(gt_dict['seg_map_path'])[-1] + } + image_infos.append(image_info) + + pan_png = mmcv.imread(gt_dict['seg_map_path']).squeeze() + pan_png = pan_png[:, :, ::-1] + pan_png = rgb2id(pan_png) + segments_info = [] + for segment_info in gt_dict['segments_info']: + id = segment_info['id'] + label = segment_info['category'] + mask = pan_png == id + isthing = categories[label]['isthing'] + if isthing: + iscrowd = 1 if not segment_info['is_thing'] else 0 + else: + iscrowd = 0 + + new_segment_info = { + 'id': id, + 'category_id': label, + 'isthing': isthing, + 'iscrowd': iscrowd, + 'area': mask.sum() + } + segments_info.append(new_segment_info) + + segm_file = image_info['file_name'].replace('jpg', 'png') + annotation = dict( + image_id=img_id, + segments_info=segments_info, + file_name=segm_file) + annotations.append(annotation) + pan_png = id2rgb(pan_png) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmdet CocoPanopticMetric.' + ) + coco_json = dict( + info=info, + images=image_infos, + categories=categories, + licenses=None, + ) + if len(annotations) > 0: + coco_json['annotations'] = annotations + dump(coco_json, converted_json_path) + return converted_json_path, gt_folder + + def result2json(self, results: Sequence[dict], + outfile_prefix: str) -> Tuple[str, str]: + """Dump the panoptic results to a COCO style json file and a directory. + + Args: + results (Sequence[dict]): Testing results of the dataset. + outfile_prefix (str): The filename prefix of the json files and the + directory. + + Returns: + Tuple[str, str]: The json file and the directory which contains \ + panoptic segmentation masks. The filename of the json is + "somepath/xxx.panoptic.json" and name of the directory is + "somepath/xxx.panoptic". + """ + label2cat = dict((v, k) for (k, v) in self.cat2label.items()) + pred_annotations = [] + for idx in range(len(results)): + result = results[idx] + for segment_info in result['segments_info']: + sem_label = segment_info['category_id'] + # convert sem_label to json label + cat_id = label2cat[sem_label] + segment_info['category_id'] = label2cat[sem_label] + is_thing = self.categories[cat_id]['isthing'] + segment_info['isthing'] = is_thing + pred_annotations.append(result) + pan_json_results = dict(annotations=pred_annotations) + json_filename = f'{outfile_prefix}.panoptic.json' + dump(pan_json_results, json_filename) + return json_filename, ( + self.seg_out_dir + if self.tmp_dir is None else tempfile.gettempdir()) + + def _parse_predictions(self, + pred: dict, + img_id: int, + segm_file: str, + label2cat=None) -> dict: + """Parse panoptic segmentation predictions. + + Args: + pred (dict): Panoptic segmentation predictions. + img_id (int): Image id. + segm_file (str): Segmentation file name. + label2cat (dict): Mapping from label to category id. + Defaults to None. + + Returns: + dict: Parsed predictions. + """ + result = dict() + result['img_id'] = img_id + # shape (1, H, W) -> (H, W) + pan = pred['pred_panoptic_seg']['sem_seg'].cpu().numpy()[0] + ignore_index = pred['pred_panoptic_seg'].get( + 'ignore_index', len(self.dataset_meta['classes'])) + pan_labels = np.unique(pan) + segments_info = [] + for pan_label in pan_labels: + sem_label = pan_label % INSTANCE_OFFSET + # We reserve the length of dataset_meta['classes'] + # and ignore_index for VOID label + if sem_label == len( + self.dataset_meta['classes']) or sem_label == ignore_index: + continue + mask = pan == pan_label + area = mask.sum() + segments_info.append({ + 'id': + int(pan_label), + # when ann_file provided, sem_label should be cat_id, otherwise + # sem_label should be a continuous id, not the cat_id + # defined in dataset + 'category_id': + label2cat[sem_label] if label2cat else sem_label, + 'area': + int(area) + }) + # evaluation script uses 0 for VOID label. + pan[pan % INSTANCE_OFFSET == len(self.dataset_meta['classes'])] = VOID + pan[pan % INSTANCE_OFFSET == ignore_index] = VOID + + pan = id2rgb(pan).astype(np.uint8) + mmcv.imwrite(pan[:, :, ::-1], osp.join(self.seg_out_dir, segm_file)) + result = { + 'image_id': img_id, + 'segments_info': segments_info, + 'file_name': segm_file + } + + return result + + def _compute_batch_pq_stats(self, data_samples: Sequence[dict]): + """Process gts and predictions when ``outfile_prefix`` is not set, gts + are from dataset or a json file which is defined by ``ann_file``. + + Intermediate results, ``pq_stats``, are computed here and put into + ``self.results``. + """ + if self._coco_api is None: + categories = dict() + for id, name in enumerate(self.dataset_meta['classes']): + isthing = 1 if name in self.dataset_meta['thing_classes']\ + else 0 + categories[id] = {'id': id, 'name': name, 'isthing': isthing} + label2cat = None + else: + categories = self.categories + cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + label2cat = {i: cat_id for i, cat_id in enumerate(cat_ids)} + + for data_sample in data_samples: + # parse pred + img_id = data_sample['img_id'] + segm_file = osp.basename(data_sample['img_path']).replace( + 'jpg', 'png') + result = self._parse_predictions( + pred=data_sample, + img_id=img_id, + segm_file=segm_file, + label2cat=label2cat) + + # parse gt + gt = dict() + gt['image_id'] = img_id + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + gt['file_name'] = segm_file + + if self._coco_api is None: + # get segments_info from data_sample + seg_map_path = osp.join(self.seg_prefix, segm_file) + pan_png = mmcv.imread(seg_map_path).squeeze() + pan_png = pan_png[:, :, ::-1] + pan_png = rgb2id(pan_png) + segments_info = [] + + for segment_info in data_sample['segments_info']: + id = segment_info['id'] + label = segment_info['category'] + mask = pan_png == id + isthing = categories[label]['isthing'] + if isthing: + iscrowd = 1 if not segment_info['is_thing'] else 0 + else: + iscrowd = 0 + + new_segment_info = { + 'id': id, + 'category_id': label, + 'isthing': isthing, + 'iscrowd': iscrowd, + 'area': mask.sum() + } + segments_info.append(new_segment_info) + else: + # get segments_info from annotation file + segments_info = self._coco_api.imgToAnns[img_id] + + gt['segments_info'] = segments_info + + pq_stats = pq_compute_single_core( + proc_id=0, + annotation_set=[(gt, result)], + gt_folder=self.seg_prefix, + pred_folder=self.seg_out_dir, + categories=categories, + backend_args=self.backend_args) + + self.results.append(pq_stats) + + def _process_gt_and_predictions(self, data_samples: Sequence[dict]): + """Process gts and predictions when ``outfile_prefix`` is set. + + The predictions will be saved to directory specified by + ``outfile_predfix``. The matched pair (gt, result) will be put into + ``self.results``. + """ + for data_sample in data_samples: + # parse pred + img_id = data_sample['img_id'] + segm_file = osp.basename(data_sample['img_path']).replace( + 'jpg', 'png') + result = self._parse_predictions( + pred=data_sample, img_id=img_id, segm_file=segm_file) + + # parse gt + gt = dict() + gt['image_id'] = img_id + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + + if self._coco_api is None: + # get segments_info from dataset + gt['segments_info'] = data_sample['segments_info'] + gt['seg_map_path'] = data_sample['seg_map_path'] + + self.results.append((gt, result)) + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + # If ``self.tmp_dir`` is none, it will save gt and predictions to + # self.results, otherwise, it will compute pq_stats here. + if self.tmp_dir is None: + self._process_gt_and_predictions(data_samples) + else: + self._compute_batch_pq_stats(data_samples) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. There + are two cases: + + - When ``outfile_prefix`` is not provided, the elements in + results are pq_stats which can be summed directly to get PQ. + - When ``outfile_prefix`` is provided, the elements in + results are tuples like (gt, pred). + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + if self.tmp_dir is None: + # do evaluation after collect all the results + + # split gt and prediction list + gts, preds = zip(*results) + + if self._coco_api is None: + # use converted gt json file to initialize coco api + logger.info('Converting ground truth to coco format...') + coco_json_path, gt_folder = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=self.outfile_prefix) + self._coco_api = COCOPanoptic(coco_json_path) + else: + gt_folder = self.seg_prefix + + self.cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + self.cat2label = { + cat_id: i + for i, cat_id in enumerate(self.cat_ids) + } + self.img_ids = self._coco_api.get_img_ids() + self.categories = self._coco_api.cats + + # convert predictions to coco format and dump to json file + json_filename, pred_folder = self.result2json( + results=preds, outfile_prefix=self.outfile_prefix) + + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(self.outfile_prefix)}') + return dict() + + imgs = self._coco_api.imgs + gt_json = self._coco_api.img_ann_map + gt_json = [{ + 'image_id': k, + 'segments_info': v, + 'file_name': imgs[k]['segm_file'] + } for k, v in gt_json.items()] + pred_json = load(json_filename) + pred_json = dict( + (el['image_id'], el) for el in pred_json['annotations']) + + # match the gt_anns and pred_anns in the same image + matched_annotations_list = [] + for gt_ann in gt_json: + img_id = gt_ann['image_id'] + if img_id not in pred_json.keys(): + raise Exception('no prediction for the image' + ' with id: {}'.format(img_id)) + matched_annotations_list.append((gt_ann, pred_json[img_id])) + + pq_stat = pq_compute_multi_core( + matched_annotations_list, + gt_folder, + pred_folder, + self.categories, + backend_args=self.backend_args, + nproc=self.nproc) + + else: + # aggregate the results generated in process + if self._coco_api is None: + categories = dict() + for id, name in enumerate(self.dataset_meta['classes']): + isthing = 1 if name in self.dataset_meta[ + 'thing_classes'] else 0 + categories[id] = { + 'id': id, + 'name': name, + 'isthing': isthing + } + self.categories = categories + + pq_stat = PQStat() + for result in results: + pq_stat += result + + metrics = [('All', None), ('Things', True), ('Stuff', False)] + pq_results = {} + + for name, isthing in metrics: + pq_results[name], classwise_results = pq_stat.pq_average( + self.categories, isthing=isthing) + if name == 'All': + pq_results['classwise'] = classwise_results + + classwise_results = None + if self.classwise: + classwise_results = { + k: v + for k, v in zip(self.dataset_meta['classes'], + pq_results['classwise'].values()) + } + + print_panoptic_table(pq_results, classwise_results, logger=logger) + results = parse_pq_results(pq_results) + + return results + + +def parse_pq_results(pq_results: dict) -> dict: + """Parse the Panoptic Quality results. + + Args: + pq_results (dict): Panoptic Quality results. + + Returns: + dict: Panoptic Quality results parsed. + """ + result = dict() + result['PQ'] = 100 * pq_results['All']['pq'] + result['SQ'] = 100 * pq_results['All']['sq'] + result['RQ'] = 100 * pq_results['All']['rq'] + result['PQ_th'] = 100 * pq_results['Things']['pq'] + result['SQ_th'] = 100 * pq_results['Things']['sq'] + result['RQ_th'] = 100 * pq_results['Things']['rq'] + result['PQ_st'] = 100 * pq_results['Stuff']['pq'] + result['SQ_st'] = 100 * pq_results['Stuff']['sq'] + result['RQ_st'] = 100 * pq_results['Stuff']['rq'] + return result + + +def print_panoptic_table( + pq_results: dict, + classwise_results: Optional[dict] = None, + logger: Optional[Union['MMLogger', str]] = None) -> None: + """Print the panoptic evaluation results table. + + Args: + pq_results(dict): The Panoptic Quality results. + classwise_results(dict, optional): The classwise Panoptic Quality. + results. The keys are class names and the values are metrics. + Defaults to None. + logger (:obj:`MMLogger` | str, optional): Logger used for printing + related information during evaluation. Default: None. + """ + + headers = ['', 'PQ', 'SQ', 'RQ', 'categories'] + data = [headers] + for name in ['All', 'Things', 'Stuff']: + numbers = [ + f'{(pq_results[name][k] * 100):0.3f}' for k in ['pq', 'sq', 'rq'] + ] + row = [name] + numbers + [pq_results[name]['n']] + data.append(row) + table = AsciiTable(data) + print_log('Panoptic Evaluation Results:\n' + table.table, logger=logger) + + if classwise_results is not None: + class_metrics = [(name, ) + tuple(f'{(metrics[k] * 100):0.3f}' + for k in ['pq', 'sq', 'rq']) + for name, metrics in classwise_results.items()] + num_columns = min(8, len(class_metrics) * 4) + results_flatten = list(itertools.chain(*class_metrics)) + headers = ['category', 'PQ', 'SQ', 'RQ'] * (num_columns // 4) + results_2d = itertools.zip_longest( + *[results_flatten[i::num_columns] for i in range(num_columns)]) + data = [headers] + data += [result for result in results_2d] + table = AsciiTable(data) + print_log( + 'Classwise Panoptic Evaluation Results:\n' + table.table, + logger=logger) diff --git a/mmdet/evaluation/metrics/coco_video_metric.py b/mmdet/evaluation/metrics/coco_video_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c75d025a6109762db21a600e3d866764caf1cb --- /dev/null +++ b/mmdet/evaluation/metrics/coco_video_metric.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Sequence + +from mmengine.dist import broadcast_object_list, is_main_process + +from mmdet.registry import METRICS +from .base_video_metric import collect_tracking_results +from .coco_metric import CocoMetric + + +@METRICS.register_module() +class CocoVideoMetric(CocoMetric): + """COCO evaluation metric. + + Evaluate AR, AP, and mAP for detection tasks including proposal/box + detection and instance segmentation. Please refer to + https://cocodataset.org/#detection-eval for more details. + """ + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for track_data_sample in data_samples: + video_data_samples = track_data_sample['video_data_samples'] + ori_video_len = video_data_samples[0].ori_video_length + video_len = len(video_data_samples) + if ori_video_len == video_len: + # video process + for frame_id in range(video_len): + img_data_sample = video_data_samples[frame_id].to_dict() + super().process(None, [img_data_sample]) + else: + # image process + img_data_sample = video_data_samples[0].to_dict() + super().process(None, [img_data_sample]) + + def evaluate(self, size: int = 1) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + if len(self.results) == 0: + warnings.warn( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.') + + results = collect_tracking_results(self.results, self.collect_device) + + if is_main_process(): + _metrics = self.compute_metrics(results) # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results list + self.results.clear() + return metrics[0] diff --git a/mmdet/evaluation/metrics/crowdhuman_metric.py b/mmdet/evaluation/metrics/crowdhuman_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..50ac210ae8606bab6cada69418334c113c90fb38 --- /dev/null +++ b/mmdet/evaluation/metrics/crowdhuman_metric.py @@ -0,0 +1,824 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import json +import os.path as osp +import tempfile +from collections import OrderedDict +from multiprocessing import Process, Queue +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.fileio import dump, get_text, load +from mmengine.logging import MMLogger +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import maximum_bipartite_matching + +from mmdet.evaluation.functional.bbox_overlaps import bbox_overlaps +from mmdet.registry import METRICS + +PERSON_CLASSES = ['background', 'person'] + + +@METRICS.register_module() +class CrowdHumanMetric(BaseMetric): + """CrowdHuman evaluation metric. + + Evaluate Average Precision (AP), Miss Rate (MR) and Jaccard Index (JI) + for detection tasks. + + Args: + ann_file (str): Path to the annotation file. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'AP', 'MR' and 'JI'. Defaults to 'AP'. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + eval_mode (int): Select the mode of evaluate. Valid mode include + 0(just body box), 1(just head box) and 2(both of them). + Defaults to 0. + iou_thres (float): IoU threshold. Defaults to 0.5. + compare_matching_method (str, optional): Matching method to compare + the detection results with the ground_truth when compute 'AP' + and 'MR'.Valid method include VOC and None(CALTECH). Default to + None. + mr_ref (str): Different parameter selection to calculate MR. Valid + ref include CALTECH_-2 and CALTECH_-4. Defaults to CALTECH_-2. + num_ji_process (int): The number of processes to evaluation JI. + Defaults to 10. + """ + default_prefix: Optional[str] = 'crowd_human' + + def __init__(self, + ann_file: str, + metric: Union[str, List[str]] = ['AP', 'MR', 'JI'], + format_only: bool = False, + outfile_prefix: Optional[str] = None, + file_client_args: dict = None, + backend_args: dict = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + eval_mode: int = 0, + iou_thres: float = 0.5, + compare_matching_method: Optional[str] = None, + mr_ref: str = 'CALTECH_-2', + num_ji_process: int = 10) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + self.ann_file = ann_file + # crowdhuman evaluation metrics + self.metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['MR', 'AP', 'JI'] + for metric in self.metrics: + if metric not in allowed_metrics: + raise KeyError(f"metric should be one of 'MR', 'AP', 'JI'," + f'but got {metric}.') + + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + self.outfile_prefix = outfile_prefix + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + assert eval_mode in [0, 1, 2], \ + "Unknown eval mode. mr_ref should be one of '0', '1', '2'." + assert compare_matching_method is None or \ + compare_matching_method == 'VOC', \ + 'The alternative compare_matching_method is VOC.' \ + 'This parameter defaults to CALTECH(None)' + assert mr_ref == 'CALTECH_-2' or mr_ref == 'CALTECH_-4', \ + "mr_ref should be one of 'CALTECH_-2', 'CALTECH_-4'." + self.eval_mode = eval_mode + self.iou_thres = iou_thres + self.compare_matching_method = compare_matching_method + self.mr_ref = mr_ref + self.num_ji_process = num_ji_process + + @staticmethod + def results2json(results: Sequence[dict], outfile_prefix: str) -> str: + """Dump the detection results to a json file.""" + result_file_path = f'{outfile_prefix}.json' + bbox_json_results = [] + for i, result in enumerate(results): + ann, pred = result + dump_dict = dict() + dump_dict['ID'] = ann['ID'] + dump_dict['width'] = ann['width'] + dump_dict['height'] = ann['height'] + dtboxes = [] + bboxes = pred.tolist() + for _, single_bbox in enumerate(bboxes): + temp_dict = dict() + x1, y1, x2, y2, score = single_bbox + temp_dict['box'] = [x1, y1, x2 - x1, y2 - y1] + temp_dict['score'] = score + temp_dict['tag'] = 1 + dtboxes.append(temp_dict) + dump_dict['dtboxes'] = dtboxes + bbox_json_results.append(dump_dict) + dump(bbox_json_results, result_file_path) + return result_file_path + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + ann = dict() + ann['ID'] = data_sample['img_id'] + ann['width'] = data_sample['ori_shape'][1] + ann['height'] = data_sample['ori_shape'][0] + pred_bboxes = data_sample['pred_instances']['bboxes'].cpu().numpy() + pred_scores = data_sample['pred_instances']['scores'].cpu().numpy() + + pred_bbox_scores = np.hstack( + [pred_bboxes, pred_scores.reshape((-1, 1))]) + + self.results.append((ann, pred_bbox_scores)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + eval_results(Dict[str, float]): The computed metrics. + The keys are the names of the metrics, and the values + are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'result') + else: + outfile_prefix = self.outfile_prefix + + # convert predictions to coco format and dump to json file + result_file = self.results2json(results, outfile_prefix) + eval_results = OrderedDict() + if self.format_only: + logger.info(f'results are saved in {osp.dirname(outfile_prefix)}') + return eval_results + + # load evaluation samples + eval_samples = self.load_eval_samples(result_file) + + if 'AP' in self.metrics or 'MR' in self.metrics: + score_list = self.compare(eval_samples) + gt_num = sum([eval_samples[i].gt_num for i in eval_samples]) + ign_num = sum([eval_samples[i].ign_num for i in eval_samples]) + gt_num = gt_num - ign_num + img_num = len(eval_samples) + + for metric in self.metrics: + logger.info(f'Evaluating {metric}...') + if metric == 'AP': + AP = self.eval_ap(score_list, gt_num, img_num) + eval_results['mAP'] = float(f'{round(AP, 4)}') + if metric == 'MR': + MR = self.eval_mr(score_list, gt_num, img_num) + eval_results['mMR'] = float(f'{round(MR, 4)}') + if metric == 'JI': + JI = self.eval_ji(eval_samples) + eval_results['JI'] = float(f'{round(JI, 4)}') + if tmp_dir is not None: + tmp_dir.cleanup() + + return eval_results + + def load_eval_samples(self, result_file): + """Load data from annotations file and detection results. + + Args: + result_file (str): The file path of the saved detection results. + + Returns: + Dict[Image]: The detection result packaged by Image + """ + gt_str = get_text( + self.ann_file, backend_args=self.backend_args).strip().split('\n') + gt_records = [json.loads(line) for line in gt_str] + + pred_records = load(result_file, backend_args=self.backend_args) + eval_samples = dict() + for gt_record, pred_record in zip(gt_records, pred_records): + assert gt_record['ID'] == pred_record['ID'], \ + 'please set val_dataloader.sampler.shuffle=False and try again' + eval_samples[pred_record['ID']] = Image(self.eval_mode) + eval_samples[pred_record['ID']].load(gt_record, 'box', None, + PERSON_CLASSES, True) + eval_samples[pred_record['ID']].load(pred_record, 'box', None, + PERSON_CLASSES, False) + eval_samples[pred_record['ID']].clip_all_boader() + return eval_samples + + def compare(self, samples): + """Match the detection results with the ground_truth. + + Args: + samples (dict[Image]): The detection result packaged by Image. + + Returns: + score_list(list[tuple[ndarray, int, str]]): Matching result. + a list of tuples (dtbox, label, imgID) in the descending + sort of dtbox.score. + """ + score_list = list() + for id in samples: + if self.compare_matching_method == 'VOC': + result = samples[id].compare_voc(self.iou_thres) + else: + result = samples[id].compare_caltech(self.iou_thres) + score_list.extend(result) + # In the descending sort of dtbox score. + score_list.sort(key=lambda x: x[0][-1], reverse=True) + return score_list + + @staticmethod + def eval_ap(score_list, gt_num, img_num): + """Evaluate by average precision. + + Args: + score_list(list[tuple[ndarray, int, str]]): Matching result. + a list of tuples (dtbox, label, imgID) in the descending + sort of dtbox.score. + gt_num(int): The number of gt boxes in the entire dataset. + img_num(int): The number of images in the entire dataset. + + Returns: + ap(float): result of average precision. + """ + + # calculate general ap score + def _calculate_map(_recall, _precision): + assert len(_recall) == len(_precision) + area = 0 + for k in range(1, len(_recall)): + delta_h = (_precision[k - 1] + _precision[k]) / 2 + delta_w = _recall[k] - _recall[k - 1] + area += delta_w * delta_h + return area + + tp, fp = 0.0, 0.0 + rpX, rpY = list(), list() + + fpn = [] + recalln = [] + thr = [] + fppi = [] + for i, item in enumerate(score_list): + if item[1] == 1: + tp += 1.0 + elif item[1] == 0: + fp += 1.0 + fn = gt_num - tp + recall = tp / (tp + fn) + precision = tp / (tp + fp) + rpX.append(recall) + rpY.append(precision) + fpn.append(fp) + recalln.append(tp) + thr.append(item[0][-1]) + fppi.append(fp / img_num) + + ap = _calculate_map(rpX, rpY) + return ap + + def eval_mr(self, score_list, gt_num, img_num): + """Evaluate by Caltech-style log-average miss rate. + + Args: + score_list(list[tuple[ndarray, int, str]]): Matching result. + a list of tuples (dtbox, label, imgID) in the descending + sort of dtbox.score. + gt_num(int): The number of gt boxes in the entire dataset. + img_num(int): The number of image in the entire dataset. + + Returns: + mr(float): result of miss rate. + """ + + # find greater_than + def _find_gt(lst, target): + for idx, _item in enumerate(lst): + if _item >= target: + return idx + return len(lst) - 1 + + if self.mr_ref == 'CALTECH_-2': + # CALTECH_MRREF_2: anchor points (from 10^-2 to 1) as in + # P.Dollar's paper + ref = [ + 0.0100, 0.0178, 0.03160, 0.0562, 0.1000, 0.1778, 0.3162, + 0.5623, 1.000 + ] + else: + # CALTECH_MRREF_4: anchor points (from 10^-4 to 1) as in + # S.Zhang's paper + ref = [ + 0.0001, 0.0003, 0.00100, 0.0032, 0.0100, 0.0316, 0.1000, + 0.3162, 1.000 + ] + + tp, fp = 0.0, 0.0 + fppiX, fppiY = list(), list() + for i, item in enumerate(score_list): + if item[1] == 1: + tp += 1.0 + elif item[1] == 0: + fp += 1.0 + + fn = gt_num - tp + recall = tp / (tp + fn) + missrate = 1.0 - recall + fppi = fp / img_num + fppiX.append(fppi) + fppiY.append(missrate) + + score = list() + for pos in ref: + argmin = _find_gt(fppiX, pos) + if argmin >= 0: + score.append(fppiY[argmin]) + score = np.array(score) + mr = np.exp(np.log(score).mean()) + return mr + + def eval_ji(self, samples): + """Evaluate by JI using multi_process. + + Args: + samples(Dict[str, Image]): The detection result packaged by Image. + + Returns: + ji(float): result of jaccard index. + """ + import math + res_line = [] + res_ji = [] + for i in range(10): + score_thr = 1e-1 * i + total = len(samples) + stride = math.ceil(total / self.num_ji_process) + result_queue = Queue(10000) + results, procs = [], [] + records = list(samples.items()) + for i in range(self.num_ji_process): + start = i * stride + end = np.min([start + stride, total]) + sample_data = dict(records[start:end]) + p = Process( + target=self.compute_ji_with_ignore, + args=(result_queue, sample_data, score_thr)) + p.start() + procs.append(p) + for i in range(total): + t = result_queue.get() + results.append(t) + for p in procs: + p.join() + line, mean_ratio = self.gather(results) + line = 'score_thr:{:.1f}, {}'.format(score_thr, line) + res_line.append(line) + res_ji.append(mean_ratio) + return max(res_ji) + + def compute_ji_with_ignore(self, result_queue, dt_result, score_thr): + """Compute JI with ignore. + + Args: + result_queue(Queue): The Queue for save compute result when + multi_process. + dt_result(dict[Image]): Detection result packaged by Image. + score_thr(float): The threshold of detection score. + Returns: + dict: compute result. + """ + for ID, record in dt_result.items(): + gt_boxes = record.gt_boxes + dt_boxes = record.dt_boxes + keep = dt_boxes[:, -1] > score_thr + dt_boxes = dt_boxes[keep][:, :-1] + + gt_tag = np.array(gt_boxes[:, -1] != -1) + matches = self.compute_ji_matching(dt_boxes, gt_boxes[gt_tag, :4]) + # get the unmatched_indices + matched_indices = np.array([j for (j, _) in matches]) + unmatched_indices = list( + set(np.arange(dt_boxes.shape[0])) - set(matched_indices)) + num_ignore_dt = self.get_ignores(dt_boxes[unmatched_indices], + gt_boxes[~gt_tag, :4]) + matched_indices = np.array([j for (_, j) in matches]) + unmatched_indices = list( + set(np.arange(gt_boxes[gt_tag].shape[0])) - + set(matched_indices)) + num_ignore_gt = self.get_ignores( + gt_boxes[gt_tag][unmatched_indices], gt_boxes[~gt_tag, :4]) + # compute results + eps = 1e-6 + k = len(matches) + m = gt_tag.sum() - num_ignore_gt + n = dt_boxes.shape[0] - num_ignore_dt + ratio = k / (m + n - k + eps) + recall = k / (m + eps) + cover = k / (n + eps) + noise = 1 - cover + result_dict = dict( + ratio=ratio, + recall=recall, + cover=cover, + noise=noise, + k=k, + m=m, + n=n) + result_queue.put_nowait(result_dict) + + @staticmethod + def gather(results): + """Integrate test results.""" + assert len(results) + img_num = 0 + for result in results: + if result['n'] != 0 or result['m'] != 0: + img_num += 1 + mean_ratio = np.sum([rb['ratio'] for rb in results]) / img_num + valids = np.sum([rb['k'] for rb in results]) + total = np.sum([rb['n'] for rb in results]) + gtn = np.sum([rb['m'] for rb in results]) + line = 'mean_ratio:{:.4f}, valids:{}, total:{}, gtn:{}'\ + .format(mean_ratio, valids, total, gtn) + return line, mean_ratio + + def compute_ji_matching(self, dt_boxes, gt_boxes): + """Match the annotation box for each detection box. + + Args: + dt_boxes(ndarray): Detection boxes. + gt_boxes(ndarray): Ground_truth boxes. + + Returns: + matches_(list[tuple[int, int]]): Match result. + """ + assert dt_boxes.shape[-1] > 3 and gt_boxes.shape[-1] > 3 + if dt_boxes.shape[0] < 1 or gt_boxes.shape[0] < 1: + return list() + + ious = bbox_overlaps(dt_boxes, gt_boxes, mode='iou') + input_ = copy.deepcopy(ious) + input_[input_ < self.iou_thres] = 0 + match_scipy = maximum_bipartite_matching( + csr_matrix(input_), perm_type='column') + matches_ = [] + for i in range(len(match_scipy)): + if match_scipy[i] != -1: + matches_.append((i, int(match_scipy[i]))) + return matches_ + + def get_ignores(self, dt_boxes, gt_boxes): + """Get the number of ignore bboxes.""" + if gt_boxes.size: + ioas = bbox_overlaps(dt_boxes, gt_boxes, mode='iof') + ioas = np.max(ioas, axis=1) + rows = np.where(ioas > self.iou_thres)[0] + return len(rows) + else: + return 0 + + +class Image(object): + """Data structure for evaluation of CrowdHuman. + + Note: + This implementation is modified from https://github.com/Purkialo/ + CrowdDet/blob/master/lib/evaluate/APMRToolkits/image.py + + Args: + mode (int): Select the mode of evaluate. Valid mode include + 0(just body box), 1(just head box) and 2(both of them). + Defaults to 0. + """ + + def __init__(self, mode): + self.ID = None + self.width = None + self.height = None + self.dt_boxes = None + self.gt_boxes = None + self.eval_mode = mode + + self.ign_num = None + self.gt_num = None + self.dt_num = None + + def load(self, record, body_key, head_key, class_names, gt_flag): + """Loading information for evaluation. + + Args: + record (dict): Label information or test results. + The format might look something like this: + { + 'ID': '273271,c9db000d5146c15', + 'gtboxes': [ + {'fbox': [72, 202, 163, 503], 'tag': 'person', ...}, + {'fbox': [199, 180, 144, 499], 'tag': 'person', ...}, + ... + ] + } + or: + { + 'ID': '273271,c9db000d5146c15', + 'width': 800, + 'height': 1067, + 'dtboxes': [ + { + 'box': [306.22, 205.95, 164.05, 394.04], + 'score': 0.99, + 'tag': 1 + }, + { + 'box': [403.60, 178.66, 157.15, 421.33], + 'score': 0.99, + 'tag': 1 + }, + ... + ] + } + body_key (str, None): key of detection body box. + Valid when loading detection results and self.eval_mode!=1. + head_key (str, None): key of detection head box. + Valid when loading detection results and self.eval_mode!=0. + class_names (list[str]):class names of data set. + Defaults to ['background', 'person']. + gt_flag (bool): Indicate whether record is ground truth + or predicting the outcome. + """ + if 'ID' in record and self.ID is None: + self.ID = record['ID'] + if 'width' in record and self.width is None: + self.width = record['width'] + if 'height' in record and self.height is None: + self.height = record['height'] + if gt_flag: + self.gt_num = len(record['gtboxes']) + body_bbox, head_bbox = self.load_gt_boxes(record, 'gtboxes', + class_names) + if self.eval_mode == 0: + self.gt_boxes = body_bbox + self.ign_num = (body_bbox[:, -1] == -1).sum() + elif self.eval_mode == 1: + self.gt_boxes = head_bbox + self.ign_num = (head_bbox[:, -1] == -1).sum() + else: + gt_tag = np.array([ + body_bbox[i, -1] != -1 and head_bbox[i, -1] != -1 + for i in range(len(body_bbox)) + ]) + self.ign_num = (gt_tag == 0).sum() + self.gt_boxes = np.hstack( + (body_bbox[:, :-1], head_bbox[:, :-1], + gt_tag.reshape(-1, 1))) + + if not gt_flag: + self.dt_num = len(record['dtboxes']) + if self.eval_mode == 0: + self.dt_boxes = self.load_det_boxes(record, 'dtboxes', + body_key, 'score') + elif self.eval_mode == 1: + self.dt_boxes = self.load_det_boxes(record, 'dtboxes', + head_key, 'score') + else: + body_dtboxes = self.load_det_boxes(record, 'dtboxes', body_key, + 'score') + head_dtboxes = self.load_det_boxes(record, 'dtboxes', head_key, + 'score') + self.dt_boxes = np.hstack((body_dtboxes, head_dtboxes)) + + @staticmethod + def load_gt_boxes(dict_input, key_name, class_names): + """load ground_truth and transform [x, y, w, h] to [x1, y1, x2, y2]""" + assert key_name in dict_input + if len(dict_input[key_name]) < 1: + return np.empty([0, 5]) + head_bbox = [] + body_bbox = [] + for rb in dict_input[key_name]: + if rb['tag'] in class_names: + body_tag = class_names.index(rb['tag']) + head_tag = copy.deepcopy(body_tag) + else: + body_tag = -1 + head_tag = -1 + if 'extra' in rb: + if 'ignore' in rb['extra']: + if rb['extra']['ignore'] != 0: + body_tag = -1 + head_tag = -1 + if 'head_attr' in rb: + if 'ignore' in rb['head_attr']: + if rb['head_attr']['ignore'] != 0: + head_tag = -1 + head_bbox.append(np.hstack((rb['hbox'], head_tag))) + body_bbox.append(np.hstack((rb['fbox'], body_tag))) + head_bbox = np.array(head_bbox) + head_bbox[:, 2:4] += head_bbox[:, :2] + body_bbox = np.array(body_bbox) + body_bbox[:, 2:4] += body_bbox[:, :2] + return body_bbox, head_bbox + + @staticmethod + def load_det_boxes(dict_input, key_name, key_box, key_score, key_tag=None): + """load detection boxes.""" + assert key_name in dict_input + if len(dict_input[key_name]) < 1: + return np.empty([0, 5]) + else: + assert key_box in dict_input[key_name][0] + if key_score: + assert key_score in dict_input[key_name][0] + if key_tag: + assert key_tag in dict_input[key_name][0] + if key_score: + if key_tag: + bboxes = np.vstack([ + np.hstack((rb[key_box], rb[key_score], rb[key_tag])) + for rb in dict_input[key_name] + ]) + else: + bboxes = np.vstack([ + np.hstack((rb[key_box], rb[key_score])) + for rb in dict_input[key_name] + ]) + else: + if key_tag: + bboxes = np.vstack([ + np.hstack((rb[key_box], rb[key_tag])) + for rb in dict_input[key_name] + ]) + else: + bboxes = np.vstack( + [rb[key_box] for rb in dict_input[key_name]]) + bboxes[:, 2:4] += bboxes[:, :2] + return bboxes + + def clip_all_boader(self): + """Make sure boxes are within the image range.""" + + def _clip_boundary(boxes, height, width): + assert boxes.shape[-1] >= 4 + boxes[:, 0] = np.minimum(np.maximum(boxes[:, 0], 0), width - 1) + boxes[:, 1] = np.minimum(np.maximum(boxes[:, 1], 0), height - 1) + boxes[:, 2] = np.maximum(np.minimum(boxes[:, 2], width), 0) + boxes[:, 3] = np.maximum(np.minimum(boxes[:, 3], height), 0) + return boxes + + assert self.dt_boxes.shape[-1] >= 4 + assert self.gt_boxes.shape[-1] >= 4 + assert self.width is not None and self.height is not None + if self.eval_mode == 2: + self.dt_boxes[:, :4] = _clip_boundary(self.dt_boxes[:, :4], + self.height, self.width) + self.gt_boxes[:, :4] = _clip_boundary(self.gt_boxes[:, :4], + self.height, self.width) + self.dt_boxes[:, 4:8] = _clip_boundary(self.dt_boxes[:, 4:8], + self.height, self.width) + self.gt_boxes[:, 4:8] = _clip_boundary(self.gt_boxes[:, 4:8], + self.height, self.width) + else: + self.dt_boxes = _clip_boundary(self.dt_boxes, self.height, + self.width) + self.gt_boxes = _clip_boundary(self.gt_boxes, self.height, + self.width) + + def compare_voc(self, thres): + """Match the detection results with the ground_truth by VOC. + + Args: + thres (float): IOU threshold. + + Returns: + score_list(list[tuple[ndarray, int, str]]): Matching result. + a list of tuples (dtbox, label, imgID) in the descending + sort of dtbox.score. + """ + if self.dt_boxes is None: + return list() + dtboxes = self.dt_boxes + gtboxes = self.gt_boxes if self.gt_boxes is not None else list() + dtboxes.sort(key=lambda x: x.score, reverse=True) + gtboxes.sort(key=lambda x: x.ign) + + score_list = list() + for i, dt in enumerate(dtboxes): + maxpos = -1 + maxiou = thres + + for j, gt in enumerate(gtboxes): + overlap = dt.iou(gt) + if overlap > maxiou: + maxiou = overlap + maxpos = j + + if maxpos >= 0: + if gtboxes[maxpos].ign == 0: + gtboxes[maxpos].matched = 1 + dtboxes[i].matched = 1 + score_list.append((dt, self.ID)) + else: + dtboxes[i].matched = -1 + else: + dtboxes[i].matched = 0 + score_list.append((dt, self.ID)) + return score_list + + def compare_caltech(self, thres): + """Match the detection results with the ground_truth by Caltech + matching strategy. + + Args: + thres (float): IOU threshold. + + Returns: + score_list(list[tuple[ndarray, int, str]]): Matching result. + a list of tuples (dtbox, label, imgID) in the descending + sort of dtbox.score. + """ + if self.dt_boxes is None or self.gt_boxes is None: + return list() + + dtboxes = self.dt_boxes if self.dt_boxes is not None else list() + gtboxes = self.gt_boxes if self.gt_boxes is not None else list() + dt_matched = np.zeros(dtboxes.shape[0]) + gt_matched = np.zeros(gtboxes.shape[0]) + + dtboxes = np.array(sorted(dtboxes, key=lambda x: x[-1], reverse=True)) + gtboxes = np.array(sorted(gtboxes, key=lambda x: x[-1], reverse=True)) + if len(dtboxes): + overlap_iou = bbox_overlaps(dtboxes, gtboxes, mode='iou') + overlap_ioa = bbox_overlaps(dtboxes, gtboxes, mode='iof') + else: + return list() + + score_list = list() + for i, dt in enumerate(dtboxes): + maxpos = -1 + maxiou = thres + for j, gt in enumerate(gtboxes): + if gt_matched[j] == 1: + continue + if gt[-1] > 0: + overlap = overlap_iou[i][j] + if overlap > maxiou: + maxiou = overlap + maxpos = j + else: + if maxpos >= 0: + break + else: + overlap = overlap_ioa[i][j] + if overlap > thres: + maxiou = overlap + maxpos = j + if maxpos >= 0: + if gtboxes[maxpos, -1] > 0: + gt_matched[maxpos] = 1 + dt_matched[i] = 1 + score_list.append((dt, 1, self.ID)) + else: + dt_matched[i] = -1 + else: + dt_matched[i] = 0 + score_list.append((dt, 0, self.ID)) + return score_list diff --git a/mmdet/evaluation/metrics/dump_det_results.py b/mmdet/evaluation/metrics/dump_det_results.py new file mode 100644 index 0000000000000000000000000000000000000000..f3071d19a6ad0199458d13dfe6f570f181a5ea7f --- /dev/null +++ b/mmdet/evaluation/metrics/dump_det_results.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Sequence + +from mmengine.evaluator import DumpResults +from mmengine.evaluator.metric import _to_cpu + +from mmdet.registry import METRICS +from mmdet.structures.mask import encode_mask_results + + +@METRICS.register_module() +class DumpDetResults(DumpResults): + """Dump model predictions to a pickle file for offline evaluation. + + Different from `DumpResults` in MMEngine, it compresses instance + segmentation masks into RLE format. + + Args: + out_file_path (str): Path of the dumped file. Must end with '.pkl' + or '.pickle'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + """ + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """transfer tensors in predictions to CPU.""" + data_samples = _to_cpu(data_samples) + for data_sample in data_samples: + # remove gt + data_sample.pop('gt_instances', None) + data_sample.pop('ignored_instances', None) + data_sample.pop('gt_panoptic_seg', None) + + if 'pred_instances' in data_sample: + pred = data_sample['pred_instances'] + # encode mask to RLE + if 'masks' in pred: + pred['masks'] = encode_mask_results(pred['masks'].numpy()) + if 'pred_panoptic_seg' in data_sample: + warnings.warn( + 'Panoptic segmentation map will not be compressed. ' + 'The dumped file will be extremely large! ' + 'Suggest using `CocoPanopticMetric` to save the coco ' + 'format json and segmentation png files directly.') + self.results.extend(data_samples) diff --git a/mmdet/evaluation/metrics/dump_proposals_metric.py b/mmdet/evaluation/metrics/dump_proposals_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9c53654c15d4b1f7e6555a9a7c53f844cb071f --- /dev/null +++ b/mmdet/evaluation/metrics/dump_proposals_metric.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from typing import Optional, Sequence + +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.fileio import dump +from mmengine.logging import MMLogger +from mmengine.structures import InstanceData + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class DumpProposals(BaseMetric): + """Dump proposals pseudo metric. + + Args: + output_dir (str): The root directory for ``proposals_file``. + Defaults to ''. + proposals_file (str): Proposals file path. Defaults to 'proposals.pkl'. + num_max_proposals (int, optional): Maximum number of proposals to dump. + If not specified, all proposals will be dumped. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + default_prefix: Optional[str] = 'dump_proposals' + + def __init__(self, + output_dir: str = '', + proposals_file: str = 'proposals.pkl', + num_max_proposals: Optional[int] = None, + file_client_args: dict = None, + backend_args: dict = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.num_max_proposals = num_max_proposals + # TODO: update after mmengine finish refactor fileio. + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + self.output_dir = output_dir + assert proposals_file.endswith(('.pkl', '.pickle')), \ + 'The output file must be a pkl file.' + + self.proposals_file = os.path.join(self.output_dir, proposals_file) + if is_main_process(): + os.makedirs(self.output_dir, exist_ok=True) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + pred = data_sample['pred_instances'] + # `bboxes` is sorted by `scores` + ranked_scores, rank_inds = pred['scores'].sort(descending=True) + ranked_bboxes = pred['bboxes'][rank_inds, :] + + ranked_bboxes = ranked_bboxes.cpu().numpy() + ranked_scores = ranked_scores.cpu().numpy() + + pred_instance = InstanceData() + pred_instance.bboxes = ranked_bboxes + pred_instance.scores = ranked_scores + if self.num_max_proposals is not None: + pred_instance = pred_instance[:self.num_max_proposals] + + img_path = data_sample['img_path'] + # `file_name` is the key to obtain the proposals from the + # `proposals_list`. + file_name = osp.join( + osp.split(osp.split(img_path)[0])[-1], + osp.split(img_path)[-1]) + result = {file_name: pred_instance} + self.results.append(result) + + def compute_metrics(self, results: list) -> dict: + """Dump the processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: An empty dict. + """ + logger: MMLogger = MMLogger.get_current_instance() + dump_results = {} + for result in results: + dump_results.update(result) + dump( + dump_results, + file=self.proposals_file, + backend_args=self.backend_args) + logger.info(f'Results are saved at {self.proposals_file}') + return {} diff --git a/mmdet/evaluation/metrics/lvis_metric.py b/mmdet/evaluation/metrics/lvis_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..e4dd6141c0e3f94758a040fd2e2a72ea43ea9b63 --- /dev/null +++ b/mmdet/evaluation/metrics/lvis_metric.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import os.path as osp +import tempfile +import warnings +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +from mmengine.fileio import get_local_path +from mmengine.logging import MMLogger +from terminaltables import AsciiTable + +from mmdet.registry import METRICS +from mmdet.structures.mask import encode_mask_results +from ..functional import eval_recalls +from .coco_metric import CocoMetric + +try: + import lvis + if getattr(lvis, '__version__', '0') >= '10.5.3': + warnings.warn( + 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501 + UserWarning) + from lvis import LVIS, LVISEval, LVISResults +except ImportError: + lvis = None + LVISEval = None + LVISResults = None + + +@METRICS.register_module() +class LVISMetric(CocoMetric): + """LVIS evaluation metric. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', 'proposal', and 'proposal_fast'. + Defaults to 'bbox'. + classwise (bool): Whether to evaluate the metric class-wise. + Defaults to False. + proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. + Defaults to (100, 300, 1000). + iou_thrs (float | List[float], optional): IoU threshold to compute AP + and AR. If not specified, IoUs from 0.5 to 0.95 will be used. + Defaults to None. + metric_items (List[str], optional): Metric result names to be + recorded in the evaluation result. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + + default_prefix: Optional[str] = 'lvis' + + def __init__(self, + ann_file: Optional[str] = None, + metric: Union[str, List[str]] = 'bbox', + classwise: bool = False, + proposal_nums: Sequence[int] = (100, 300, 1000), + iou_thrs: Optional[Union[float, Sequence[float]]] = None, + metric_items: Optional[Sequence[str]] = None, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + file_client_args: dict = None, + backend_args: dict = None) -> None: + if lvis is None: + raise RuntimeError( + 'Package lvis is not installed. Please run "pip install ' + 'git+https://github.com/lvis-dataset/lvis-api.git".') + super().__init__(collect_device=collect_device, prefix=prefix) + # coco evaluation metrics + self.metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] + for metric in self.metrics: + if metric not in allowed_metrics: + raise KeyError( + "metric should be one of 'bbox', 'segm', 'proposal', " + f"'proposal_fast', but got {metric}.") + + # do class wise evaluation, default False + self.classwise = classwise + + # proposal_nums used to compute recall or precision. + self.proposal_nums = list(proposal_nums) + + # iou_thrs used to compute recall or precision. + if iou_thrs is None: + iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.iou_thrs = iou_thrs + self.metric_items = metric_items + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.outfile_prefix = outfile_prefix + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + # if ann_file is not specified, + # initialize lvis api with the converted dataset + if ann_file is not None: + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._lvis_api = LVIS(local_path) + else: + self._lvis_api = None + + # handle dataset lazy init + self.cat_ids = None + self.img_ids = None + + def fast_eval_recall(self, + results: List[dict], + proposal_nums: Sequence[int], + iou_thrs: Sequence[float], + logger: Optional[MMLogger] = None) -> np.ndarray: + """Evaluate proposal recall with LVIS's fast_eval_recall. + + Args: + results (List[dict]): Results of the dataset. + proposal_nums (Sequence[int]): Proposal numbers used for + evaluation. + iou_thrs (Sequence[float]): IoU thresholds used for evaluation. + logger (MMLogger, optional): Logger used for logging the recall + summary. + Returns: + np.ndarray: Averaged recall results. + """ + gt_bboxes = [] + pred_bboxes = [result['bboxes'] for result in results] + for i in range(len(self.img_ids)): + ann_ids = self._lvis_api.get_ann_ids(img_ids=[self.img_ids[i]]) + ann_info = self._lvis_api.load_anns(ann_ids) + if len(ann_info) == 0: + gt_bboxes.append(np.zeros((0, 4))) + continue + bboxes = [] + for ann in ann_info: + x1, y1, w, h = ann['bbox'] + bboxes.append([x1, y1, x1 + w, y1 + h]) + bboxes = np.array(bboxes, dtype=np.float32) + if bboxes.shape[0] == 0: + bboxes = np.zeros((0, 4)) + gt_bboxes.append(bboxes) + + recalls = eval_recalls( + gt_bboxes, pred_bboxes, proposal_nums, iou_thrs, logger=logger) + ar = recalls.mean(axis=1) + return ar + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + result = dict() + pred = data_sample['pred_instances'] + result['img_id'] = data_sample['img_id'] + result['bboxes'] = pred['bboxes'].cpu().numpy() + result['scores'] = pred['scores'].cpu().numpy() + result['labels'] = pred['labels'].cpu().numpy() + # encode mask to RLE + if 'masks' in pred: + result['masks'] = encode_mask_results( + pred['masks'].detach().cpu().numpy()) + # some detectors use different scores for bbox and mask + if 'mask_scores' in pred: + result['mask_scores'] = pred['mask_scores'].cpu().numpy() + + # parse gt + gt = dict() + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + gt['img_id'] = data_sample['img_id'] + if self._lvis_api is None: + # TODO: Need to refactor to support LoadAnnotations + assert 'instances' in data_sample, \ + 'ground truth is required for evaluation when ' \ + '`ann_file` is not provided' + gt['anns'] = data_sample['instances'] + # add converted result to the results list + self.results.append((gt, result)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # split gt and prediction list + gts, preds = zip(*results) + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + if self._lvis_api is None: + # use converted gt json file to initialize coco api + logger.info('Converting ground truth to coco format...') + coco_json_path = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=outfile_prefix) + self._lvis_api = LVIS(coco_json_path) + + # handle lazy init + if self.cat_ids is None: + self.cat_ids = self._lvis_api.get_cat_ids() + if self.img_ids is None: + self.img_ids = self._lvis_api.get_img_ids() + + # convert predictions to coco format and dump to json file + result_files = self.results2json(preds, outfile_prefix) + + eval_results = OrderedDict() + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + lvis_gt = self._lvis_api + + for metric in self.metrics: + logger.info(f'Evaluating {metric}...') + + # TODO: May refactor fast_eval_recall to an independent metric? + # fast eval recall + if metric == 'proposal_fast': + ar = self.fast_eval_recall( + preds, self.proposal_nums, self.iou_thrs, logger=logger) + log_msg = [] + for i, num in enumerate(self.proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') + log_msg = ''.join(log_msg) + logger.info(log_msg) + continue + + try: + lvis_dt = LVISResults(lvis_gt, result_files[metric]) + except IndexError: + logger.info( + 'The testing results of the whole dataset is empty.') + break + + iou_type = 'bbox' if metric == 'proposal' else metric + lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type) + lvis_eval.params.imgIds = self.img_ids + metric_items = self.metric_items + if metric == 'proposal': + lvis_eval.params.useCats = 0 + lvis_eval.params.maxDets = list(self.proposal_nums) + lvis_eval.evaluate() + lvis_eval.accumulate() + lvis_eval.summarize() + if metric_items is None: + metric_items = ['AR@300', 'ARs@300', 'ARm@300', 'ARl@300'] + for k, v in lvis_eval.get_results().items(): + if k in metric_items: + val = float('{:.3f}'.format(float(v))) + eval_results[k] = val + + else: + lvis_eval.evaluate() + lvis_eval.accumulate() + lvis_eval.summarize() + lvis_results = lvis_eval.get_results() + if self.classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = lvis_eval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, catId in enumerate(self.cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + # the dimensions of precisions are + # [num_thrs, num_recalls, num_cats, num_area_rngs] + nm = self._lvis_api.load_cats([catId])[0] + precision = precisions[:, :, idx, 0] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (f'{nm["name"]}', f'{float(ap):0.3f}')) + eval_results[f'{nm["name"]}_precision'] = round(ap, 3) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (num_columns // 2) + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + logger.info('\n' + table.table) + + if metric_items is None: + metric_items = [ + 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'APr', + 'APc', 'APf' + ] + + for k, v in lvis_results.items(): + if k in metric_items: + key = '{}_{}'.format(metric, k) + val = float('{:.3f}'.format(float(v))) + eval_results[key] = val + + lvis_eval.print_results() + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results diff --git a/mmdet/evaluation/metrics/mot_challenge_metric.py b/mmdet/evaluation/metrics/mot_challenge_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a5513c44e81de7dd869d4c5c802bfac0387bdbf6 --- /dev/null +++ b/mmdet/evaluation/metrics/mot_challenge_metric.py @@ -0,0 +1,443 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import shutil +import tempfile +from collections import defaultdict +from typing import List, Optional, Union + +import numpy as np +import torch + +try: + import trackeval +except ImportError: + trackeval = None +from mmengine.dist import (all_gather_object, barrier, broadcast, + broadcast_object_list, get_dist_info, + is_main_process) +from mmengine.logging import MMLogger + +from mmdet.registry import METRICS, TASK_UTILS +from .base_video_metric import BaseVideoMetric + + +def get_tmpdir() -> str: + """return the same tmpdir for all processes.""" + rank, world_size = get_dist_info() + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8) + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor(bytearray(tmpdir.encode()), dtype=torch.uint8) + dir_tensor[:len(tmpdir)] = tmpdir + broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + return tmpdir + + +@METRICS.register_module() +class MOTChallengeMetric(BaseVideoMetric): + """Evaluation metrics for MOT Challenge. + + Args: + metric (str | list[str]): Metrics to be evaluated. Options are + 'HOTA', 'CLEAR', 'Identity'. + Defaults to ['HOTA', 'CLEAR', 'Identity']. + outfile_prefix (str, optional): Path to save the formatted results. + Defaults to None. + track_iou_thr (float): IoU threshold for tracking evaluation. + Defaults to 0.5. + benchmark (str): Benchmark to be evaluated. Defaults to 'MOT17'. + format_only (bool): If True, only formatting the results to the + official format and not performing evaluation. Defaults to False. + postprocess_tracklet_cfg (List[dict], optional): configs for tracklets + postprocessing methods. `InterpolateTracklets` is supported. + Defaults to [] + - InterpolateTracklets: + - min_num_frames (int, optional): The minimum length of a + track that will be interpolated. Defaults to 5. + - max_num_frames (int, optional): The maximum disconnected + length in a track. Defaults to 20. + - use_gsi (bool, optional): Whether to use the GSI (Gaussian- + smoothed interpolation) method. Defaults to False. + - smooth_tau (int, optional): smoothing parameter in GSI. + Defaults to 10. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + Returns: + """ + TRACKER = 'default-tracker' + allowed_metrics = ['HOTA', 'CLEAR', 'Identity'] + allowed_benchmarks = ['MOT15', 'MOT16', 'MOT17', 'MOT20', 'DanceTrack'] + default_prefix: Optional[str] = 'motchallenge-metric' + + def __init__(self, + metric: Union[str, List[str]] = ['HOTA', 'CLEAR', 'Identity'], + outfile_prefix: Optional[str] = None, + track_iou_thr: float = 0.5, + benchmark: str = 'MOT17', + format_only: bool = False, + use_postprocess: bool = False, + postprocess_tracklet_cfg: Optional[List[dict]] = [], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + if trackeval is None: + raise RuntimeError( + 'trackeval is not installed,' + 'please install it by: pip install' + 'git+https://github.com/JonathonLuiten/TrackEval.git' + 'trackeval need low version numpy, please install it' + 'by: pip install -U numpy==1.23.5') + if isinstance(metric, list): + metrics = metric + elif isinstance(metric, str): + metrics = [metric] + else: + raise TypeError('metric must be a list or a str.') + for metric in metrics: + if metric not in self.allowed_metrics: + raise KeyError(f'metric {metric} is not supported.') + self.metrics = metrics + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + self.use_postprocess = use_postprocess + self.postprocess_tracklet_cfg = postprocess_tracklet_cfg.copy() + self.postprocess_tracklet_methods = [ + TASK_UTILS.build(cfg) for cfg in self.postprocess_tracklet_cfg + ] + assert benchmark in self.allowed_benchmarks + self.benchmark = benchmark + self.track_iou_thr = track_iou_thr + self.tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir.name = get_tmpdir() + self.seq_info = defaultdict( + lambda: dict(seq_length=-1, gt_tracks=[], pred_tracks=[])) + self.gt_dir = self._get_gt_dir() + self.pred_dir = self._get_pred_dir(outfile_prefix) + self.seqmap = osp.join(self.pred_dir, 'videoseq.txt') + with open(self.seqmap, 'w') as f: + f.write('name\n') + + def __del__(self): + # To avoid tmpdir being cleaned up too early, because in multiple + # consecutive ValLoops, the value of `self.tmp_dir.name` is unchanged, + # and calling `tmp_dir.cleanup()` in compute_metrics will cause errors. + self.tmp_dir.cleanup() + + def _get_pred_dir(self, outfile_prefix): + """Get directory to save the prediction results.""" + logger: MMLogger = MMLogger.get_current_instance() + + if outfile_prefix is None: + outfile_prefix = self.tmp_dir.name + else: + if osp.exists(outfile_prefix) and is_main_process(): + logger.info('remove previous results.') + shutil.rmtree(outfile_prefix) + pred_dir = osp.join(outfile_prefix, self.TRACKER) + os.makedirs(pred_dir, exist_ok=True) + return pred_dir + + def _get_gt_dir(self): + """Get directory to save the gt files.""" + output_dir = osp.join(self.tmp_dir.name, 'gt') + os.makedirs(output_dir, exist_ok=True) + return output_dir + + def transform_gt_and_pred(self, img_data_sample, video, frame_id): + + video = img_data_sample['img_path'].split(os.sep)[-3] + # load gts + if 'instances' in img_data_sample: + gt_instances = img_data_sample['instances'] + gt_tracks = [ + np.array([ + frame_id + 1, gt_instances[i]['instance_id'], + gt_instances[i]['bbox'][0], gt_instances[i]['bbox'][1], + gt_instances[i]['bbox'][2] - gt_instances[i]['bbox'][0], + gt_instances[i]['bbox'][3] - gt_instances[i]['bbox'][1], + gt_instances[i]['mot_conf'], + gt_instances[i]['category_id'], + gt_instances[i]['visibility'] + ]) for i in range(len(gt_instances)) + ] + self.seq_info[video]['gt_tracks'].extend(gt_tracks) + + # load predictions + assert 'pred_track_instances' in img_data_sample + if self.use_postprocess: + pred_instances = img_data_sample['pred_track_instances'] + pred_tracks = [ + pred_instances['bboxes'][i] + for i in range(len(pred_instances['bboxes'])) + ] + else: + pred_instances = img_data_sample['pred_track_instances'] + pred_tracks = [ + np.array([ + frame_id + 1, pred_instances['instances_id'][i].cpu(), + pred_instances['bboxes'][i][0].cpu(), + pred_instances['bboxes'][i][1].cpu(), + (pred_instances['bboxes'][i][2] - + pred_instances['bboxes'][i][0]).cpu(), + (pred_instances['bboxes'][i][3] - + pred_instances['bboxes'][i][1]).cpu(), + pred_instances['scores'][i].cpu() + ]) for i in range(len(pred_instances['instances_id'])) + ] + self.seq_info[video]['pred_tracks'].extend(pred_tracks) + + def process_image(self, data_samples, video_len): + + img_data_sample = data_samples[0].to_dict() + video = img_data_sample['img_path'].split(os.sep)[-3] + frame_id = img_data_sample['frame_id'] + if self.seq_info[video]['seq_length'] == -1: + self.seq_info[video]['seq_length'] = video_len + self.transform_gt_and_pred(img_data_sample, video, frame_id) + + if frame_id == video_len - 1: + # postprocessing + if self.postprocess_tracklet_cfg: + info = self.seq_info[video] + pred_tracks = np.array(info['pred_tracks']) + for postprocess_tracklet_methods in \ + self.postprocess_tracklet_methods: + pred_tracks = postprocess_tracklet_methods\ + .forward(pred_tracks) + info['pred_tracks'] = pred_tracks + self._save_one_video_gts_preds(video) + + def process_video(self, data_samples): + + video_len = len(data_samples) + for frame_id in range(video_len): + img_data_sample = data_samples[frame_id].to_dict() + # load basic info + video = img_data_sample['img_path'].split(os.sep)[-3] + if self.seq_info[video]['seq_length'] == -1: + self.seq_info[video]['seq_length'] = video_len + self.transform_gt_and_pred(img_data_sample, video, frame_id) + + if self.postprocess_tracklet_cfg: + info = self.seq_info[video] + pred_tracks = np.array(info['pred_tracks']) + for postprocess_tracklet_methods in \ + self.postprocess_tracklet_methods: + pred_tracks = postprocess_tracklet_methods \ + .forward(pred_tracks) + info['pred_tracks'] = pred_tracks + self._save_one_video_gts_preds(video) + + def _save_one_video_gts_preds(self, seq: str) -> None: + """Save the gt and prediction results.""" + info = self.seq_info[seq] + # save predictions + pred_file = osp.join(self.pred_dir, seq + '.txt') + + pred_tracks = np.array(info['pred_tracks']) + + with open(pred_file, 'wt') as f: + for tracks in pred_tracks: + line = '%d,%d,%.3f,%.3f,%.3f,%.3f,%.3f,-1,-1,-1\n' % ( + tracks[0], tracks[1], tracks[2], tracks[3], tracks[4], + tracks[5], tracks[6]) + f.writelines(line) + + info['pred_tracks'] = [] + # save gts + if info['gt_tracks']: + gt_file = osp.join(self.gt_dir, seq + '.txt') + with open(gt_file, 'wt') as f: + for tracks in info['gt_tracks']: + line = '%d,%d,%d,%d,%d,%d,%d,%d,%.5f\n' % ( + tracks[0], tracks[1], tracks[2], tracks[3], tracks[4], + tracks[5], tracks[6], tracks[7], tracks[8]) + f.writelines(line) + info['gt_tracks'].clear() + # save seq info + with open(self.seqmap, 'a') as f: + f.write(seq + '\n') + f.close() + + def compute_metrics(self, results: list = None) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Defaults to None. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # NOTICE: don't access `self.results` from the method. + eval_results = dict() + + if self.format_only: + return eval_results + + eval_config = trackeval.Evaluator.get_default_eval_config() + + # need to split out the tracker name + # caused by the implementation of TrackEval + pred_dir_tmp = self.pred_dir.rsplit(osp.sep, 1)[0] + dataset_config = self.get_dataset_cfg(self.gt_dir, pred_dir_tmp) + + evaluator = trackeval.Evaluator(eval_config) + dataset = [trackeval.datasets.MotChallenge2DBox(dataset_config)] + metrics = [ + getattr(trackeval.metrics, + metric)(dict(METRICS=[metric], THRESHOLD=0.5)) + for metric in self.metrics + ] + output_res, _ = evaluator.evaluate(dataset, metrics) + output_res = output_res['MotChallenge2DBox'][ + self.TRACKER]['COMBINED_SEQ']['pedestrian'] + + if 'HOTA' in self.metrics: + logger.info('Evaluating HOTA Metrics...') + eval_results['HOTA'] = np.average(output_res['HOTA']['HOTA']) + eval_results['AssA'] = np.average(output_res['HOTA']['AssA']) + eval_results['DetA'] = np.average(output_res['HOTA']['DetA']) + + if 'CLEAR' in self.metrics: + logger.info('Evaluating CLEAR Metrics...') + eval_results['MOTA'] = np.average(output_res['CLEAR']['MOTA']) + eval_results['MOTP'] = np.average(output_res['CLEAR']['MOTP']) + eval_results['IDSW'] = np.average(output_res['CLEAR']['IDSW']) + eval_results['TP'] = np.average(output_res['CLEAR']['CLR_TP']) + eval_results['FP'] = np.average(output_res['CLEAR']['CLR_FP']) + eval_results['FN'] = np.average(output_res['CLEAR']['CLR_FN']) + eval_results['Frag'] = np.average(output_res['CLEAR']['Frag']) + eval_results['MT'] = np.average(output_res['CLEAR']['MT']) + eval_results['ML'] = np.average(output_res['CLEAR']['ML']) + + if 'Identity' in self.metrics: + logger.info('Evaluating Identity Metrics...') + eval_results['IDF1'] = np.average(output_res['Identity']['IDF1']) + eval_results['IDTP'] = np.average(output_res['Identity']['IDTP']) + eval_results['IDFN'] = np.average(output_res['Identity']['IDFN']) + eval_results['IDFP'] = np.average(output_res['Identity']['IDFP']) + eval_results['IDP'] = np.average(output_res['Identity']['IDP']) + eval_results['IDR'] = np.average(output_res['Identity']['IDR']) + + return eval_results + + def evaluate(self, size: int = 1) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + Defaults to None. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + # wait for all processes to complete prediction. + barrier() + + # gather seq_info and convert the list of dict to a dict. + # convert self.seq_info to dict first to make it picklable. + gathered_seq_info = all_gather_object(dict(self.seq_info)) + all_seq_info = dict() + for _seq_info in gathered_seq_info: + all_seq_info.update(_seq_info) + self.seq_info = all_seq_info + + if is_main_process(): + _metrics = self.compute_metrics() # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results list + self.results.clear() + return metrics[0] + + def get_dataset_cfg(self, gt_folder: str, tracker_folder: str): + """Get default configs for trackeval.datasets.MotChallenge2DBox. + + Args: + gt_folder (str): the name of the GT folder + tracker_folder (str): the name of the tracker folder + + Returns: + Dataset Configs for MotChallenge2DBox. + """ + dataset_config = dict( + # Location of GT data + GT_FOLDER=gt_folder, + # Trackers location + TRACKERS_FOLDER=tracker_folder, + # Where to save eval results + # (if None, same as TRACKERS_FOLDER) + OUTPUT_FOLDER=None, + # Use self.TRACKER as the default tracker + TRACKERS_TO_EVAL=[self.TRACKER], + # Option values: ['pedestrian'] + CLASSES_TO_EVAL=['pedestrian'], + # Option Values: 'MOT15', 'MOT16', 'MOT17', 'MOT20', 'DanceTrack' + BENCHMARK=self.benchmark, + # Option Values: 'train', 'test' + SPLIT_TO_EVAL='val' if self.benchmark == 'DanceTrack' else 'train', + # Whether tracker input files are zipped + INPUT_AS_ZIP=False, + # Whether to print current config + PRINT_CONFIG=True, + # Whether to perform preprocessing + # (never done for MOT15) + DO_PREPROC=False if self.benchmark == 'MOT15' else True, + # Tracker files are in + # TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER + TRACKER_SUB_FOLDER='', + # Output files are saved in + # OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER + OUTPUT_SUB_FOLDER='', + # Names of trackers to display + # (if None: TRACKERS_TO_EVAL) + TRACKER_DISPLAY_NAMES=None, + # Where seqmaps are found + # (if None: GT_FOLDER/seqmaps) + SEQMAP_FOLDER=None, + # Directly specify seqmap file + # (if none use seqmap_folder/benchmark-split_to_eval) + SEQMAP_FILE=self.seqmap, + # If not None, specify sequences to eval + # and their number of timesteps + SEQ_INFO={ + seq: info['seq_length'] + for seq, info in self.seq_info.items() + }, + # '{gt_folder}/{seq}.txt' + GT_LOC_FORMAT='{gt_folder}/{seq}.txt', + # If False, data is in GT_FOLDER/BENCHMARK-SPLIT_TO_EVAL/ and in + # TRACKERS_FOLDER/BENCHMARK-SPLIT_TO_EVAL/tracker/ + # If True, the middle 'benchmark-split' folder is skipped for both. + SKIP_SPLIT_FOL=True, + ) + + return dataset_config diff --git a/mmdet/evaluation/metrics/openimages_metric.py b/mmdet/evaluation/metrics/openimages_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d75c59e0e711c90bb1e5fbcc1529e95864e99e9a --- /dev/null +++ b/mmdet/evaluation/metrics/openimages_metric.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import OrderedDict +from typing import List, Optional, Sequence, Union + +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log + +from mmdet.registry import METRICS +from ..functional import eval_map + + +@METRICS.register_module() +class OpenImagesMetric(BaseMetric): + """OpenImages evaluation metric. + + Evaluate detection mAP for OpenImages. Please refer to + https://storage.googleapis.com/openimages/web/evaluation.html for more + details. + + Args: + iou_thrs (float or List[float]): IoU threshold. Defaults to 0.5. + ioa_thrs (float or List[float]): IoA threshold. Defaults to 0.5. + scale_ranges (List[tuple], optional): Scale ranges for evaluating + mAP. If not specified, all bounding boxes would be included in + evaluation. Defaults to None + use_group_of (bool): Whether consider group of groud truth bboxes + during evaluating. Defaults to True. + get_supercategory (bool): Whether to get parent class of the + current class. Default: True. + filter_labels (bool): Whether filter unannotated classes. + Default: True. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + default_prefix: Optional[str] = 'openimages' + + def __init__(self, + iou_thrs: Union[float, List[float]] = 0.5, + ioa_thrs: Union[float, List[float]] = 0.5, + scale_ranges: Optional[List[tuple]] = None, + use_group_of: bool = True, + get_supercategory: bool = True, + filter_labels: bool = True, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.iou_thrs = [iou_thrs] if isinstance(iou_thrs, float) else iou_thrs + self.ioa_thrs = [ioa_thrs] if (isinstance(ioa_thrs, float) + or ioa_thrs is None) else ioa_thrs + assert isinstance(self.iou_thrs, list) and isinstance( + self.ioa_thrs, list) + assert len(self.iou_thrs) == len(self.ioa_thrs) + + self.scale_ranges = scale_ranges + self.use_group_of = use_group_of + self.get_supercategory = get_supercategory + self.filter_labels = filter_labels + + def _get_supercategory_ann(self, instances: List[dict]) -> List[dict]: + """Get parent classes's annotation of the corresponding class. + + Args: + instances (List[dict]): A list of annotations of the instances. + + Returns: + List[dict]: Annotations extended with super-category. + """ + supercat_instances = [] + relation_matrix = self.dataset_meta['RELATION_MATRIX'] + for instance in instances: + labels = np.where(relation_matrix[instance['bbox_label']])[0] + for label in labels: + if label == instance['bbox_label']: + continue + new_instance = copy.deepcopy(instance) + new_instance['bbox_label'] = label + supercat_instances.append(new_instance) + return supercat_instances + + def _process_predictions(self, pred_bboxes: np.ndarray, + pred_scores: np.ndarray, pred_labels: np.ndarray, + gt_instances: list, + image_level_labels: np.ndarray) -> tuple: + """Process results of the corresponding class of the detection bboxes. + + Note: It will choose to do the following two processing according to + the parameters: + + 1. Whether to add parent classes of the corresponding class of the + detection bboxes. + + 2. Whether to ignore the classes that unannotated on that image. + + Args: + pred_bboxes (np.ndarray): bboxes predicted by the model + pred_scores (np.ndarray): scores predicted by the model + pred_labels (np.ndarray): labels predicted by the model + gt_instances (list): ground truth annotations + image_level_labels (np.ndarray): human-verified image level labels + + Returns: + tuple: Processed bboxes, scores, and labels. + """ + processed_bboxes = copy.deepcopy(pred_bboxes) + processed_scores = copy.deepcopy(pred_scores) + processed_labels = copy.deepcopy(pred_labels) + gt_labels = np.array([ins['bbox_label'] for ins in gt_instances], + dtype=np.int64) + if image_level_labels is not None: + allowed_classes = np.unique( + np.append(gt_labels, image_level_labels)) + else: + allowed_classes = np.unique(gt_labels) + relation_matrix = self.dataset_meta['RELATION_MATRIX'] + pred_classes = np.unique(pred_labels) + for pred_class in pred_classes: + classes = np.where(relation_matrix[pred_class])[0] + for cls in classes: + if (cls in allowed_classes and cls != pred_class + and self.get_supercategory): + # add super-supercategory preds + index = np.where(pred_labels == pred_class)[0] + processed_scores = np.concatenate( + [processed_scores, pred_scores[index]]) + processed_bboxes = np.concatenate( + [processed_bboxes, pred_bboxes[index]]) + extend_labels = np.full(index.shape, cls, dtype=np.int64) + processed_labels = np.concatenate( + [processed_labels, extend_labels]) + elif cls not in allowed_classes and self.filter_labels: + # remove unannotated preds + index = np.where(processed_labels != cls)[0] + processed_scores = processed_scores[index] + processed_bboxes = processed_bboxes[index] + processed_labels = processed_labels[index] + return processed_bboxes, processed_scores, processed_labels + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + gt = copy.deepcopy(data_sample) + # add super-category instances + # TODO: Need to refactor to support LoadAnnotations + instances = gt['instances'] + if self.get_supercategory: + supercat_instances = self._get_supercategory_ann(instances) + instances.extend(supercat_instances) + gt_labels = [] + gt_bboxes = [] + is_group_ofs = [] + for ins in instances: + gt_labels.append(ins['bbox_label']) + gt_bboxes.append(ins['bbox']) + is_group_ofs.append(ins['is_group_of']) + ann = dict( + labels=np.array(gt_labels, dtype=np.int64), + bboxes=np.array(gt_bboxes, dtype=np.float32).reshape((-1, 4)), + gt_is_group_ofs=np.array(is_group_ofs, dtype=bool)) + + image_level_labels = gt.get('image_level_labels', None) + pred = data_sample['pred_instances'] + pred_bboxes = pred['bboxes'].cpu().numpy() + pred_scores = pred['scores'].cpu().numpy() + pred_labels = pred['labels'].cpu().numpy() + + pred_bboxes, pred_scores, pred_labels = self._process_predictions( + pred_bboxes, pred_scores, pred_labels, instances, + image_level_labels) + + dets = [] + for label in range(len(self.dataset_meta['classes'])): + index = np.where(pred_labels == label)[0] + pred_bbox_scores = np.hstack( + [pred_bboxes[index], pred_scores[index].reshape((-1, 1))]) + dets.append(pred_bbox_scores) + self.results.append((ann, dets)) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + logger = MMLogger.get_current_instance() + gts, preds = zip(*results) + eval_results = OrderedDict() + # get dataset type + dataset_type = self.dataset_meta.get('dataset_type') + if dataset_type not in ['oid_challenge', 'oid_v6']: + dataset_type = 'oid_v6' + print_log( + 'Cannot infer dataset type from the length of the' + ' classes. Set `oid_v6` as dataset type.', + logger='current') + mean_aps = [] + for i, (iou_thr, + ioa_thr) in enumerate(zip(self.iou_thrs, self.ioa_thrs)): + if self.use_group_of: + assert ioa_thr is not None, 'ioa_thr must have value when' \ + ' using group_of in evaluation.' + print_log(f'\n{"-" * 15}iou_thr, ioa_thr: {iou_thr}, {ioa_thr}' + f'{"-" * 15}') + mean_ap, _ = eval_map( + preds, + gts, + scale_ranges=self.scale_ranges, + iou_thr=iou_thr, + ioa_thr=ioa_thr, + dataset=dataset_type, + logger=logger, + use_group_of=self.use_group_of) + + mean_aps.append(mean_ap) + eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) + eval_results['mAP'] = sum(mean_aps) / len(mean_aps) + return eval_results diff --git a/mmdet/evaluation/metrics/refseg_metric.py b/mmdet/evaluation/metrics/refseg_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..0faee07007e809ef08e86a88e8b11c2be1a64034 --- /dev/null +++ b/mmdet/evaluation/metrics/refseg_metric.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from mmengine.evaluator import BaseMetric + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class RefSegMetric(BaseMetric): + """Referring Expression Segmentation Metric.""" + + def __init__(self, metric: Sequence = ('cIoU', 'mIoU'), **kwargs): + super().__init__(**kwargs) + assert set(metric).issubset(['cIoU', 'mIoU']), \ + f'Only support cIoU and mIoU, but got {metric}' + assert len(metric) > 0, 'metrics should not be empty' + self.metrics = metric + + def compute_iou(self, pred_seg: torch.Tensor, + gt_seg: torch.Tensor) -> tuple: + overlap = pred_seg & gt_seg + union = pred_seg | gt_seg + return overlap, union + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_label = data_sample['pred_instances']['masks'].bool() + label = data_sample['gt_masks'].to_tensor( + pred_label.dtype, pred_label.device).bool() + # calculate iou + overlap, union = self.compute_iou(pred_label, label) + + bs = len(pred_label) + iou = overlap.reshape(bs, -1).sum(-1) * 1.0 / union.reshape( + bs, -1).sum(-1) + iou = torch.nan_to_num_(iou, nan=0.0) + self.results.append((overlap.sum(), union.sum(), iou.sum(), bs)) + + def compute_metrics(self, results: list) -> dict: + results = tuple(zip(*results)) + assert len(results) == 4 + cum_i = sum(results[0]) + cum_u = sum(results[1]) + iou = sum(results[2]) + seg_total = sum(results[3]) + + metrics = {} + if 'cIoU' in self.metrics: + metrics['cIoU'] = cum_i * 100 / cum_u + if 'mIoU' in self.metrics: + metrics['mIoU'] = iou * 100 / seg_total + return metrics diff --git a/mmdet/evaluation/metrics/reid_metric.py b/mmdet/evaluation/metrics/reid_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d74df1433cdb093cfb0377b734fc5479401e09e7 --- /dev/null +++ b/mmdet/evaluation/metrics/reid_metric.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class ReIDMetrics(BaseMetric): + """mAP and CMC evaluation metrics for the ReID task. + + Args: + metric (str | list[str]): Metrics to be evaluated. + Default value is `mAP`. + metric_options: (dict, optional): Options for calculating metrics. + Allowed keys are 'rank_list' and 'max_rank'. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + """ + allowed_metrics = ['mAP', 'CMC'] + default_prefix: Optional[str] = 'reid-metric' + + def __init__(self, + metric: Union[str, Sequence[str]] = 'mAP', + metric_options: Optional[dict] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + + if isinstance(metric, list): + metrics = metric + elif isinstance(metric, str): + metrics = [metric] + else: + raise TypeError('metric must be a list or a str.') + for metric in metrics: + if metric not in self.allowed_metrics: + raise KeyError(f'metric {metric} is not supported.') + self.metrics = metrics + + self.metric_options = metric_options or dict( + rank_list=[1, 5, 10, 20], max_rank=20) + for rank in self.metric_options['rank_list']: + assert 1 <= rank <= self.metric_options['max_rank'] + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + pred_feature = data_sample['pred_feature'] + assert isinstance(pred_feature, torch.Tensor) + gt_label = data_sample.get('gt_label', data_sample['gt_label']) + assert isinstance(gt_label['label'], torch.Tensor) + result = dict( + pred_feature=pred_feature.data.cpu(), + gt_label=gt_label['label'].cpu()) + self.results.append(result) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = {} + + pids = torch.cat([result['gt_label'] for result in results]).numpy() + features = torch.stack([result['pred_feature'] for result in results]) + + n, c = features.size() + mat = torch.pow(features, 2).sum(dim=1, keepdim=True).expand(n, n) + distmat = mat + mat.t() + distmat.addmm_(features, features.t(), beta=1, alpha=-2) + distmat = distmat.numpy() + + indices = np.argsort(distmat, axis=1) + matches = (pids[indices] == pids[:, np.newaxis]).astype(np.int32) + + all_cmc = [] + all_AP = [] + num_valid_q = 0. + for q_idx in range(n): + # remove self + raw_cmc = matches[q_idx][1:] + if not np.any(raw_cmc): + # this condition is true when query identity + # does not appear in gallery + continue + + cmc = raw_cmc.cumsum() + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:self.metric_options['max_rank']]) + num_valid_q += 1. + + # compute average precision + num_rel = raw_cmc.sum() + tmp_cmc = raw_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * raw_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + + assert num_valid_q > 0, \ + 'Error: all query identities do not appear in gallery' + + all_cmc = np.asarray(all_cmc) + all_cmc = all_cmc.sum(0) / num_valid_q + mAP = np.mean(all_AP) + + if 'mAP' in self.metrics: + metrics['mAP'] = np.around(mAP, decimals=3) + if 'CMC' in self.metrics: + for rank in self.metric_options['rank_list']: + metrics[f'R{rank}'] = np.around(all_cmc[rank - 1], decimals=3) + + return metrics diff --git a/mmdet/evaluation/metrics/semseg_metric.py b/mmdet/evaluation/metrics/semseg_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..3215f6788a6155bdbceb6a91259008b4d851868e --- /dev/null +++ b/mmdet/evaluation/metrics/semseg_metric.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict +from typing import Dict, Optional, Sequence, Union + +import numpy as np +import torch +from mmcv import imwrite +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image + +try: + from prettytable import PrettyTable +except ImportError: + PrettyTable = None + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class SemSegMetric(BaseMetric): + """mIoU evaluation metric. + + Args: + iou_metrics (list[str] | str): Metrics to be calculated, the options + includes 'mIoU', 'mDice' and 'mFscore'. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + iou_metrics: Sequence[str] = ['mIoU'], + beta: int = 1, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + backend_args: dict = None, + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(iou_metrics, str): + iou_metrics = [iou_metrics] + if not set(iou_metrics).issubset(set(['mIoU', 'mDice', 'mFscore'])): + raise KeyError(f'metrics {iou_metrics} is not supported. ' + f'Only supports mIoU/mDice/mFscore.') + self.metrics = iou_metrics + self.beta = beta + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + self.backend_args = backend_args + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['sem_seg'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + label = data_sample['gt_sem_seg']['sem_seg'].squeeze().to( + pred_label) + ignore_index = data_sample['pred_sem_seg'].get( + 'ignore_index', 255) + self.results.append( + self._compute_pred_stats(pred_label, label, num_classes, + ignore_index)) + + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy() + output = Image.fromarray(output_mask.astype(np.uint8)) + imwrite(output, png_filename, backend_args=self.backend_args) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + ret_metrics = self.get_return_metrics(results) + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + metrics[key] = val + else: + metrics['m' + key] = val + + print_semantic_table(ret_metrics, self.dataset_meta['classes'], logger) + + return metrics + + def _compute_pred_stats(self, pred_label: torch.tensor, + label: torch.tensor, num_classes: int, + ignore_index: int): + """Parse semantic segmentation predictions. + + Args: + pred_label (torch.tensor): Prediction segmentation map + or predict result filename. The shape is (H, W). + label (torch.tensor): Ground truth segmentation map + or label filename. The shape is (H, W). + num_classes (int): Number of categories. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + assert pred_label.shape == label.shape + mask = label != ignore_index + label, pred_label = label[mask], pred_label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=num_classes, min=0, max=num_classes - 1) + area_pred_label = torch.histc( + pred_label.float(), bins=num_classes, min=0, max=num_classes - 1) + area_label = torch.histc( + label.float(), bins=num_classes, min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + result = dict( + area_intersect=area_intersect, + area_union=area_union, + area_pred_label=area_pred_label, + area_label=area_label) + return result + + def get_return_metrics(self, results: list) -> dict: + """Calculate evaluation metrics. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, np.ndarray]: per category evaluation metrics, + shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined + score. Default: 1. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + total_area_intersect = sum([r['area_intersect'] for r in results]) + total_area_union = sum([r['area_union'] for r in results]) + total_area_pred_label = sum([r['area_pred_label'] for r in results]) + total_area_label = sum([r['area_label'] for r in results]) + + all_acc = total_area_intersect / total_area_label + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in self.metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([ + f_score(x[0], x[1], self.beta) + for x in zip(precision, recall) + ]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.cpu().numpy() + for metric, value in ret_metrics.items() + } + + return ret_metrics + + +def print_semantic_table( + results: dict, + class_names: list, + logger: Optional[Union['MMLogger', str]] = None) -> None: + """Print semantic segmentation evaluation results table. + + Args: + results (dict): The evaluation results. + class_names (list): Class names. + logger (MMLogger | str, optional): Logger used for printing. + Default: None. + """ + # each class table + results.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in results.items() + }) + + print_log('per class results:', logger) + if PrettyTable: + class_table_data = PrettyTable() + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + print_log('\n' + class_table_data.get_string(), logger=logger) + else: + logger.warning( + '`prettytable` is not installed, for better table format, ' + 'please consider installing it with "pip install prettytable"') + print_result = {} + for class_name, iou, acc in zip(class_names, ret_metrics_class['IoU'], + ret_metrics_class['Acc']): + print_result[class_name] = {'IoU': iou, 'Acc': acc} + print_log(print_result, logger) diff --git a/mmdet/evaluation/metrics/voc_metric.py b/mmdet/evaluation/metrics/voc_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..32d8c075de9c8b4fb842ad7f64f87a10c4d68546 --- /dev/null +++ b/mmdet/evaluation/metrics/voc_metric.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from collections import OrderedDict +from typing import List, Optional, Sequence, Union + +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmdet.registry import METRICS +from ..functional import eval_map, eval_recalls + + +@METRICS.register_module() +class VOCMetric(BaseMetric): + """Pascal VOC evaluation metric. + + Args: + iou_thrs (float or List[float]): IoU threshold. Defaults to 0.5. + scale_ranges (List[tuple], optional): Scale ranges for evaluating + mAP. If not specified, all bounding boxes would be included in + evaluation. Defaults to None. + metric (str | list[str]): Metrics to be evaluated. Options are + 'mAP', 'recall'. If is list, the first setting in the list will + be used to evaluate metric. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + eval_mode (str): 'area' or '11points', 'area' means calculating the + area under precision-recall curve, '11points' means calculating + the average precision of recalls at [0, 0.1, ..., 1]. + The PASCAL VOC2007 defaults to use '11points', while PASCAL + VOC2012 defaults to use 'area'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + default_prefix: Optional[str] = 'pascal_voc' + + def __init__(self, + iou_thrs: Union[float, List[float]] = 0.5, + scale_ranges: Optional[List[tuple]] = None, + metric: Union[str, List[str]] = 'mAP', + proposal_nums: Sequence[int] = (100, 300, 1000), + eval_mode: str = '11points', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.iou_thrs = [iou_thrs] if isinstance(iou_thrs, float) \ + else iou_thrs + self.scale_ranges = scale_ranges + # voc evaluation metrics + if not isinstance(metric, str): + assert len(metric) == 1 + metric = metric[0] + allowed_metrics = ['recall', 'mAP'] + if metric not in allowed_metrics: + raise KeyError( + f"metric should be one of 'recall', 'mAP', but got {metric}.") + self.metric = metric + self.proposal_nums = proposal_nums + assert eval_mode in ['area', '11points'], \ + 'Unrecognized mode, only "area" and "11points" are supported' + self.eval_mode = eval_mode + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + gt = copy.deepcopy(data_sample) + # TODO: Need to refactor to support LoadAnnotations + gt_instances = gt['gt_instances'] + gt_ignore_instances = gt['ignored_instances'] + ann = dict( + labels=gt_instances['labels'].cpu().numpy(), + bboxes=gt_instances['bboxes'].cpu().numpy(), + bboxes_ignore=gt_ignore_instances['bboxes'].cpu().numpy(), + labels_ignore=gt_ignore_instances['labels'].cpu().numpy()) + + pred = data_sample['pred_instances'] + pred_bboxes = pred['bboxes'].cpu().numpy() + pred_scores = pred['scores'].cpu().numpy() + pred_labels = pred['labels'].cpu().numpy() + + dets = [] + for label in range(len(self.dataset_meta['classes'])): + index = np.where(pred_labels == label)[0] + pred_bbox_scores = np.hstack( + [pred_bboxes[index], pred_scores[index].reshape((-1, 1))]) + dets.append(pred_bbox_scores) + + self.results.append((ann, dets)) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + gts, preds = zip(*results) + eval_results = OrderedDict() + if self.metric == 'mAP': + assert isinstance(self.iou_thrs, list) + dataset_type = self.dataset_meta.get('dataset_type') + if dataset_type in ['VOC2007', 'VOC2012']: + dataset_name = 'voc' + if dataset_type == 'VOC2007' and self.eval_mode != '11points': + warnings.warn('Pascal VOC2007 uses `11points` as default ' + 'evaluate mode, but you are using ' + f'{self.eval_mode}.') + elif dataset_type == 'VOC2012' and self.eval_mode != 'area': + warnings.warn('Pascal VOC2012 uses `area` as default ' + 'evaluate mode, but you are using ' + f'{self.eval_mode}.') + else: + dataset_name = self.dataset_meta['classes'] + + mean_aps = [] + for iou_thr in self.iou_thrs: + logger.info(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') + # Follow the official implementation, + # http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar + # we should use the legacy coordinate system in mmdet 1.x, + # which means w, h should be computed as 'x2 - x1 + 1` and + # `y2 - y1 + 1` + mean_ap, _ = eval_map( + preds, + gts, + scale_ranges=self.scale_ranges, + iou_thr=iou_thr, + dataset=dataset_name, + logger=logger, + eval_mode=self.eval_mode, + use_legacy_coordinate=True) + mean_aps.append(mean_ap) + eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) + eval_results['mAP'] = sum(mean_aps) / len(mean_aps) + eval_results.move_to_end('mAP', last=False) + elif self.metric == 'recall': + gt_bboxes = [gt['bboxes'] for gt in gts] + pr_bboxes = [pred[0] for pred in preds] + recalls = eval_recalls( + gt_bboxes, + pr_bboxes, + self.proposal_nums, + self.iou_thrs, + logger=logger, + use_legacy_coordinate=True) + for i, num in enumerate(self.proposal_nums): + for j, iou_thr in enumerate(self.iou_thrs): + eval_results[f'recall@{num}@{iou_thr}'] = recalls[i, j] + if recalls.shape[1] > 1: + ar = recalls.mean(axis=1) + for i, num in enumerate(self.proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + return eval_results diff --git a/mmdet/evaluation/metrics/youtube_vis_metric.py b/mmdet/evaluation/metrics/youtube_vis_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..5abc77a591c7ee5d67cdf4dc4c4926c84894ba1d --- /dev/null +++ b/mmdet/evaluation/metrics/youtube_vis_metric.py @@ -0,0 +1,426 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +import warnings +import zipfile +from collections import OrderedDict, defaultdict +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmengine +import numpy as np +from mmengine.dist import (all_gather_object, barrier, broadcast_object_list, + is_main_process) +from mmengine.logging import MMLogger + +from mmdet.registry import METRICS +from mmdet.structures.mask import encode_mask_results +from ..functional import YTVIS, YTVISeval +from .base_video_metric import BaseVideoMetric, collect_tracking_results + + +@METRICS.register_module() +class YouTubeVISMetric(BaseVideoMetric): + """mAP evaluation metrics for the VIS task. + + Args: + metric (str | list[str]): Metrics to be evaluated. + Default value is `youtube_vis_ap`. + metric_items (List[str], optional): Metric result names to be + recorded in the evaluation result. Defaults to None. + outfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonyms metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + format_only (bool): If True, only formatting the results to the + official format and not performing evaluation. Defaults to False. + """ + + default_prefix: Optional[str] = 'youtube_vis' + + def __init__(self, + metric: Union[str, List[str]] = 'youtube_vis_ap', + metric_items: Optional[Sequence[str]] = None, + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + format_only: bool = False) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + # vis evaluation metrics + self.metrics = metric if isinstance(metric, list) else [metric] + self.format_only = format_only + allowed_metrics = ['youtube_vis_ap'] + for metric in self.metrics: + if metric not in allowed_metrics: + raise KeyError( + f"metric should be 'youtube_vis_ap', but got {metric}.") + + self.metric_items = metric_items + self.outfile_prefix = outfile_prefix + self.per_video_res = [] + self.categories = [] + self._vis_meta_info = defaultdict(list) # record video and image infos + + def process_video(self, data_samples): + + video_length = len(data_samples) + for frame_id in range(video_length): + result = dict() + img_data_sample = data_samples[frame_id].to_dict() + pred = img_data_sample['pred_track_instances'] + video_id = img_data_sample['video_id'] + + result['img_id'] = img_data_sample['img_id'] + result['bboxes'] = pred['bboxes'].cpu().numpy() + result['scores'] = pred['scores'].cpu().numpy() + result['labels'] = pred['labels'].cpu().numpy() + result['instances_id'] = pred['instances_id'].cpu().numpy() + # encode mask to RLE + assert 'masks' in pred, \ + 'masks must exist in YouTube-VIS metric' + result['masks'] = encode_mask_results( + pred['masks'].detach().cpu().numpy()) + + # parse gt + gt = dict() + gt['width'] = img_data_sample['ori_shape'][1] + gt['height'] = img_data_sample['ori_shape'][0] + gt['img_id'] = img_data_sample['img_id'] + gt['frame_id'] = frame_id + gt['video_id'] = video_id + gt['video_length'] = video_length + + if 'instances' in img_data_sample: + gt['anns'] = img_data_sample['instances'] + else: + gt['anns'] = dict() + self.per_video_res.append((result, gt)) + + preds, gts = zip(*self.per_video_res) + # format the results + # we must format gts first to update self._vis_meta_info + gt_results = self._format_one_video_gts(gts) + pred_results = self._format_one_video_preds(preds) + self.per_video_res.clear() + # add converted result to the results list + self.results.append((pred_results, gt_results)) + + def compute_metrics(self, results: List) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (List): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + # split gt and prediction list + tmp_pred_results, tmp_gt_results = zip(*results) + gt_results = self.format_gts(tmp_gt_results) + pred_results = self.format_preds(tmp_pred_results) + + if self.format_only: + self.save_pred_results(pred_results) + return dict() + + ytvis = YTVIS(gt_results) + + ytvis_dets = ytvis.loadRes(pred_results) + vid_ids = ytvis.getVidIds() + + iou_type = metric = 'segm' + eval_results = OrderedDict() + ytvisEval = YTVISeval(ytvis, ytvis_dets, iou_type) + ytvisEval.params.vidIds = vid_ids + ytvisEval.evaluate() + ytvisEval.accumulate() + ytvisEval.summarize() + + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@1': 6, + 'AR@10': 7, + 'AR@100': 8, + 'AR_s@100': 9, + 'AR_m@100': 10, + 'AR_l@100': 11 + } + metric_items = self.metric_items + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError( + f'metric item "{metric_item}" is not supported') + + if metric_items is None: + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = float( + f'{ytvisEval.stats[coco_metric_names[metric_item]]:.3f}') + eval_results[key] = val + + return eval_results + + def format_gts(self, gts: Tuple[List]) -> dict: + """Gather all ground-truth from self.results.""" + self.categories = [ + dict(id=id + 1, name=name) + for id, name in enumerate(self.dataset_meta['classes']) + ] + gt_results = dict( + categories=self.categories, + videos=self._vis_meta_info['videos'], + annotations=[]) + for gt_result in gts: + gt_results['annotations'].extend(gt_result) + return gt_results + + def format_preds(self, preds: Tuple[List]) -> List: + """Gather all predictions from self.results.""" + pred_results = [] + for pred_result in preds: + pred_results.extend(pred_result) + return pred_results + + def _format_one_video_preds(self, pred_dicts: Tuple[dict]) -> List: + """Convert the annotation to the format of YouTube-VIS. + + This operation is to make it easier to use the official eval API. + + Args: + pred_dicts (Tuple[dict]): Prediction of the dataset. + + Returns: + List: The formatted predictions. + """ + # Collate preds scatters (tuple of dict to dict of list) + preds = defaultdict(list) + for pred in pred_dicts: + for key in pred.keys(): + preds[key].append(pred[key]) + + img_infos = self._vis_meta_info['images'] + vid_infos = self._vis_meta_info['videos'] + inds = [i for i, _ in enumerate(img_infos) if _['frame_id'] == 0] + inds.append(len(img_infos)) + json_results = [] + video_id = vid_infos[-1]['id'] + # collect data for each instances in a video. + collect_data = dict() + for frame_id, (masks, scores, labels, ids) in enumerate( + zip(preds['masks'], preds['scores'], preds['labels'], + preds['instances_id'])): + + assert len(masks) == len(labels) + for j, id in enumerate(ids): + if id not in collect_data: + collect_data[id] = dict( + category_ids=[], scores=[], segmentations=dict()) + collect_data[id]['category_ids'].append(labels[j]) + collect_data[id]['scores'].append(scores[j]) + if isinstance(masks[j]['counts'], bytes): + masks[j]['counts'] = masks[j]['counts'].decode() + collect_data[id]['segmentations'][frame_id] = masks[j] + + # transform the collected data into official format + for id, id_data in collect_data.items(): + output = dict() + output['video_id'] = video_id + output['score'] = np.array(id_data['scores']).mean().item() + # majority voting for sequence category + output['category_id'] = np.bincount( + np.array(id_data['category_ids'])).argmax().item() + 1 + output['segmentations'] = [] + for frame_id in range(inds[-1] - inds[-2]): + if frame_id in id_data['segmentations']: + output['segmentations'].append( + id_data['segmentations'][frame_id]) + else: + output['segmentations'].append(None) + json_results.append(output) + + return json_results + + def _format_one_video_gts(self, gt_dicts: Tuple[dict]) -> List: + """Convert the annotation to the format of YouTube-VIS. + + This operation is to make it easier to use the official eval API. + + Args: + gt_dicts (Tuple[dict]): Ground truth of the dataset. + + Returns: + list: The formatted gts. + """ + video_infos = [] + image_infos = [] + instance_infos = defaultdict(list) + len_videos = dict() # mapping from instance_id to video_length + vis_anns = [] + + # get video infos + for gt_dict in gt_dicts: + frame_id = gt_dict['frame_id'] + video_id = gt_dict['video_id'] + img_id = gt_dict['img_id'] + image_info = dict( + id=img_id, + width=gt_dict['width'], + height=gt_dict['height'], + frame_id=frame_id, + file_name='') + image_infos.append(image_info) + if frame_id == 0: + video_info = dict( + id=video_id, + width=gt_dict['width'], + height=gt_dict['height'], + file_name='') + video_infos.append(video_info) + + for ann in gt_dict['anns']: + label = ann['bbox_label'] + bbox = ann['bbox'] + instance_id = ann['instance_id'] + # update video length + len_videos[instance_id] = gt_dict['video_length'] + coco_bbox = [ + bbox[0], + bbox[1], + bbox[2] - bbox[0], + bbox[3] - bbox[1], + ] + + annotation = dict( + video_id=video_id, + frame_id=frame_id, + bbox=coco_bbox, + instance_id=instance_id, + iscrowd=ann.get('ignore_flag', 0), + category_id=int(label) + 1, + area=coco_bbox[2] * coco_bbox[3]) + if ann.get('mask', None): + mask = ann['mask'] + # area = mask_util.area(mask) + if isinstance(mask, dict) and isinstance( + mask['counts'], bytes): + mask['counts'] = mask['counts'].decode() + annotation['segmentation'] = mask + + instance_infos[instance_id].append(annotation) + + # update vis meta info + self._vis_meta_info['images'].extend(image_infos) + self._vis_meta_info['videos'].extend(video_infos) + + for instance_id, ann_infos in instance_infos.items(): + cur_video_len = len_videos[instance_id] + segm = [None] * cur_video_len + bbox = [None] * cur_video_len + area = [None] * cur_video_len + # In the official format, no instances are represented by + # 'None', however, only images with instances are recorded + # in the current annotations, so we need to use 'None' to + # initialize these lists. + for ann_info in ann_infos: + frame_id = ann_info['frame_id'] + segm[frame_id] = ann_info['segmentation'] + bbox[frame_id] = ann_info['bbox'] + area[frame_id] = ann_info['area'] + instance = dict( + category_id=ann_infos[0]['category_id'], + segmentations=segm, + bboxes=bbox, + video_id=ann_infos[0]['video_id'], + areas=area, + id=instance_id, + iscrowd=ann_infos[0]['iscrowd']) + vis_anns.append(instance) + return vis_anns + + def save_pred_results(self, pred_results: List) -> None: + """Save the results to a zip file (standard format for YouTube-VIS + Challenge). + + Args: + pred_results (list): Testing results of the + dataset. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + mmengine.dump(pred_results, f'{outfile_prefix}.json') + # zip the json file in order to submit to the test server. + zip_file_name = f'{outfile_prefix}.submission_file.zip' + zf = zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) + logger.info(f"zip the 'results.json' into '{zip_file_name}', " + 'please submmit the zip file to the test server') + zf.write(f'{outfile_prefix}.json', 'results.json') + zf.close() + + def evaluate(self, size: int) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + # wait for all processes to complete prediction. + barrier() + + if len(self.results) == 0: + warnings.warn( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.') + + results = collect_tracking_results(self.results, self.collect_device) + + # gather seq_info + gathered_seq_info = all_gather_object(self._vis_meta_info['videos']) + all_seq_info = [] + for _seq_info in gathered_seq_info: + all_seq_info.extend(_seq_info) + # update self._vis_meta_info + self._vis_meta_info = dict(videos=all_seq_info) + + if is_main_process(): + _metrics = self.compute_metrics(results) # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results list + self.results.clear() + # reset the vis_meta_info + self._vis_meta_info.clear() + return metrics[0] diff --git a/mmdet/models/.DS_Store b/mmdet/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fd1ae5e4004bc1275db7485bfd056ed249fa79e9 Binary files /dev/null and b/mmdet/models/.DS_Store differ diff --git a/mmdet/models/__init__.py b/mmdet/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a0d5e8d350d81e72787ff73fd85c2176783b43 --- /dev/null +++ b/mmdet/models/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .data_preprocessors import * # noqa: F401,F403 +from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 +from .language_models import * # noqa: F401,F403 +from .layers import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .mot import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .reid import * # noqa: F401,F403 +from .roi_heads import * # noqa: F401,F403 +from .seg_heads import * # noqa: F401,F403 +from .task_modules import * # noqa: F401,F403 +from .test_time_augs import * # noqa: F401,F403 +from .trackers import * # noqa: F401,F403 +from .tracking_heads import * # noqa: F401,F403 +from .vis import * # noqa: F401,F403 diff --git a/mmdet/models/__pycache__/__init__.cpython-311.pyc b/mmdet/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e9e5d8f6109c874d6a7abc3e43f1be9c55bebe8 Binary files /dev/null and b/mmdet/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e16ff85f7037b36fb2046fcbcd3af523050a6516 --- /dev/null +++ b/mmdet/models/backbones/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .csp_darknet import CSPDarknet +from .cspnext import CSPNeXt +from .darknet import Darknet +from .detectors_resnet import DetectoRS_ResNet +from .detectors_resnext import DetectoRS_ResNeXt +from .efficientnet import EfficientNet +from .hourglass import HourglassNet +from .hrnet import HRNet +from .mobilenet_v2 import MobileNetV2 +from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2 +from .regnet import RegNet +from .res2net import Res2Net +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1d +from .resnext import ResNeXt +from .ssd_vgg import SSDVGG +from .swin import SwinTransformer +from .trident_resnet import TridentResNet + +__all__ = [ + 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', + 'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet', + 'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet', + 'SwinTransformer', 'PyramidVisionTransformer', + 'PyramidVisionTransformerV2', 'EfficientNet', 'CSPNeXt' +] diff --git a/mmdet/models/backbones/__pycache__/__init__.cpython-311.pyc b/mmdet/models/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb1565f8c22b24533c1e8b1ede7a8f3a8751900d Binary files /dev/null and b/mmdet/models/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/models/backbones/__pycache__/csp_darknet.cpython-311.pyc b/mmdet/models/backbones/__pycache__/csp_darknet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e1e64169b8422160b42a4606b01f60f0d226ab6 Binary files /dev/null and b/mmdet/models/backbones/__pycache__/csp_darknet.cpython-311.pyc differ diff --git a/mmdet/models/backbones/csp_darknet.py b/mmdet/models/backbones/csp_darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..a890b486f255befa23fe5a3e9746f8f9298ac33f --- /dev/null +++ b/mmdet/models/backbones/csp_darknet.py @@ -0,0 +1,286 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS +from ..layers import CSPLayer + + +class Focus(nn.Module): + """Focus width and height information into channel space. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + kernel_size (int): The kernel size of the convolution. Default: 1 + stride (int): The stride of the convolution. Default: 1 + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', momentum=0.03, eps=0.001). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='Swish'). + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish')): + super().__init__() + self.conv = ConvModule( + in_channels * 4, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class SPPBottleneck(BaseModule): + """Spatial pyramid pooling layer used in YOLOv3-SPP. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + kernel_sizes (tuple[int]): Sequential of kernel sizes of pooling + layers. Default: (5, 9, 13). + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='Swish'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + conv_cfg=None, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + init_cfg=None): + super().__init__(init_cfg) + mid_channels = in_channels // 2 + self.conv1 = ConvModule( + in_channels, + mid_channels, + 1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.poolings = nn.ModuleList([ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = mid_channels * (len(kernel_sizes) + 1) + self.conv2 = ConvModule( + conv2_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.conv1(x) + with torch.cuda.amp.autocast(enabled=False): + x = torch.cat( + [x] + [pooling(x) for pooling in self.poolings], dim=1) + x = self.conv2(x) + return x + + +@MODELS.register_module() +class CSPDarknet(BaseModule): + """CSP-Darknet backbone used in YOLOv5 and YOLOX. + + Args: + arch (str): Architecture of CSP-Darknet, from {P5, P6}. + Default: P5. + deepen_factor (float): Depth multiplier, multiply number of + blocks in CSP layer by this amount. Default: 1.0. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, 3, 4). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + use_depthwise (bool): Whether to use depthwise separable convolution. + Default: False. + arch_ovewrite(list): Overwrite default arch settings. Default: None. + spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP + layers. Default: (5, 9, 13). + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmdet.models import CSPDarknet + >>> import torch + >>> self = CSPDarknet(depth=53) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + # From left to right: + # in_channels, out_channels, num_blocks, add_identity, use_spp + arch_settings = { + 'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False], + [256, 512, 9, True, False], [512, 1024, 3, False, True]], + 'P6': [[64, 128, 3, True, False], [128, 256, 9, True, False], + [256, 512, 9, True, False], [512, 768, 3, True, False], + [768, 1024, 3, False, True]] + } + + def __init__(self, + arch='P5', + deepen_factor=1.0, + widen_factor=1.0, + out_indices=(2, 3, 4), + frozen_stages=-1, + use_depthwise=False, + arch_ovewrite=None, + spp_kernal_sizes=(5, 9, 13), + conv_cfg=None, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + norm_eval=False, + init_cfg=dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu')): + super().__init__(init_cfg) + arch_setting = self.arch_settings[arch] + if arch_ovewrite: + arch_setting = arch_ovewrite + assert set(out_indices).issubset( + i for i in range(len(arch_setting) + 1)) + if frozen_stages not in range(-1, len(arch_setting) + 1): + raise ValueError('frozen_stages must be in range(-1, ' + 'len(arch_setting) + 1). But received ' + f'{frozen_stages}') + + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.use_depthwise = use_depthwise + self.norm_eval = norm_eval + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + + self.stem = Focus( + 3, + int(arch_setting[0][0] * widen_factor), + kernel_size=3, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.layers = ['stem'] + + for i, (in_channels, out_channels, num_blocks, add_identity, + use_spp) in enumerate(arch_setting): + in_channels = int(in_channels * widen_factor) + out_channels = int(out_channels * widen_factor) + num_blocks = max(round(num_blocks * deepen_factor), 1) + stage = [] + conv_layer = conv( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + stage.append(conv_layer) + if use_spp: + spp = SPPBottleneck( + out_channels, + out_channels, + kernel_sizes=spp_kernal_sizes, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + stage.append(spp) + csp_layer = CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + add_identity=add_identity, + use_depthwise=use_depthwise, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + stage.append(csp_layer) + self.add_module(f'stage{i + 1}', nn.Sequential(*stage)) + self.layers.append(f'stage{i + 1}') + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for i in range(self.frozen_stages + 1): + m = getattr(self, self.layers[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(CSPDarknet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmdet/models/backbones/cspnext.py b/mmdet/models/backbones/cspnext.py new file mode 100644 index 0000000000000000000000000000000000000000..269725a70224047a1f7f7564ba8199e38df25cc8 --- /dev/null +++ b/mmdet/models/backbones/cspnext.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence, Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule +from torch import Tensor +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from ..layers import CSPLayer +from .csp_darknet import SPPBottleneck + + +@MODELS.register_module() +class CSPNeXt(BaseModule): + """CSPNeXt backbone used in RTMDet. + + Args: + arch (str): Architecture of CSPNeXt, from {P5, P6}. + Defaults to P5. + expand_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Defaults to 0.5. + deepen_factor (float): Depth multiplier, multiply number of + blocks in CSP layer by this amount. Defaults to 1.0. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Defaults to 1.0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (2, 3, 4). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Defaults to -1. + use_depthwise (bool): Whether to use depthwise separable convolution. + Defaults to False. + arch_ovewrite (list): Overwrite default arch settings. + Defaults to None. + spp_kernel_sizes: (tuple[int]): Sequential of kernel sizes of SPP + layers. Defaults to (5, 9, 13). + channel_attention (bool): Whether to add channel attention in each + stage. Defaults to True. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and + config norm layer. Defaults to dict(type='BN', requires_grad=True). + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Defaults to dict(type='SiLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. + """ + # From left to right: + # in_channels, out_channels, num_blocks, add_identity, use_spp + arch_settings = { + 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False], + [256, 512, 6, True, False], [512, 1024, 3, False, True]], + 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False], + [256, 512, 6, True, False], [512, 768, 3, True, False], + [768, 1024, 3, False, True]] + } + + def __init__( + self, + arch: str = 'P5', + deepen_factor: float = 1.0, + widen_factor: float = 1.0, + out_indices: Sequence[int] = (2, 3, 4), + frozen_stages: int = -1, + use_depthwise: bool = False, + expand_ratio: float = 0.5, + arch_ovewrite: dict = None, + spp_kernel_sizes: Sequence[int] = (5, 9, 13), + channel_attention: bool = True, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='SiLU'), + norm_eval: bool = False, + init_cfg: OptMultiConfig = dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu') + ) -> None: + super().__init__(init_cfg=init_cfg) + arch_setting = self.arch_settings[arch] + if arch_ovewrite: + arch_setting = arch_ovewrite + assert set(out_indices).issubset( + i for i in range(len(arch_setting) + 1)) + if frozen_stages not in range(-1, len(arch_setting) + 1): + raise ValueError('frozen_stages must be in range(-1, ' + 'len(arch_setting) + 1). But received ' + f'{frozen_stages}') + + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.use_depthwise = use_depthwise + self.norm_eval = norm_eval + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + self.stem = nn.Sequential( + ConvModule( + 3, + int(arch_setting[0][0] * widen_factor // 2), + 3, + padding=1, + stride=2, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + int(arch_setting[0][0] * widen_factor // 2), + int(arch_setting[0][0] * widen_factor // 2), + 3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + int(arch_setting[0][0] * widen_factor // 2), + int(arch_setting[0][0] * widen_factor), + 3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.layers = ['stem'] + + for i, (in_channels, out_channels, num_blocks, add_identity, + use_spp) in enumerate(arch_setting): + in_channels = int(in_channels * widen_factor) + out_channels = int(out_channels * widen_factor) + num_blocks = max(round(num_blocks * deepen_factor), 1) + stage = [] + conv_layer = conv( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + stage.append(conv_layer) + if use_spp: + spp = SPPBottleneck( + out_channels, + out_channels, + kernel_sizes=spp_kernel_sizes, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + stage.append(spp) + csp_layer = CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + add_identity=add_identity, + use_depthwise=use_depthwise, + use_cspnext_block=True, + expand_ratio=expand_ratio, + channel_attention=channel_attention, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + stage.append(csp_layer) + self.add_module(f'stage{i + 1}', nn.Sequential(*stage)) + self.layers.append(f'stage{i + 1}') + + def _freeze_stages(self) -> None: + if self.frozen_stages >= 0: + for i in range(self.frozen_stages + 1): + m = getattr(self, self.layers[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True) -> None: + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def forward(self, x: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmdet/models/backbones/darknet.py b/mmdet/models/backbones/darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d44da1e03f04a7e0801c10e5338277cf6244ab1 --- /dev/null +++ b/mmdet/models/backbones/darknet.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +import warnings + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS + + +class ResBlock(BaseModule): + """The basic residual block used in Darknet. Each ResBlock consists of two + ConvModules and the input is added to the final output. Each ConvModule is + composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer + has half of the number of the filters as much as the second convLayer. The + first convLayer has filter size of 1x1 and the second one has the filter + size of 3x3. + + Args: + in_channels (int): The input channels. Must be even. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1), + init_cfg=None): + super(ResBlock, self).__init__(init_cfg) + assert in_channels % 2 == 0 # ensure the in_channels is even + half_in_channels = in_channels // 2 + + # shortcut + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + self.conv1 = ConvModule(in_channels, half_in_channels, 1, **cfg) + self.conv2 = ConvModule( + half_in_channels, in_channels, 3, padding=1, **cfg) + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.conv2(out) + out = out + residual + + return out + + +@MODELS.register_module() +class Darknet(BaseModule): + """Darknet backbone. + + Args: + depth (int): Depth of Darknet. Currently only support 53. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmdet.models import Darknet + >>> import torch + >>> self = Darknet(depth=53) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + + # Dict(depth: (layers, channels)) + arch_settings = { + 53: ((1, 2, 8, 8, 4), ((32, 64), (64, 128), (128, 256), (256, 512), + (512, 1024))) + } + + def __init__(self, + depth=53, + out_indices=(3, 4, 5), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1), + norm_eval=True, + pretrained=None, + init_cfg=None): + super(Darknet, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for darknet') + + self.depth = depth + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.layers, self.channels = self.arch_settings[depth] + + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + self.conv1 = ConvModule(3, 32, 3, padding=1, **cfg) + + self.cr_blocks = ['conv1'] + for i, n_layers in enumerate(self.layers): + layer_name = f'conv_res_block{i + 1}' + in_c, out_c = self.channels[i] + self.add_module( + layer_name, + self.make_conv_res_block(in_c, out_c, n_layers, **cfg)) + self.cr_blocks.append(layer_name) + + self.norm_eval = norm_eval + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.cr_blocks): + cr_block = getattr(self, layer_name) + x = cr_block(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for i in range(self.frozen_stages): + m = getattr(self, self.cr_blocks[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(Darknet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + @staticmethod + def make_conv_res_block(in_channels, + out_channels, + res_repeat, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', + negative_slope=0.1)): + """In Darknet backbone, ConvLayer is usually followed by ResBlock. This + function will make that. The Conv layers always have 3x3 filters with + stride=2. The number of the filters in Conv layer is the same as the + out channels of the ResBlock. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + res_repeat (int): The number of ResBlocks. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + """ + + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + model = nn.Sequential() + model.add_module( + 'conv', + ConvModule( + in_channels, out_channels, 3, stride=2, padding=1, **cfg)) + for idx in range(res_repeat): + model.add_module('res{}'.format(idx), + ResBlock(out_channels, **cfg)) + return model diff --git a/mmdet/models/backbones/detectors_resnet.py b/mmdet/models/backbones/detectors_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f33424fce4a933d675f1f1d3d4ad89e0173c5f9e --- /dev/null +++ b/mmdet/models/backbones/detectors_resnet.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.logging import MMLogger +from mmengine.model import Sequential, constant_init, kaiming_init +from mmengine.runner.checkpoint import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS +from .resnet import BasicBlock +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + r"""Bottleneck for the ResNet backbone in `DetectoRS + `_. + + This bottleneck allows the users to specify whether to use + SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid). + + Args: + inplanes (int): The number of input channels. + planes (int): The number of output channels before expansion. + rfp_inplanes (int, optional): The number of channels from RFP. + Default: None. If specified, an additional conv layer will be + added for ``rfp_feat``. Otherwise, the structure is the same as + base class. + sac (dict, optional): Dictionary to construct SAC. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + rfp_inplanes=None, + sac=None, + init_cfg=None, + **kwargs): + super(Bottleneck, self).__init__( + inplanes, planes, init_cfg=init_cfg, **kwargs) + + assert sac is None or isinstance(sac, dict) + self.sac = sac + self.with_sac = sac is not None + if self.with_sac: + self.conv2 = build_conv_layer( + self.sac, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False) + + self.rfp_inplanes = rfp_inplanes + if self.rfp_inplanes: + self.rfp_conv = build_conv_layer( + None, + self.rfp_inplanes, + planes * self.expansion, + 1, + stride=1, + bias=True) + if init_cfg is None: + self.init_cfg = dict( + type='Constant', val=0, override=dict(name='rfp_conv')) + + def rfp_forward(self, x, rfp_feat): + """The forward function that also takes the RFP features as input.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + if self.rfp_inplanes: + rfp_feat = self.rfp_conv(rfp_feat) + out = out + rfp_feat + + out = self.relu(out) + + return out + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone for RPF in detectoRS. + + The difference between this module and base class is that we pass + ``rfp_inplanes`` to the first block. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Default: True + rfp_inplanes (int, optional): The number of channels from RFP. + Default: None. If specified, an additional conv layer will be + added for ``rfp_feat``. Otherwise, the structure is the same as + base class. + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + downsample_first=True, + rfp_inplanes=None, + **kwargs): + self.block = block + assert downsample_first, f'downsample_first={downsample_first} is ' \ + 'not supported in DetectoRS' + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + rfp_inplanes=rfp_inplanes, + **kwargs)) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + + super(ResLayer, self).__init__(*layers) + + +@MODELS.register_module() +class DetectoRS_ResNet(ResNet): + """ResNet backbone for DetectoRS. + + Args: + sac (dict, optional): Dictionary to construct SAC (Switchable Atrous + Convolution). Default: None. + stage_with_sac (list): Which stage to use sac. Default: (False, False, + False, False). + rfp_inplanes (int, optional): The number of channels from RFP. + Default: None. If specified, an additional conv layer will be + added for ``rfp_feat``. Otherwise, the structure is the same as + base class. + output_img (bool): If ``True``, the input image will be inserted into + the starting position of output. Default: False. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + sac=None, + stage_with_sac=(False, False, False, False), + rfp_inplanes=None, + output_img=False, + pretrained=None, + init_cfg=None, + **kwargs): + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + self.pretrained = pretrained + if init_cfg is not None: + assert isinstance(init_cfg, dict), \ + f'init_cfg must be a dict, but got {type(init_cfg)}' + if 'type' in init_cfg: + assert init_cfg.get('type') == 'Pretrained', \ + 'Only can initialize module by loading a pretrained model' + else: + raise KeyError('`init_cfg` must contain the key "type"') + self.pretrained = init_cfg.get('checkpoint') + self.sac = sac + self.stage_with_sac = stage_with_sac + self.rfp_inplanes = rfp_inplanes + self.output_img = output_img + super(DetectoRS_ResNet, self).__init__(**kwargs) + + self.inplanes = self.stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + sac = self.sac if self.stage_with_sac[i] else None + if self.plugins is not None: + stage_plugins = self.make_stage_plugins(self.plugins, i) + else: + stage_plugins = None + planes = self.base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=dcn, + sac=sac, + rfp_inplanes=rfp_inplanes if i > 0 else None, + plugins=stage_plugins) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + # In order to be properly initialized by RFP + def init_weights(self): + # Calling this method will cause parameter initialization exception + # super(DetectoRS_ResNet, self).init_weights() + + if isinstance(self.pretrained, str): + logger = MMLogger.get_current_instance() + load_checkpoint(self, self.pretrained, strict=False, logger=logger) + elif self.pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + if self.dcn is not None: + for m in self.modules(): + if isinstance(m, Bottleneck) and hasattr( + m.conv2, 'conv_offset'): + constant_init(m.conv2.conv_offset, 0) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + else: + raise TypeError('pretrained must be a str or None') + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer`` for DetectoRS.""" + return ResLayer(**kwargs) + + def forward(self, x): + """Forward function.""" + outs = list(super(DetectoRS_ResNet, self).forward(x)) + if self.output_img: + outs.insert(0, x) + return tuple(outs) + + def rfp_forward(self, x, rfp_feats): + """Forward function for RFP.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + rfp_feat = rfp_feats[i] if i > 0 else None + for layer in res_layer: + x = layer.rfp_forward(x, rfp_feat) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmdet/models/backbones/detectors_resnext.py b/mmdet/models/backbones/detectors_resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbd63154bb47910e27cf6a75e4b359e050063e1 --- /dev/null +++ b/mmdet/models/backbones/detectors_resnext.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmdet.registry import MODELS +from .detectors_resnet import Bottleneck as _Bottleneck +from .detectors_resnet import DetectoRS_ResNet + + +class Bottleneck(_Bottleneck): + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_sac: + self.conv2 = build_conv_layer( + self.sac, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + elif not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class DetectoRS_ResNeXt(DetectoRS_ResNet): + """ResNeXt backbone for DetectoRS. + + Args: + groups (int): The number of groups in ResNeXt. + base_width (int): The base width of ResNeXt. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super(DetectoRS_ResNeXt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + return super().make_res_layer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmdet/models/backbones/efficientnet.py b/mmdet/models/backbones/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8484afe2e34e2bf8327e8aefedb968bd9a1e7792 --- /dev/null +++ b/mmdet/models/backbones/efficientnet.py @@ -0,0 +1,418 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import BaseModule, Sequential + +from mmdet.registry import MODELS +from ..layers import InvertedResidual, SELayer +from ..utils import make_divisible + + +class EdgeResidual(BaseModule): + """Edge Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the second convolution. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + stride (int): The stride of the first convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + with_residual (bool): Use residual connection. Defaults to True. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_residual=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None, + **kwargs): + super(EdgeResidual, self).__init__(init_cfg=init_cfg) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_residual = ( + stride == 1 and in_channels == out_channels and with_residual) + + if self.with_se: + assert isinstance(se_cfg, dict) + + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.conv2 = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + out = self.conv1(out) + + if self.with_se: + out = self.se(out) + + out = self.conv2(out) + + if self.with_residual: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +def model_scaling(layer_setting, arch_setting): + """Scaling operation to the layer's parameters according to the + arch_setting.""" + # scale width + new_layer_setting = copy.deepcopy(layer_setting) + for layer_cfg in new_layer_setting: + for block_cfg in layer_cfg: + block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8) + + # scale depth + split_layer_setting = [new_layer_setting[0]] + for layer_cfg in new_layer_setting[1:-1]: + tmp_index = [0] + for i in range(len(layer_cfg) - 1): + if layer_cfg[i + 1][1] != layer_cfg[i][1]: + tmp_index.append(i + 1) + tmp_index.append(len(layer_cfg)) + for i in range(len(tmp_index) - 1): + split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i + + 1]]) + split_layer_setting.append(new_layer_setting[-1]) + + num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]] + new_layers = [ + int(math.ceil(arch_setting[1] * num)) for num in num_of_layers + ] + + merge_layer_setting = [split_layer_setting[0]] + for i, layer_cfg in enumerate(split_layer_setting[1:-1]): + if new_layers[i] <= num_of_layers[i]: + tmp_layer_cfg = layer_cfg[:new_layers[i]] + else: + tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * ( + new_layers[i] - num_of_layers[i]) + if tmp_layer_cfg[0][3] == 1 and i != 0: + merge_layer_setting[-1] += tmp_layer_cfg.copy() + else: + merge_layer_setting.append(tmp_layer_cfg.copy()) + merge_layer_setting.append(split_layer_setting[-1]) + + return merge_layer_setting + + +@MODELS.register_module() +class EfficientNet(BaseModule): + """EfficientNet backbone. + + Args: + arch (str): Architecture of efficientnet. Defaults to b0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (6, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. + # 'b' represents the architecture of normal EfficientNet family includes + # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'. + # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es', + # 'em', 'el'. + # 6 parameters are needed to construct a layer, From left to right: + # - kernel_size: The kernel size of the block + # - out_channel: The number of out_channels of the block + # - se_ratio: The sequeeze ratio of SELayer. + # - stride: The stride of the block + # - expand_ratio: The expand_ratio of the mid_channels + # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual + layer_settings = { + 'b': [[[3, 32, 0, 2, 0, -1]], + [[3, 16, 4, 1, 1, 0]], + [[3, 24, 4, 2, 6, 0], + [3, 24, 4, 1, 6, 0]], + [[5, 40, 4, 2, 6, 0], + [5, 40, 4, 1, 6, 0]], + [[3, 80, 4, 2, 6, 0], + [3, 80, 4, 1, 6, 0], + [3, 80, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0]], + [[5, 192, 4, 2, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [3, 320, 4, 1, 6, 0]], + [[1, 1280, 0, 1, 0, -1]] + ], + 'e': [[[3, 32, 0, 2, 0, -1]], + [[3, 24, 0, 1, 3, 1]], + [[3, 32, 0, 2, 8, 1], + [3, 32, 0, 1, 8, 1]], + [[3, 48, 0, 2, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1]], + [[5, 96, 0, 2, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0]], + [[5, 192, 0, 2, 8, 0], + [5, 192, 0, 1, 8, 0]], + [[1, 1280, 0, 1, 0, -1]] + ] + } # yapf: disable + + # Parameters to build different kinds of architecture. + # From left to right: scaling factor for width, scaling factor for depth, + # resolution. + arch_settings = { + 'b0': (1.0, 1.0, 224), + 'b1': (1.0, 1.1, 240), + 'b2': (1.1, 1.2, 260), + 'b3': (1.2, 1.4, 300), + 'b4': (1.4, 1.8, 380), + 'b5': (1.6, 2.2, 456), + 'b6': (1.8, 2.6, 528), + 'b7': (2.0, 3.1, 600), + 'b8': (2.2, 3.6, 672), + 'es': (1.0, 1.0, 224), + 'em': (1.0, 1.1, 240), + 'el': (1.2, 1.4, 300) + } + + def __init__(self, + arch='b0', + drop_path_rate=0., + out_indices=(6, ), + frozen_stages=0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='Swish'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNet, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch_setting = self.arch_settings[arch] + self.layer_setting = self.layer_settings[arch[:1]] + for index in out_indices: + if index not in range(0, len(self.layer_setting)): + raise ValueError('the item in out_indices must in ' + f'range(0, {len(self.layer_setting)}). ' + f'But received {index}') + + if frozen_stages not in range(len(self.layer_setting) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.layer_setting) + 1}). ' + f'But received {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layer_setting = model_scaling(self.layer_setting, + self.arch_setting) + block_cfg_0 = self.layer_setting[0][0] + block_cfg_last = self.layer_setting[-1][0] + self.in_channels = make_divisible(block_cfg_0[1], 8) + self.out_channels = block_cfg_last[1] + self.layers = nn.ModuleList() + self.layers.append( + ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=block_cfg_0[0], + stride=block_cfg_0[3], + padding=block_cfg_0[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.make_layer() + # Avoid building unused layers in mmdetection. + if len(self.layers) < max(self.out_indices) + 1: + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=block_cfg_last[0], + stride=block_cfg_last[3], + padding=block_cfg_last[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def make_layer(self): + # Without the first and the final conv block. + layer_setting = self.layer_setting[1:-1] + + total_num_blocks = sum([len(x) for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for i, layer_cfg in enumerate(layer_setting): + # Avoid building unused layers in mmdetection. + if i > max(self.out_indices) - 1: + break + layer = [] + for i, block_cfg in enumerate(layer_cfg): + (kernel_size, out_channels, se_ratio, stride, expand_ratio, + block_type) = block_cfg + + mid_channels = int(self.in_channels * expand_ratio) + out_channels = make_divisible(out_channels, 8) + if se_ratio <= 0: + se_cfg = None + else: + # In mmdetection, the `divisor` is deleted to align + # the logic of SELayer with mmpretrain. + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * se_ratio, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + if block_type == 1: # edge tpu + if i > 0 and expand_ratio == 3: + with_residual = False + expand_ratio = 4 + else: + with_residual = True + mid_channels = int(self.in_channels * expand_ratio) + if se_cfg is not None: + # In mmdetection, the `divisor` is deleted to align + # the logic of SELayer with mmpretrain. + se_cfg = dict( + channels=mid_channels, + ratio=se_ratio * expand_ratio, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = partial(EdgeResidual, with_residual=with_residual) + else: + block = InvertedResidual + layer.append( + block( + in_channels=self.in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp, + # In mmdetection, `with_expand_conv` is set to align + # the logic of InvertedResidual with mmpretrain. + with_expand_conv=(mid_channels != self.in_channels))) + self.in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + def forward(self, x): + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmdet/models/backbones/hourglass.py b/mmdet/models/backbones/hourglass.py new file mode 100644 index 0000000000000000000000000000000000000000..bb58799f7b32138b3f58383419ddce9aa6d5ca18 --- /dev/null +++ b/mmdet/models/backbones/hourglass.py @@ -0,0 +1,225 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptMultiConfig +from ..layers import ResLayer +from .resnet import BasicBlock + + +class HourglassModule(BaseModule): + """Hourglass Module for HourglassNet backbone. + + Generate module recursively and use BasicBlock as the base unit. + + Args: + depth (int): Depth of current HourglassModule. + stage_channels (list[int]): Feature channels of sub-modules in current + and follow-up HourglassModule. + stage_blocks (list[int]): Number of sub-modules stacked in current and + follow-up HourglassModule. + norm_cfg (ConfigType): Dictionary to construct and config norm layer. + Defaults to `dict(type='BN', requires_grad=True)` + upsample_cfg (ConfigType): Config dict for interpolate layer. + Defaults to `dict(mode='nearest')` + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. + """ + + def __init__(self, + depth: int, + stage_channels: List[int], + stage_blocks: List[int], + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + upsample_cfg: ConfigType = dict(mode='nearest'), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg) + + self.depth = depth + + cur_block = stage_blocks[0] + next_block = stage_blocks[1] + + cur_channel = stage_channels[0] + next_channel = stage_channels[1] + + self.up1 = ResLayer( + BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg) + + self.low1 = ResLayer( + BasicBlock, + cur_channel, + next_channel, + cur_block, + stride=2, + norm_cfg=norm_cfg) + + if self.depth > 1: + self.low2 = HourglassModule(depth - 1, stage_channels[1:], + stage_blocks[1:]) + else: + self.low2 = ResLayer( + BasicBlock, + next_channel, + next_channel, + next_block, + norm_cfg=norm_cfg) + + self.low3 = ResLayer( + BasicBlock, + next_channel, + cur_channel, + cur_block, + norm_cfg=norm_cfg, + downsample_first=False) + + self.up2 = F.interpolate + self.upsample_cfg = upsample_cfg + + def forward(self, x: torch.Tensor) -> nn.Module: + """Forward function.""" + up1 = self.up1(x) + low1 = self.low1(x) + low2 = self.low2(low1) + low3 = self.low3(low2) + # Fixing `scale factor` (e.g. 2) is common for upsampling, but + # in some cases the spatial size is mismatched and error will arise. + if 'scale_factor' in self.upsample_cfg: + up2 = self.up2(low3, **self.upsample_cfg) + else: + shape = up1.shape[2:] + up2 = self.up2(low3, size=shape, **self.upsample_cfg) + return up1 + up2 + + +@MODELS.register_module() +class HourglassNet(BaseModule): + """HourglassNet backbone. + + Stacked Hourglass Networks for Human Pose Estimation. + More details can be found in the `paper + `_ . + + Args: + downsample_times (int): Downsample times in a HourglassModule. + num_stacks (int): Number of HourglassModule modules stacked, + 1 for Hourglass-52, 2 for Hourglass-104. + stage_channels (Sequence[int]): Feature channel of each sub-module in a + HourglassModule. + stage_blocks (Sequence[int]): Number of sub-modules stacked in a + HourglassModule. + feat_channel (int): Feature channel of conv after a HourglassModule. + norm_cfg (norm_cfg): Dictionary to construct and config norm layer. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. + + Example: + >>> from mmdet.models import HourglassNet + >>> import torch + >>> self = HourglassNet() + >>> self.eval() + >>> inputs = torch.rand(1, 3, 511, 511) + >>> level_outputs = self.forward(inputs) + >>> for level_output in level_outputs: + ... print(tuple(level_output.shape)) + (1, 256, 128, 128) + (1, 256, 128, 128) + """ + + def __init__(self, + downsample_times: int = 5, + num_stacks: int = 2, + stage_channels: Sequence = (256, 256, 384, 384, 384, 512), + stage_blocks: Sequence = (2, 2, 2, 2, 2, 4), + feat_channel: int = 256, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + init_cfg: OptMultiConfig = None) -> None: + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__(init_cfg) + + self.num_stacks = num_stacks + assert self.num_stacks >= 1 + assert len(stage_channels) == len(stage_blocks) + assert len(stage_channels) > downsample_times + + cur_channel = stage_channels[0] + + self.stem = nn.Sequential( + ConvModule( + 3, cur_channel // 2, 7, padding=3, stride=2, + norm_cfg=norm_cfg), + ResLayer( + BasicBlock, + cur_channel // 2, + cur_channel, + 1, + stride=2, + norm_cfg=norm_cfg)) + + self.hourglass_modules = nn.ModuleList([ + HourglassModule(downsample_times, stage_channels, stage_blocks) + for _ in range(num_stacks) + ]) + + self.inters = ResLayer( + BasicBlock, + cur_channel, + cur_channel, + num_stacks - 1, + norm_cfg=norm_cfg) + + self.conv1x1s = nn.ModuleList([ + ConvModule( + cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) + for _ in range(num_stacks - 1) + ]) + + self.out_convs = nn.ModuleList([ + ConvModule( + cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) + for _ in range(num_stacks) + ]) + + self.remap_convs = nn.ModuleList([ + ConvModule( + feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) + for _ in range(num_stacks - 1) + ]) + + self.relu = nn.ReLU(inplace=True) + + def init_weights(self) -> None: + """Init module weights.""" + # Training Centripetal Model needs to reset parameters for Conv2d + super().init_weights() + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.reset_parameters() + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Forward function.""" + inter_feat = self.stem(x) + out_feats = [] + + for ind in range(self.num_stacks): + single_hourglass = self.hourglass_modules[ind] + out_conv = self.out_convs[ind] + + hourglass_feat = single_hourglass(inter_feat) + out_feat = out_conv(hourglass_feat) + out_feats.append(out_feat) + + if ind < self.num_stacks - 1: + inter_feat = self.conv1x1s[ind]( + inter_feat) + self.remap_convs[ind]( + out_feat) + inter_feat = self.inters[ind](self.relu(inter_feat)) + + return out_feats diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..77bd3cc7125bb7ba03cd201ab3a55174b01dde50 --- /dev/null +++ b/mmdet/models/backbones/hrnet.py @@ -0,0 +1,589 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS +from .resnet import BasicBlock, Bottleneck + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + block_init_cfg=None, + init_cfg=None): + super(HRModule, self).__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_BLOCKS({len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_CHANNELS({len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, num_channels[branch_index] * + block.expansion)[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + + return Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + `High-Resolution Representations for Labeling Pixels and Regions + arXiv: `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules(int): The number of HRModule in this stage. + - num_branches(int): The number of branches in the HRModule. + - block(str): The type of convolution block. + - num_blocks(tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels(tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Default: 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmdet.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=True, + with_cp=False, + zero_init_residual=False, + multiscale_output=True, + pretrained=None, + init_cfg=None): + super(HRNet, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + # Assert configurations of 4 stages are in extra + assert 'stage1' in extra and 'stage2' in extra \ + and 'stage3' in extra and 'stage4' in extra + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + for i in range(4): + cfg = extra[f'stage{i + 1}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * block.expansion + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multiscale_output=multiscale_output) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg, + )) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + + return Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules), in_channels + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(HRNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmdet/models/backbones/mobilenet_v2.py b/mmdet/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a4fd0519ad4d5106e1acb82624d6393052596ce8 --- /dev/null +++ b/mmdet/models/backbones/mobilenet_v2.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS +from ..layers import InvertedResidual +from ..utils import make_divisible + + +@MODELS.register_module() +class MobileNetV2(BaseModule): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int], optional): Output from which stages. + Default: (1, 2, 4, 7). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=(1, 2, 4, 7), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super(MobileNetV2, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + self.widen_factor = widen_factor + self.out_indices = out_indices + if not set(out_indices).issubset(set(range(0, 8))): + raise ValueError('out_indices must be a subset of range' + f'(0, 8). But received {out_indices}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + mid_channels=int(round(self.in_channels * expand_ratio)), + stride=stride, + with_expand_conv=expand_ratio != 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + frozen.""" + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmdet/models/backbones/pvt.py b/mmdet/models/backbones/pvt.py new file mode 100644 index 0000000000000000000000000000000000000000..8b250f63c1b22f21a892faf4c41ccc2d20e83e13 --- /dev/null +++ b/mmdet/models/backbones/pvt.py @@ -0,0 +1,665 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmengine.logging import MMLogger +from mmengine.model import (BaseModule, ModuleList, Sequential, constant_init, + normal_init, trunc_normal_init) +from mmengine.model.weight_init import trunc_normal_ +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn.modules.utils import _pair as to_2tuple + +from mmdet.registry import MODELS +from ..layers import PatchEmbed, nchw_to_nlc, nlc_to_nchw + + +class MixFFN(BaseModule): + """An implementation of MixFFN of PVT. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Depth-wise Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + Default: None. + use_conv (bool): If True, add 3x3 DWConv between two Linear layers. + Defaults: False. + init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + use_conv=False, + init_cfg=None): + super(MixFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + if use_conv: + # 3x3 depth wise conv to provide positional encode information + dw_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, activate, drop, fc2, drop] + if use_conv: + layers.insert(1, dw_conv) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class SpatialReductionAttention(MultiheadAttention): + """An implementation of Spatial Reduction Attention of PVT. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default: True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention of PVT. Default: 1. + init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type='LN'), + sr_ratio=1, + init_cfg=None): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + batch_first=batch_first, + dropout_layer=dropout_layer, + bias=qkv_bias, + init_cfg=init_cfg) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmdet import digit_version, mmcv_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'SpatialReductionAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_queries, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_queries, embed_dims) to num_queries_first + # (num_queries ,batch, embed_dims), and recover ``attn_output`` + # from num_queries_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class PVTEncoderLayer(BaseModule): + """Implements one encoder layer in PVT. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default: 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention of PVT. Default: 1. + use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. + Default: False. + init_cfg (dict, optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1, + use_conv_ffn=False, + init_cfg=None): + super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg) + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = SpatialReductionAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + use_conv=use_conv_ffn, + act_cfg=act_cfg) + + def forward(self, x, hw_shape): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + + return x + + +class AbsolutePositionEmbedding(BaseModule): + """An implementation of the absolute position embedding in PVT. + + Args: + pos_shape (int): The shape of the absolute position embedding. + pos_dim (int): The dimension of the absolute position embedding. + drop_rate (float): Probability of an element to be zeroed. + Default: 0.0. + """ + + def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(pos_shape, int): + pos_shape = to_2tuple(pos_shape) + elif isinstance(pos_shape, tuple): + if len(pos_shape) == 1: + pos_shape = to_2tuple(pos_shape[0]) + assert len(pos_shape) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pos_shape)}' + self.pos_shape = pos_shape + self.pos_dim = pos_dim + + self.pos_embed = nn.Parameter( + torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim)) + self.drop = nn.Dropout(p=drop_rate) + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + + def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'): + """Resize pos_embed weights. + + Resize pos_embed using bilinear interpolate method. + + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shape (tuple): Tuple for (downsampled input image height, + downsampled input image width). + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'bilinear'``. + + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C]. + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = self.pos_shape + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous() + pos_embed_weight = F.interpolate( + pos_embed_weight, size=input_shape, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, + 2).transpose(1, 2).contiguous() + pos_embed = pos_embed_weight + + return pos_embed + + def forward(self, x, hw_shape, mode='bilinear'): + pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode) + return self.drop(x + pos_embed) + + +@MODELS.register_module() +class PyramidVisionTransformer(BaseModule): + """Pyramid Vision Transformer (PVT) + + Implementation of `Pyramid Vision Transformer: A Versatile Backbone for + Dense Prediction without Convolutions + `_. + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 64. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 5, 8]. + patch_sizes (Sequence[int]): The patch_size of each patch embedding. + Default: [4, 2, 2, 2]. + strides (Sequence[int]): The stride of each patch embedding. + Default: [4, 2, 2, 2]. + paddings (Sequence[int]): The padding of each patch embedding. + Default: [0, 0, 0, 0]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the + embedding dim of each transformer encode layer. + Default: [8, 8, 4, 4]. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: True. + use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. + Default: False. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + pretrained (str, optional): model pretrained path. Default: None. + convert_weights (bool): The flag indicates whether the + pre-trained model is from the original repo. We may need + to convert some keys to make it compatible. + Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 5, 8], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + paddings=[0, 0, 0, 0], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=True, + norm_after_stage=False, + use_conv_ffn=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrained=None, + convert_weights=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.convert_weights = convert_weights + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + self.init_cfg = init_cfg + else: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + self.pretrained = pretrained + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=paddings[i], + bias=True, + norm_cfg=norm_cfg) + + layers = ModuleList() + if use_abs_pos_embed: + pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1]) + pos_embed = AbsolutePositionEmbedding( + pos_shape=pos_shape, + pos_dim=embed_dims_i, + drop_rate=drop_rate) + layers.append(pos_embed) + layers.extend([ + PVTEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratios[i] * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + sr_ratio=sr_ratios[i], + use_conv_ffn=use_conv_ffn) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + if norm_after_stage: + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + else: + norm = nn.Identity() + self.layers.append(ModuleList([patch_embed, layers, norm])) + cur += num_layer + + def init_weights(self): + logger = MMLogger.get_current_instance() + if self.init_cfg is None: + logger.warn(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init(m, 0, math.sqrt(2.0 / fan_out)) + elif isinstance(m, AbsolutePositionEmbedding): + m.init_weights() + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg.checkpoint, logger=logger, map_location='cpu') + logger.warn(f'Load pre-trained model for ' + f'{self.__class__.__name__} from original repo') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + if self.convert_weights: + # Because pvt backbones are not supported by mmpretrain, + # so we need to convert pre-trained weights to match this + # implementation. + state_dict = pvt_convert(state_dict) + load_state_dict(self, state_dict, strict=False, logger=logger) + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, hw_shape = layer[0](x) + + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs + + +@MODELS.register_module() +class PyramidVisionTransformerV2(PyramidVisionTransformer): + """Implementation of `PVTv2: Improved Baselines with Pyramid Vision + Transformer `_.""" + + def __init__(self, **kwargs): + super(PyramidVisionTransformerV2, self).__init__( + patch_sizes=[7, 3, 3, 3], + paddings=[3, 1, 1, 1], + use_abs_pos_embed=False, + norm_after_stage=True, + use_conv_ffn=True, + **kwargs) + + +def pvt_convert(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + use_abs_pos_embed = False + use_conv_ffn = False + for k in ckpt.keys(): + if k.startswith('pos_embed'): + use_abs_pos_embed = True + if k.find('dwconv') >= 0: + use_conv_ffn = True + for k, v in ckpt.items(): + if k.startswith('head'): + continue + if k.startswith('norm.'): + continue + if k.startswith('cls_token'): + continue + if k.startswith('pos_embed'): + stage_i = int(k.replace('pos_embed', '')) + new_k = k.replace(f'pos_embed{stage_i}', + f'layers.{stage_i - 1}.1.0.pos_embed') + if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7 + new_v = v[:, 1:, :] # remove cls token + else: + new_v = v + elif k.startswith('patch_embed'): + stage_i = int(k.split('.')[0].replace('patch_embed', '')) + new_k = k.replace(f'patch_embed{stage_i}', + f'layers.{stage_i - 1}.0') + new_v = v + if 'proj.' in new_k: + new_k = new_k.replace('proj.', 'projection.') + elif k.startswith('block'): + stage_i = int(k.split('.')[0].replace('block', '')) + layer_i = int(k.split('.')[1]) + new_layer_i = layer_i + use_abs_pos_embed + new_k = k.replace(f'block{stage_i}.{layer_i}', + f'layers.{stage_i - 1}.1.{new_layer_i}') + new_v = v + if 'attn.q.' in new_k: + sub_item_k = k.replace('q.', 'kv.') + new_k = new_k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) + elif 'attn.kv.' in new_k: + continue + elif 'attn.proj.' in new_k: + new_k = new_k.replace('proj.', 'attn.out_proj.') + elif 'attn.sr.' in new_k: + new_k = new_k.replace('sr.', 'sr.') + elif 'mlp.' in new_k: + string = f'{new_k}-' + new_k = new_k.replace('mlp.', 'ffn.layers.') + if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + new_v = v.reshape((*v.shape, 1, 1)) + new_k = new_k.replace('fc1.', '0.') + new_k = new_k.replace('dwconv.dwconv.', '1.') + if use_conv_ffn: + new_k = new_k.replace('fc2.', '4.') + else: + new_k = new_k.replace('fc2.', '3.') + string += f'{new_k} {v.shape}-{new_v.shape}' + elif k.startswith('norm'): + stage_i = int(k[4]) + new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2') + new_v = v + else: + new_k = k + new_v = v + new_ckpt[new_k] = new_v + + return new_ckpt diff --git a/mmdet/models/backbones/regnet.py b/mmdet/models/backbones/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..55d3ce075f0cec68de4537a71ed569151d684562 --- /dev/null +++ b/mmdet/models/backbones/regnet.py @@ -0,0 +1,356 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmdet.registry import MODELS +from .resnet import ResNet +from .resnext import Bottleneck + + +@MODELS.register_module() +class RegNet(ResNet): + """RegNet backbone. + + More details can be found in `paper `_ . + + Args: + arch (dict): The parameter of RegNets. + + - w0 (int): initial width + - wa (float): slope of width + - wm (float): quantization parameter to quantize the width + - depth (int): depth of the backbone + - group_w (int): width of group + - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck. + strides (Sequence[int]): Strides of the first block of each stage. + base_channels (int): Base channels after stem layer. + in_channels (int): Number of input image channels. Default: 3. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmdet.models import RegNet + >>> import torch + >>> self = RegNet( + arch=dict( + w0=88, + wa=26.31, + wm=2.25, + group_w=48, + depth=25, + bot_mul=1.0)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 96, 8, 8) + (1, 192, 4, 4) + (1, 432, 2, 2) + (1, 1008, 1, 1) + """ + arch_settings = { + 'regnetx_400mf': + dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0), + 'regnetx_800mf': + dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0), + 'regnetx_1.6gf': + dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0), + 'regnetx_3.2gf': + dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0), + 'regnetx_4.0gf': + dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0), + 'regnetx_6.4gf': + dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0), + 'regnetx_8.0gf': + dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0), + 'regnetx_12gf': + dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0), + } + + def __init__(self, + arch, + in_channels=3, + stem_channels=32, + base_channels=32, + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + + # Generate RegNet parameters first + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the' \ + ' arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise ValueError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + widths, num_stages = self.generate_regnet( + arch['w0'], + arch['wa'], + arch['wm'], + arch['depth'], + ) + # Convert to per stage format + stage_widths, stage_blocks = self.get_stages_from_blocks(widths) + # Generate group widths and bot muls + group_widths = [arch['group_w'] for _ in range(num_stages)] + self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)] + # Adjust the compatibility of stage_widths and group_widths + stage_widths, group_widths = self.adjust_width_group( + stage_widths, self.bottleneck_ratio, group_widths) + + # Group params by stage + self.stage_widths = stage_widths + self.group_widths = group_widths + self.depth = sum(stage_blocks) + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.zero_init_residual = zero_init_residual + self.block = Bottleneck + expansion_bak = self.block.expansion + self.block.expansion = 1 + self.stage_blocks = stage_blocks[:num_stages] + + self._make_stem_layer(in_channels, stem_channels) + + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + if self.zero_init_residual: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.inplanes = stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + group_width = self.group_widths[i] + width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i])) + stage_groups = width // group_width + + dcn = self.dcn if self.stage_with_dcn[i] else None + if self.plugins is not None: + stage_plugins = self.make_stage_plugins(self.plugins, i) + else: + stage_plugins = None + + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=self.stage_widths[i], + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=dcn, + plugins=stage_plugins, + groups=stage_groups, + base_width=group_width, + base_channels=self.stage_widths[i], + init_cfg=block_init_cfg) + self.inplanes = self.stage_widths[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = stage_widths[-1] + self.block.expansion = expansion_bak + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def generate_regnet(self, + initial_width, + width_slope, + width_parameter, + depth, + divisor=8): + """Generates per block width from RegNet parameters. + + Args: + initial_width ([int]): Initial width of the backbone + width_slope ([float]): Slope of the quantized linear function + width_parameter ([int]): Parameter used to quantize the width. + depth ([int]): Depth of the backbone. + divisor (int, optional): The divisor of channels. Defaults to 8. + + Returns: + list, int: return a list of widths of each stage and the number \ + of stages + """ + assert width_slope >= 0 + assert initial_width > 0 + assert width_parameter > 1 + assert initial_width % divisor == 0 + widths_cont = np.arange(depth) * width_slope + initial_width + ks = np.round( + np.log(widths_cont / initial_width) / np.log(width_parameter)) + widths = initial_width * np.power(width_parameter, ks) + widths = np.round(np.divide(widths, divisor)) * divisor + num_stages = len(np.unique(widths)) + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages + + @staticmethod + def quantize_float(number, divisor): + """Converts a float to closest non-zero int divisible by divisor. + + Args: + number (int): Original number to be quantized. + divisor (int): Divisor used to quantize the number. + + Returns: + int: quantized number that is divisible by devisor. + """ + return int(round(number / divisor) * divisor) + + def adjust_width_group(self, widths, bottleneck_ratio, groups): + """Adjusts the compatibility of widths and groups. + + Args: + widths (list[int]): Width of each stage. + bottleneck_ratio (float): Bottleneck ratio. + groups (int): number of groups in each stage + + Returns: + tuple(list): The adjusted widths and groups of each stage. + """ + bottleneck_width = [ + int(w * b) for w, b in zip(widths, bottleneck_ratio) + ] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)] + bottleneck_width = [ + self.quantize_float(w_bot, g) + for w_bot, g in zip(bottleneck_width, groups) + ] + widths = [ + int(w_bot / b) + for w_bot, b in zip(bottleneck_width, bottleneck_ratio) + ] + return widths, groups + + def get_stages_from_blocks(self, widths): + """Gets widths/stage_blocks of network at each stage. + + Args: + widths (list[int]): Width in each stage. + + Returns: + tuple(list): width and depth of each stage + """ + width_diff = [ + width != width_prev + for width, width_prev in zip(widths + [0], [0] + widths) + ] + stage_widths = [ + width for width, diff in zip(widths, width_diff[:-1]) if diff + ] + stage_blocks = np.diff([ + depth for depth, diff in zip(range(len(width_diff)), width_diff) + if diff + ]).tolist() + return stage_widths, stage_blocks + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmdet/models/backbones/res2net.py b/mmdet/models/backbones/res2net.py new file mode 100644 index 0000000000000000000000000000000000000000..958fc88465c6769cb4c50907c92335331e8b7834 --- /dev/null +++ b/mmdet/models/backbones/res2net.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import Sequential + +from mmdet.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottle2neck(_Bottleneck): + expansion = 4 + + def __init__(self, + inplanes, + planes, + scales=4, + base_width=26, + base_channels=64, + stage_type='normal', + **kwargs): + """Bottle2neck block for Res2Net. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottle2neck, self).__init__(inplanes, planes, **kwargs) + assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.' + width = int(math.floor(self.planes * (base_width / base_channels))) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width * scales, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width * scales, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + + if stage_type == 'stage' and self.conv2_stride != 1: + self.pool = nn.AvgPool2d( + kernel_size=3, stride=self.conv2_stride, padding=1) + convs = [] + bns = [] + + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + for i in range(scales - 1): + convs.append( + build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False)) + bns.append( + build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1]) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + for i in range(scales - 1): + convs.append( + build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False)) + bns.append( + build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1]) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width * scales, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.stage_type = stage_type + self.scales = scales + self.width = width + delattr(self, 'conv2') + delattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + spx = torch.split(out, self.width, 1) + sp = self.convs[0](spx[0].contiguous()) + sp = self.relu(self.bns[0](sp)) + out = sp + for i in range(1, self.scales - 1): + if self.stage_type == 'stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp.contiguous()) + sp = self.relu(self.bns[i](sp)) + out = torch.cat((out, sp), 1) + + if self.stage_type == 'normal' or self.conv2_stride == 1: + out = torch.cat((out, spx[self.scales - 1]), 1) + elif self.stage_type == 'stage': + out = torch.cat((out, self.pool(spx[self.scales - 1])), 1) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Res2Layer(Sequential): + """Res2Layer to build Res2Net style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + scales (int): Scales used in Res2Net. Default: 4 + base_width (int): Basic width of each scale. Default: 26 + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + scales=4, + base_width=26, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1], + ) + + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + stage_type='stage', + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + **kwargs)) + super(Res2Layer, self).__init__(*layers) + + +@MODELS.register_module() +class Res2Net(ResNet): + """Res2Net backbone. + + Args: + scales (int): Scales used in Res2Net. Default: 4 + base_width (int): Basic width of each scale. Default: 26 + depth (int): Depth of res2net, from {50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Res2net stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmdet.models import Res2Net + >>> import torch + >>> self = Res2Net(depth=50, scales=4, base_width=26) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottle2neck, (3, 4, 6, 3)), + 101: (Bottle2neck, (3, 4, 23, 3)), + 152: (Bottle2neck, (3, 8, 36, 3)) + } + + def __init__(self, + scales=4, + base_width=26, + style='pytorch', + deep_stem=True, + avg_down=True, + pretrained=None, + init_cfg=None, + **kwargs): + self.scales = scales + self.base_width = base_width + super(Res2Net, self).__init__( + style='pytorch', + deep_stem=True, + avg_down=True, + pretrained=pretrained, + init_cfg=init_cfg, + **kwargs) + + def make_res_layer(self, **kwargs): + return Res2Layer( + scales=self.scales, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmdet/models/backbones/resnest.py b/mmdet/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..d4466c4cc416237bee1f870b52e3c20a849c5a60 --- /dev/null +++ b/mmdet/models/backbones/resnest.py @@ -0,0 +1,322 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from ..layers import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(BaseModule): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Number of channels in the input feature map. + channels (int): Number of intermediate channels. + kernel_size (int | tuple[int]): Size of the convolution kernel. + stride (int | tuple[int]): Stride of the convolution. + padding (int | tuple[int]): Zero-padding added to both sides of + dilation (int | tuple[int]): Spacing between kernel elements. + groups (int): Number of blocked connections from input channels to + output channels. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + init_cfg=None): + super(SplitAttentionConv2d, self).__init__(init_cfg) + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + # To be consistent with original implementation, starting from 0 + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + base_width (int): Base of width in terms of base channels. Default: 4. + base_channels (int): Base of channels for calculating width. + Default: 64. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SplitAttentionConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6f48f94f286e3c5e3179f752a7b36ea77c0d45 --- /dev/null +++ b/mmdet/models/backbones/resnet.py @@ -0,0 +1,672 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet.registry import MODELS +from ..layers import ResLayer + + +class BasicBlock(BaseModule): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + out = x + for name in plugin_names: + out = getattr(self, name)(out) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmdet.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + self.zero_init_residual = zero_init_residual + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be: + + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->conv3->yyy->zzz1->zzz2 + + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class ResNetV1d(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..df3d79e046c3ab9b289bcfeb6f937c87f6c09bfa --- /dev/null +++ b/mmdet/models/backbones/resnext.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmdet.registry import MODELS +from ..layers import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + if self.with_plugins: + self._del_block_plugins(self.after_conv1_plugin_names + + self.after_conv2_plugin_names + + self.after_conv3_plugin_names) + self.after_conv1_plugin_names = self.make_block_plugins( + width, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + width, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + self.planes * self.expansion, self.after_conv3_plugins) + + def _del_block_plugins(self, plugin_names): + """delete plugins for block if exist. + + Args: + plugin_names (list[str]): List of plugins name to delete. + """ + assert isinstance(plugin_names, list) + for plugin_name in plugin_names: + del self._modules[plugin_name] + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + groups (int): Group of resnext. + base_width (int): Base width of resnext. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super(ResNeXt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..843e82e2722f93b9b2abb5180c827c8f2a430b48 --- /dev/null +++ b/mmdet/models/backbones/ssd_vgg.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import VGG +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from ..necks import ssd_neck + + +@MODELS.register_module() +class SSDVGG(VGG, BaseModule): + """VGG Backbone network for single-shot-detection. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_last_pool (bool): Whether to add a pooling layer at the last + of the model + ceil_mode (bool): When True, will use `ceil` instead of `floor` + to compute the output shape. + out_indices (Sequence[int]): Output from which stages. + out_feature_indices (Sequence[int]): Output from which feature map. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + input_size (int, optional): Deprecated argumment. + Width and height of input, from {300, 512}. + l2_norm_scale (float, optional) : Deprecated argumment. + L2 normalization layer init scale. + + Example: + >>> self = SSDVGG(input_size=300, depth=11) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 300, 300) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 1024, 19, 19) + (1, 512, 10, 10) + (1, 256, 5, 5) + (1, 256, 3, 3) + (1, 256, 1, 1) + """ + extra_setting = { + 300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256), + 512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128), + } + + def __init__(self, + depth, + with_last_pool=False, + ceil_mode=True, + out_indices=(3, 4), + out_feature_indices=(22, 34), + pretrained=None, + init_cfg=None, + input_size=None, + l2_norm_scale=None): + # TODO: in_channels for mmcv.VGG + super(SSDVGG, self).__init__( + depth, + with_last_pool=with_last_pool, + ceil_mode=ceil_mode, + out_indices=out_indices) + + self.features.add_module( + str(len(self.features)), + nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + self.features.add_module( + str(len(self.features)), + nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)) + self.features.add_module( + str(len(self.features)), nn.ReLU(inplace=True)) + self.features.add_module( + str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1)) + self.features.add_module( + str(len(self.features)), nn.ReLU(inplace=True)) + self.out_feature_indices = out_feature_indices + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + + if init_cfg is not None: + self.init_cfg = init_cfg + elif isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict(type='Constant', val=1, layer='BatchNorm2d'), + dict(type='Normal', std=0.01, layer='Linear'), + ] + else: + raise TypeError('pretrained must be a str or None') + + if input_size is not None: + warnings.warn('DeprecationWarning: input_size is deprecated') + if l2_norm_scale is not None: + warnings.warn('DeprecationWarning: l2_norm_scale in VGG is ' + 'deprecated, it has been moved to SSDNeck.') + + def init_weights(self, pretrained=None): + super(VGG, self).init_weights() + + def forward(self, x): + """Forward function.""" + outs = [] + for i, layer in enumerate(self.features): + x = layer(x) + if i in self.out_feature_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + +class L2Norm(ssd_neck.L2Norm): + + def __init__(self, **kwargs): + super(L2Norm, self).__init__(**kwargs) + warnings.warn('DeprecationWarning: L2Norm in ssd_vgg.py ' + 'is deprecated, please use L2Norm in ' + 'mmdet/models/necks/ssd_neck.py instead') diff --git a/mmdet/models/backbones/swin.py b/mmdet/models/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..062190fa077d7b01e0c1db76bea0cfb5dc7b6620 --- /dev/null +++ b/mmdet/models/backbones/swin.py @@ -0,0 +1,819 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmengine.logging import MMLogger +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) +from mmengine.runner.checkpoint import CheckpointLoader +from mmengine.utils import to_2tuple + +from mmdet.registry import MODELS +from ..layers import PatchEmbed, PatchMerging + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__() + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + self.init_cfg = init_cfg + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SwinBlock, self).__init__() + + self.init_cfg = init_cfg + self.with_cp = with_cp + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + init_cfg=None) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@MODELS.register_module() +class SwinTransformer(BaseModule): + """ Swin Transformer + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/abs/2103.14030 + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + convert_weights (bool): The flag indicates whether the + pre-trained model is from the original repo. We may need + to convert some keys to make it compatible. + Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + Default: -1 (-1 means not freezing any parameters). + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained=None, + convert_weights=False, + frozen_stages=-1, + init_cfg=None): + self.convert_weights = convert_weights + self.frozen_stages = frozen_stages + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + self.init_cfg = init_cfg + else: + raise TypeError('pretrained must be a str or None') + + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * in_channels, + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + logger = MMLogger.get_current_instance() + if self.init_cfg is None: + logger.warn(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, 1.0) + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg.checkpoint, logger=logger, map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + _state_dict = ckpt['model'] + else: + _state_dict = ckpt + if self.convert_weights: + # supported loading weight from original repo, + _state_dict = swin_converter(_state_dict) + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + self.load_state_dict(state_dict, False) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs + + +def swin_converter(ckpt): + + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + else: + new_v = v + new_k = k + + new_ckpt['backbone.' + new_k] = new_v + + return new_ckpt diff --git a/mmdet/models/backbones/trident_resnet.py b/mmdet/models/backbones/trident_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..22c76354522ff8533b094df6858ec361ba400c1e --- /dev/null +++ b/mmdet/models/backbones/trident_resnet.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from torch.nn.modules.utils import _pair + +from mmdet.models.backbones.resnet import Bottleneck, ResNet +from mmdet.registry import MODELS + + +class TridentConv(BaseModule): + """Trident Convolution Module. + + Args: + in_channels (int): Number of channels in input. + out_channels (int): Number of channels in output. + kernel_size (int): Size of convolution kernel. + stride (int, optional): Convolution stride. Default: 1. + trident_dilations (tuple[int, int, int], optional): Dilations of + different trident branch. Default: (1, 2, 3). + test_branch_idx (int, optional): In inference, all 3 branches will + be used if `test_branch_idx==-1`, otherwise only branch with + index `test_branch_idx` will be used. Default: 1. + bias (bool, optional): Whether to use bias in convolution or not. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + trident_dilations=(1, 2, 3), + test_branch_idx=1, + bias=False, + init_cfg=None): + super(TridentConv, self).__init__(init_cfg) + self.num_branch = len(trident_dilations) + self.with_bias = bias + self.test_branch_idx = test_branch_idx + self.stride = _pair(stride) + self.kernel_size = _pair(kernel_size) + self.paddings = _pair(trident_dilations) + self.dilations = trident_dilations + self.in_channels = in_channels + self.out_channels = out_channels + self.bias = bias + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + def extra_repr(self): + tmpstr = f'in_channels={self.in_channels}' + tmpstr += f', out_channels={self.out_channels}' + tmpstr += f', kernel_size={self.kernel_size}' + tmpstr += f', num_branch={self.num_branch}' + tmpstr += f', test_branch_idx={self.test_branch_idx}' + tmpstr += f', stride={self.stride}' + tmpstr += f', paddings={self.paddings}' + tmpstr += f', dilations={self.dilations}' + tmpstr += f', bias={self.bias}' + return tmpstr + + def forward(self, inputs): + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, self.stride, padding, + dilation) for input, dilation, padding in zip( + inputs, self.dilations, self.paddings) + ] + else: + assert len(inputs) == 1 + outputs = [ + F.conv2d(inputs[0], self.weight, self.bias, self.stride, + self.paddings[self.test_branch_idx], + self.dilations[self.test_branch_idx]) + ] + + return outputs + + +# Since TridentNet is defined over ResNet50 and ResNet101, here we +# only support TridentBottleneckBlock. +class TridentBottleneck(Bottleneck): + """BottleBlock for TridentResNet. + + Args: + trident_dilations (tuple[int, int, int]): Dilations of different + trident branch. + test_branch_idx (int): In inference, all 3 branches will be used + if `test_branch_idx==-1`, otherwise only branch with index + `test_branch_idx` will be used. + concat_output (bool): Whether to concat the output list to a Tensor. + `True` only in the last Block. + """ + + def __init__(self, trident_dilations, test_branch_idx, concat_output, + **kwargs): + + super(TridentBottleneck, self).__init__(**kwargs) + self.trident_dilations = trident_dilations + self.num_branch = len(trident_dilations) + self.concat_output = concat_output + self.test_branch_idx = test_branch_idx + self.conv2 = TridentConv( + self.planes, + self.planes, + kernel_size=3, + stride=self.conv2_stride, + bias=False, + trident_dilations=self.trident_dilations, + test_branch_idx=test_branch_idx, + init_cfg=dict( + type='Kaiming', + distribution='uniform', + mode='fan_in', + override=dict(name='conv2'))) + + def forward(self, x): + + def _inner_forward(x): + num_branch = ( + self.num_branch + if self.training or self.test_branch_idx == -1 else 1) + identity = x + if not isinstance(x, list): + x = (x, ) * num_branch + identity = x + if self.downsample is not None: + identity = [self.downsample(b) for b in x] + + out = [self.conv1(b) for b in x] + out = [self.norm1(b) for b in out] + out = [self.relu(b) for b in out] + + if self.with_plugins: + for k in range(len(out)): + out[k] = self.forward_plugin(out[k], + self.after_conv1_plugin_names) + + out = self.conv2(out) + out = [self.norm2(b) for b in out] + out = [self.relu(b) for b in out] + if self.with_plugins: + for k in range(len(out)): + out[k] = self.forward_plugin(out[k], + self.after_conv2_plugin_names) + + out = [self.conv3(b) for b in out] + out = [self.norm3(b) for b in out] + + if self.with_plugins: + for k in range(len(out)): + out[k] = self.forward_plugin(out[k], + self.after_conv3_plugin_names) + + out = [ + out_b + identity_b for out_b, identity_b in zip(out, identity) + ] + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = [self.relu(b) for b in out] + if self.concat_output: + out = torch.cat(out, dim=0) + return out + + +def make_trident_res_layer(block, + inplanes, + planes, + num_blocks, + stride=1, + trident_dilations=(1, 2, 3), + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + test_branch_idx=-1): + """Build Trident Res Layers.""" + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + for i in range(num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride if i == 0 else 1, + trident_dilations=trident_dilations, + downsample=downsample if i == 0 else None, + style=style, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=plugins, + test_branch_idx=test_branch_idx, + concat_output=True if i == num_blocks - 1 else False)) + inplanes = planes * block.expansion + return nn.Sequential(*layers) + + +@MODELS.register_module() +class TridentResNet(ResNet): + """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to + ResNet, while in stage 3, Trident BottleBlock is utilized to replace the + normal BottleBlock to yield trident output. Different branch shares the + convolution weight but uses different dilations to achieve multi-scale + output. + + / stage3(b0) \ + x - stem - stage1 - stage2 - stage3(b1) - output + \ stage3(b2) / + + Args: + depth (int): Depth of resnet, from {50, 101, 152}. + num_branch (int): Number of branches in TridentNet. + test_branch_idx (int): In inference, all 3 branches will be used + if `test_branch_idx==-1`, otherwise only branch with index + `test_branch_idx` will be used. + trident_dilations (tuple[int]): Dilations of different trident branch. + len(trident_dilations) should be equal to num_branch. + """ # noqa + + def __init__(self, depth, num_branch, test_branch_idx, trident_dilations, + **kwargs): + + assert num_branch == len(trident_dilations) + assert depth in (50, 101, 152) + super(TridentResNet, self).__init__(depth, **kwargs) + assert self.num_stages == 3 + self.test_branch_idx = test_branch_idx + self.num_branch = num_branch + + last_stage_idx = self.num_stages - 1 + stride = self.strides[last_stage_idx] + dilation = trident_dilations + dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None + if self.plugins is not None: + stage_plugins = self.make_stage_plugins(self.plugins, + last_stage_idx) + else: + stage_plugins = None + planes = self.base_channels * 2**last_stage_idx + res_layer = make_trident_res_layer( + TridentBottleneck, + inplanes=(self.block.expansion * self.base_channels * + 2**(last_stage_idx - 1)), + planes=planes, + num_blocks=self.stage_blocks[last_stage_idx], + stride=stride, + trident_dilations=dilation, + style=self.style, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=dcn, + plugins=stage_plugins, + test_branch_idx=self.test_branch_idx) + + layer_name = f'layer{last_stage_idx + 1}' + + self.__setattr__(layer_name, res_layer) + self.res_layers.pop(last_stage_idx) + self.res_layers.insert(last_stage_idx, layer_name) + + self._freeze_stages() diff --git a/mmdet/models/data_preprocessors/__init__.py b/mmdet/models/data_preprocessors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..201a1da6a4f320a17cea9c65d5c102bfdd7700d8 --- /dev/null +++ b/mmdet/models/data_preprocessors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_preprocessor import (BatchFixedSizePad, BatchResize, + BatchSyncRandomResize, BoxInstDataPreprocessor, + DetDataPreprocessor, + MultiBranchDataPreprocessor) +from .reid_data_preprocessor import ReIDDataPreprocessor +from .track_data_preprocessor import TrackDataPreprocessor + +__all__ = [ + 'DetDataPreprocessor', 'BatchSyncRandomResize', 'BatchFixedSizePad', + 'MultiBranchDataPreprocessor', 'BatchResize', 'BoxInstDataPreprocessor', + 'TrackDataPreprocessor', 'ReIDDataPreprocessor' +] diff --git a/mmdet/models/data_preprocessors/data_preprocessor.py b/mmdet/models/data_preprocessors/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..55b5c35b3a4888c95c6646df3fa080347afe4704 --- /dev/null +++ b/mmdet/models/data_preprocessors/data_preprocessor.py @@ -0,0 +1,793 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.dist import barrier, broadcast, get_dist_info +from mmengine.logging import MessageHub +from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor +from mmengine.structures import PixelData +from mmengine.utils import is_seq_of +from torch import Tensor + +from mmdet.models.utils import unfold_wo_center +from mmdet.models.utils.misc import samplelist_boxtype2tensor +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.mask import BitmapMasks +from mmdet.utils import ConfigType + +try: + import skimage +except ImportError: + skimage = None + + +@MODELS.register_module() +class DetDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for detection tasks. + + Comparing with the :class:`mmengine.ImgDataPreprocessor`, + + 1. It supports batch augmentations. + 2. It will additionally append batch_input_shape and pad_shape + to data_samples considering the object detection task. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + pad_mask (bool): Whether to pad instance masks. Defaults to False. + mask_pad_value (int): The padded pixel value for instance masks. + Defaults to 0. + pad_seg (bool): Whether to pad semantic segmentation maps. + Defaults to False. + seg_pad_value (int): The padded pixel value for semantic + segmentation maps. Defaults to 255. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of + bboxes data to ``Tensor`` type. Defaults to True. + non_blocking (bool): Whether block current process + when transferring data to device. Defaults to False. + batch_augments (list[dict], optional): Batch-level augmentations + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + pad_mask: bool = False, + mask_pad_value: int = 0, + pad_seg: bool = False, + seg_pad_value: int = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + boxtype2tensor: bool = True, + non_blocking: Optional[bool] = False, + batch_augments: Optional[List[dict]] = None): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + if batch_augments is not None: + self.batch_augments = nn.ModuleList( + [MODELS.build(aug) for aug in batch_augments]) + else: + self.batch_augments = None + self.pad_mask = pad_mask + self.mask_pad_value = mask_pad_value + self.pad_seg = pad_seg + self.seg_pad_value = seg_pad_value + self.boxtype2tensor = boxtype2tensor + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization,padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): Data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + batch_pad_shape = self._get_pad_shape(data) + data = super().forward(data=data, training=training) + inputs, data_samples = data['inputs'], data['data_samples'] + + if data_samples is not None: + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + batch_input_shape = tuple(inputs[0].size()[-2:]) + for data_sample, pad_shape in zip(data_samples, batch_pad_shape): + data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'pad_shape': pad_shape + }) + + if self.boxtype2tensor: + samplelist_boxtype2tensor(data_samples) + + if self.pad_mask and training: + self.pad_gt_masks(data_samples) + + if self.pad_seg and training: + self.pad_gt_sem_seg(data_samples) + + if training and self.batch_augments is not None: + for batch_aug in self.batch_augments: + inputs, data_samples = batch_aug(inputs, data_samples) + + return {'inputs': inputs, 'data_samples': data_samples} + + def _get_pad_shape(self, data: dict) -> List[tuple]: + """Get the pad_shape of each image based on data and + pad_size_divisor.""" + _batch_inputs = data['inputs'] + # Process data with `pseudo_collate`. + if is_seq_of(_batch_inputs, torch.Tensor): + batch_pad_shape = [] + for ori_input in _batch_inputs: + pad_h = int( + np.ceil(ori_input.shape[1] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int( + np.ceil(ori_input.shape[2] / + self.pad_size_divisor)) * self.pad_size_divisor + batch_pad_shape.append((pad_h, pad_w)) + # Process data with `default_collate`. + elif isinstance(_batch_inputs, torch.Tensor): + assert _batch_inputs.dim() == 4, ( + 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' + 'or a list of tensor, but got a tensor with shape: ' + f'{_batch_inputs.shape}') + pad_h = int( + np.ceil(_batch_inputs.shape[2] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int( + np.ceil(_batch_inputs.shape[3] / + self.pad_size_divisor)) * self.pad_size_divisor + batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0] + else: + raise TypeError('Output of `cast_data` should be a dict ' + 'or a tuple with inputs and data_samples, but got' + f'{type(data)}: {data}') + return batch_pad_shape + + def pad_gt_masks(self, + batch_data_samples: Sequence[DetDataSample]) -> None: + """Pad gt_masks to shape of batch_input_shape.""" + if 'masks' in batch_data_samples[0].gt_instances: + for data_samples in batch_data_samples: + masks = data_samples.gt_instances.masks + data_samples.gt_instances.masks = masks.pad( + data_samples.batch_input_shape, + pad_val=self.mask_pad_value) + + def pad_gt_sem_seg(self, + batch_data_samples: Sequence[DetDataSample]) -> None: + """Pad gt_sem_seg to shape of batch_input_shape.""" + if 'gt_sem_seg' in batch_data_samples[0]: + for data_samples in batch_data_samples: + gt_sem_seg = data_samples.gt_sem_seg.sem_seg + h, w = gt_sem_seg.shape[-2:] + pad_h, pad_w = data_samples.batch_input_shape + gt_sem_seg = F.pad( + gt_sem_seg, + pad=(0, max(pad_w - w, 0), 0, max(pad_h - h, 0)), + mode='constant', + value=self.seg_pad_value) + data_samples.gt_sem_seg = PixelData(sem_seg=gt_sem_seg) + + +@MODELS.register_module() +class BatchSyncRandomResize(nn.Module): + """Batch random resize which synchronizes the random size across ranks. + + Args: + random_size_range (tuple): The multi-scale random range during + multi-scale training. + interval (int): The iter interval of change + image size. Defaults to 10. + size_divisor (int): Image size divisible factor. + Defaults to 32. + """ + + def __init__(self, + random_size_range: Tuple[int, int], + interval: int = 10, + size_divisor: int = 32) -> None: + super().__init__() + self.rank, self.world_size = get_dist_info() + self._input_size = None + self._random_size_range = (round(random_size_range[0] / size_divisor), + round(random_size_range[1] / size_divisor)) + self._interval = interval + self._size_divisor = size_divisor + + def forward( + self, inputs: Tensor, data_samples: List[DetDataSample] + ) -> Tuple[Tensor, List[DetDataSample]]: + """resize a batch of images and bboxes to shape ``self._input_size``""" + h, w = inputs.shape[-2:] + if self._input_size is None: + self._input_size = (h, w) + scale_y = self._input_size[0] / h + scale_x = self._input_size[1] / w + if scale_x != 1 or scale_y != 1: + inputs = F.interpolate( + inputs, + size=self._input_size, + mode='bilinear', + align_corners=False) + for data_sample in data_samples: + img_shape = (int(data_sample.img_shape[0] * scale_y), + int(data_sample.img_shape[1] * scale_x)) + pad_shape = (int(data_sample.pad_shape[0] * scale_y), + int(data_sample.pad_shape[1] * scale_x)) + data_sample.set_metainfo({ + 'img_shape': img_shape, + 'pad_shape': pad_shape, + 'batch_input_shape': self._input_size + }) + data_sample.gt_instances.bboxes[ + ..., + 0::2] = data_sample.gt_instances.bboxes[..., + 0::2] * scale_x + data_sample.gt_instances.bboxes[ + ..., + 1::2] = data_sample.gt_instances.bboxes[..., + 1::2] * scale_y + if 'ignored_instances' in data_sample: + data_sample.ignored_instances.bboxes[ + ..., 0::2] = data_sample.ignored_instances.bboxes[ + ..., 0::2] * scale_x + data_sample.ignored_instances.bboxes[ + ..., 1::2] = data_sample.ignored_instances.bboxes[ + ..., 1::2] * scale_y + message_hub = MessageHub.get_current_instance() + if (message_hub.get_info('iter') + 1) % self._interval == 0: + self._input_size = self._get_random_size( + aspect_ratio=float(w / h), device=inputs.device) + return inputs, data_samples + + def _get_random_size(self, aspect_ratio: float, + device: torch.device) -> Tuple[int, int]: + """Randomly generate a shape in ``_random_size_range`` and broadcast to + all ranks.""" + tensor = torch.LongTensor(2).to(device) + if self.rank == 0: + size = random.randint(*self._random_size_range) + size = (self._size_divisor * size, + self._size_divisor * int(aspect_ratio * size)) + tensor[0] = size[0] + tensor[1] = size[1] + barrier() + broadcast(tensor, 0) + input_size = (tensor[0].item(), tensor[1].item()) + return input_size + + +@MODELS.register_module() +class BatchFixedSizePad(nn.Module): + """Fixed size padding for batch images. + + Args: + size (Tuple[int, int]): Fixed padding size. Expected padding + shape (h, w). Defaults to None. + img_pad_value (int): The padded pixel value for images. + Defaults to 0. + pad_mask (bool): Whether to pad instance masks. Defaults to False. + mask_pad_value (int): The padded pixel value for instance masks. + Defaults to 0. + pad_seg (bool): Whether to pad semantic segmentation maps. + Defaults to False. + seg_pad_value (int): The padded pixel value for semantic + segmentation maps. Defaults to 255. + """ + + def __init__(self, + size: Tuple[int, int], + img_pad_value: int = 0, + pad_mask: bool = False, + mask_pad_value: int = 0, + pad_seg: bool = False, + seg_pad_value: int = 255) -> None: + super().__init__() + self.size = size + self.pad_mask = pad_mask + self.pad_seg = pad_seg + self.img_pad_value = img_pad_value + self.mask_pad_value = mask_pad_value + self.seg_pad_value = seg_pad_value + + def forward( + self, + inputs: Tensor, + data_samples: Optional[List[dict]] = None + ) -> Tuple[Tensor, Optional[List[dict]]]: + """Pad image, instance masks, segmantic segmentation maps.""" + src_h, src_w = inputs.shape[-2:] + dst_h, dst_w = self.size + + if src_h >= dst_h and src_w >= dst_w: + return inputs, data_samples + + inputs = F.pad( + inputs, + pad=(0, max(0, dst_w - src_w), 0, max(0, dst_h - src_h)), + mode='constant', + value=self.img_pad_value) + + if data_samples is not None: + # update batch_input_shape + for data_sample in data_samples: + data_sample.set_metainfo({ + 'batch_input_shape': (dst_h, dst_w), + 'pad_shape': (dst_h, dst_w) + }) + + if self.pad_mask: + for data_sample in data_samples: + masks = data_sample.gt_instances.masks + data_sample.gt_instances.masks = masks.pad( + (dst_h, dst_w), pad_val=self.mask_pad_value) + + if self.pad_seg: + for data_sample in data_samples: + gt_sem_seg = data_sample.gt_sem_seg.sem_seg + h, w = gt_sem_seg.shape[-2:] + gt_sem_seg = F.pad( + gt_sem_seg, + pad=(0, max(0, dst_w - w), 0, max(0, dst_h - h)), + mode='constant', + value=self.seg_pad_value) + data_sample.gt_sem_seg = PixelData(sem_seg=gt_sem_seg) + + return inputs, data_samples + + +@MODELS.register_module() +class MultiBranchDataPreprocessor(BaseDataPreprocessor): + """DataPreprocessor wrapper for multi-branch data. + + Take semi-supervised object detection as an example, assume that + the ratio of labeled data and unlabeled data in a batch is 1:2, + `sup` indicates the branch where the labeled data is augmented, + `unsup_teacher` and `unsup_student` indicate the branches where + the unlabeled data is augmented by different pipeline. + + The input format of multi-branch data is shown as below : + + .. code-block:: none + { + 'inputs': + { + 'sup': [Tensor, None, None], + 'unsup_teacher': [None, Tensor, Tensor], + 'unsup_student': [None, Tensor, Tensor], + }, + 'data_sample': + { + 'sup': [DetDataSample, None, None], + 'unsup_teacher': [None, DetDataSample, DetDataSample], + 'unsup_student': [NOne, DetDataSample, DetDataSample], + } + } + + The format of multi-branch data + after filtering None is shown as below : + + .. code-block:: none + { + 'inputs': + { + 'sup': [Tensor], + 'unsup_teacher': [Tensor, Tensor], + 'unsup_student': [Tensor, Tensor], + }, + 'data_sample': + { + 'sup': [DetDataSample], + 'unsup_teacher': [DetDataSample, DetDataSample], + 'unsup_student': [DetDataSample, DetDataSample], + } + } + + In order to reuse `DetDataPreprocessor` for the data + from different branches, the format of multi-branch data + grouped by branch is as below : + + .. code-block:: none + { + 'sup': + { + 'inputs': [Tensor] + 'data_sample': [DetDataSample, DetDataSample] + }, + 'unsup_teacher': + { + 'inputs': [Tensor, Tensor] + 'data_sample': [DetDataSample, DetDataSample] + }, + 'unsup_student': + { + 'inputs': [Tensor, Tensor] + 'data_sample': [DetDataSample, DetDataSample] + }, + } + + After preprocessing data from different branches, + the multi-branch data needs to be reformatted as: + + .. code-block:: none + { + 'inputs': + { + 'sup': [Tensor], + 'unsup_teacher': [Tensor, Tensor], + 'unsup_student': [Tensor, Tensor], + }, + 'data_sample': + { + 'sup': [DetDataSample], + 'unsup_teacher': [DetDataSample, DetDataSample], + 'unsup_student': [DetDataSample, DetDataSample], + } + } + + Args: + data_preprocessor (:obj:`ConfigDict` or dict): Config of + :class:`DetDataPreprocessor` to process the input data. + """ + + def __init__(self, data_preprocessor: ConfigType) -> None: + super().__init__() + self.data_preprocessor = MODELS.build(data_preprocessor) + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization,padding and bgr2rgb conversion based on + ``BaseDataPreprocessor`` for multi-branch data. + + Args: + data (dict): Data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: + + - 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of + models from different branches. + - 'data_sample' (Dict[str, obj:`DetDataSample`]): The annotation + info of the sample from different branches. + """ + + if training is False: + return self.data_preprocessor(data, training) + + # Filter out branches with a value of None + for key in data.keys(): + for branch in data[key].keys(): + data[key][branch] = list( + filter(lambda x: x is not None, data[key][branch])) + + # Group data by branch + multi_branch_data = {} + for key in data.keys(): + for branch in data[key].keys(): + if multi_branch_data.get(branch, None) is None: + multi_branch_data[branch] = {key: data[key][branch]} + elif multi_branch_data[branch].get(key, None) is None: + multi_branch_data[branch][key] = data[key][branch] + else: + multi_branch_data[branch][key].append(data[key][branch]) + + # Preprocess data from different branches + for branch, _data in multi_branch_data.items(): + multi_branch_data[branch] = self.data_preprocessor(_data, training) + + # Format data by inputs and data_samples + format_data = {} + for branch in multi_branch_data.keys(): + for key in multi_branch_data[branch].keys(): + if format_data.get(key, None) is None: + format_data[key] = {branch: multi_branch_data[branch][key]} + elif format_data[key].get(branch, None) is None: + format_data[key][branch] = multi_branch_data[branch][key] + else: + format_data[key][branch].append( + multi_branch_data[branch][key]) + + return format_data + + @property + def device(self): + return self.data_preprocessor.device + + def to(self, device: Optional[Union[int, torch.device]], *args, + **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Args: + device (int or torch.device, optional): The desired device of the + parameters and buffers in this module. + + Returns: + nn.Module: The model itself. + """ + + return self.data_preprocessor.to(device, *args, **kwargs) + + def cuda(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + + return self.data_preprocessor.cuda(*args, **kwargs) + + def cpu(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + + return self.data_preprocessor.cpu(*args, **kwargs) + + +@MODELS.register_module() +class BatchResize(nn.Module): + """Batch resize during training. This implementation is modified from + https://github.com/Purkialo/CrowdDet/blob/master/lib/data/CrowdHuman.py. + + It provides the data pre-processing as follows: + - A batch of all images will pad to a uniform size and stack them into + a torch.Tensor by `DetDataPreprocessor`. + - `BatchFixShapeResize` resize all images to the target size. + - Padding images to make sure the size of image can be divisible by + ``pad_size_divisor``. + + Args: + scale (tuple): Images scales for resizing. + pad_size_divisor (int): Image size divisible factor. + Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + """ + + def __init__( + self, + scale: tuple, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + ) -> None: + super().__init__() + self.min_size = min(scale) + self.max_size = max(scale) + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + + def forward( + self, inputs: Tensor, data_samples: List[DetDataSample] + ) -> Tuple[Tensor, List[DetDataSample]]: + """resize a batch of images and bboxes.""" + + batch_height, batch_width = inputs.shape[-2:] + target_height, target_width, scale = self.get_target_size( + batch_height, batch_width) + + inputs = F.interpolate( + inputs, + size=(target_height, target_width), + mode='bilinear', + align_corners=False) + + inputs = self.get_padded_tensor(inputs, self.pad_value) + + if data_samples is not None: + batch_input_shape = tuple(inputs.size()[-2:]) + for data_sample in data_samples: + img_shape = [ + int(scale * _) for _ in list(data_sample.img_shape) + ] + data_sample.set_metainfo({ + 'img_shape': tuple(img_shape), + 'batch_input_shape': batch_input_shape, + 'pad_shape': batch_input_shape, + 'scale_factor': (scale, scale) + }) + + data_sample.gt_instances.bboxes *= scale + data_sample.ignored_instances.bboxes *= scale + + return inputs, data_samples + + def get_target_size(self, height: int, + width: int) -> Tuple[int, int, float]: + """Get the target size of a batch of images based on data and scale.""" + im_size_min = np.min([height, width]) + im_size_max = np.max([height, width]) + scale = self.min_size / im_size_min + if scale * im_size_max > self.max_size: + scale = self.max_size / im_size_max + target_height, target_width = int(round(height * scale)), int( + round(width * scale)) + return target_height, target_width, scale + + def get_padded_tensor(self, tensor: Tensor, pad_value: int) -> Tensor: + """Pad images according to pad_size_divisor.""" + assert tensor.ndim == 4 + target_height, target_width = tensor.shape[-2], tensor.shape[-1] + divisor = self.pad_size_divisor + padded_height = (target_height + divisor - 1) // divisor * divisor + padded_width = (target_width + divisor - 1) // divisor * divisor + padded_tensor = torch.ones([ + tensor.shape[0], tensor.shape[1], padded_height, padded_width + ]) * pad_value + padded_tensor = padded_tensor.type_as(tensor) + padded_tensor[:, :, :target_height, :target_width] = tensor + return padded_tensor + + +@MODELS.register_module() +class BoxInstDataPreprocessor(DetDataPreprocessor): + """Pseudo mask pre-processor for BoxInst. + + Comparing with the :class:`mmdet.DetDataPreprocessor`, + + 1. It generates masks using box annotations. + 2. It computes the images color similarity in LAB color space. + + Args: + mask_stride (int): The mask output stride in boxinst. Defaults to 4. + pairwise_size (int): The size of neighborhood for each pixel. + Defaults to 3. + pairwise_dilation (int): The dilation of neighborhood for each pixel. + Defaults to 2. + pairwise_color_thresh (float): The thresh of image color similarity. + Defaults to 0.3. + bottom_pixels_removed (int): The length of removed pixels in bottom. + It is caused by the annotation error in coco dataset. + Defaults to 10. + """ + + def __init__(self, + *arg, + mask_stride: int = 4, + pairwise_size: int = 3, + pairwise_dilation: int = 2, + pairwise_color_thresh: float = 0.3, + bottom_pixels_removed: int = 10, + **kwargs) -> None: + super().__init__(*arg, **kwargs) + self.mask_stride = mask_stride + self.pairwise_size = pairwise_size + self.pairwise_dilation = pairwise_dilation + self.pairwise_color_thresh = pairwise_color_thresh + self.bottom_pixels_removed = bottom_pixels_removed + + if skimage is None: + raise RuntimeError('skimage is not installed,\ + please install it by: pip install scikit-image') + + def get_images_color_similarity(self, inputs: Tensor, + image_masks: Tensor) -> Tensor: + """Compute the image color similarity in LAB color space.""" + assert inputs.dim() == 4 + assert inputs.size(0) == 1 + + unfolded_images = unfold_wo_center( + inputs, + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + diff = inputs[:, :, None] - unfolded_images + similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5) + + unfolded_weights = unfold_wo_center( + image_masks[None, None], + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + unfolded_weights = torch.max(unfolded_weights, dim=1)[0] + + return similarity * unfolded_weights + + def forward(self, data: dict, training: bool = False) -> dict: + """Get pseudo mask labels using color similarity.""" + det_data = super().forward(data, training) + inputs, data_samples = det_data['inputs'], det_data['data_samples'] + + if training: + # get image masks and remove bottom pixels + b_img_h, b_img_w = data_samples[0].batch_input_shape + img_masks = [] + for i in range(inputs.shape[0]): + img_h, img_w = data_samples[i].img_shape + img_mask = inputs.new_ones((img_h, img_w)) + pixels_removed = int(self.bottom_pixels_removed * + float(img_h) / float(b_img_h)) + if pixels_removed > 0: + img_mask[-pixels_removed:, :] = 0 + pad_w = b_img_w - img_w + pad_h = b_img_h - img_h + img_mask = F.pad(img_mask, (0, pad_w, 0, pad_h), 'constant', + 0.) + img_masks.append(img_mask) + img_masks = torch.stack(img_masks, dim=0) + start = int(self.mask_stride // 2) + img_masks = img_masks[:, start::self.mask_stride, + start::self.mask_stride] + + # Get origin rgb image for color similarity + ori_imgs = inputs * self.std + self.mean + downsampled_imgs = F.avg_pool2d( + ori_imgs.float(), + kernel_size=self.mask_stride, + stride=self.mask_stride, + padding=0) + + # Compute color similarity for pseudo mask generation + for im_i, data_sample in enumerate(data_samples): + # TODO: Support rgb2lab in mmengine? + images_lab = skimage.color.rgb2lab( + downsampled_imgs[im_i].byte().permute(1, 2, + 0).cpu().numpy()) + images_lab = torch.as_tensor( + images_lab, device=ori_imgs.device, dtype=torch.float32) + images_lab = images_lab.permute(2, 0, 1)[None] + images_color_similarity = self.get_images_color_similarity( + images_lab, img_masks[im_i]) + pairwise_mask = (images_color_similarity >= + self.pairwise_color_thresh).float() + + per_im_bboxes = data_sample.gt_instances.bboxes + if per_im_bboxes.shape[0] > 0: + per_im_masks = [] + for per_box in per_im_bboxes: + mask_full = torch.zeros((b_img_h, b_img_w), + device=self.device).float() + mask_full[int(per_box[1]):int(per_box[3] + 1), + int(per_box[0]):int(per_box[2] + 1)] = 1.0 + per_im_masks.append(mask_full) + per_im_masks = torch.stack(per_im_masks, dim=0) + pairwise_masks = torch.cat( + [pairwise_mask for _ in range(per_im_bboxes.shape[0])], + dim=0) + else: + per_im_masks = torch.zeros((0, b_img_h, b_img_w)) + pairwise_masks = torch.zeros( + (0, self.pairwise_size**2 - 1, b_img_h, b_img_w)) + + # TODO: Support BitmapMasks with tensor? + data_sample.gt_instances.masks = BitmapMasks( + per_im_masks.cpu().numpy(), b_img_h, b_img_w) + data_sample.gt_instances.pairwise_masks = pairwise_masks + return {'inputs': inputs, 'data_samples': data_samples} diff --git a/mmdet/models/data_preprocessors/reid_data_preprocessor.py b/mmdet/models/data_preprocessors/reid_data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0a1d45d97ba350e8845c6620f3b73f05545e61 --- /dev/null +++ b/mmdet/models/data_preprocessors/reid_data_preprocessor.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from numbers import Number +from typing import Optional, Sequence + +import torch +import torch.nn.functional as F +from mmengine.model import BaseDataPreprocessor, stack_batch + +from mmdet.registry import MODELS + +try: + import mmpretrain + from mmpretrain.models.utils.batch_augments import RandomBatchAugment + from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels, + tensor_split) +except ImportError: + mmpretrain = None + + +def stack_batch_scores(elements, device=None): + """Stack the ``score`` of a batch of :obj:`LabelData` to a tensor. + + Args: + elements (List[LabelData]): A batch of :obj`LabelData`. + device (torch.device, optional): The output device of the batch label. + Defaults to None. + Returns: + torch.Tensor: The stacked score tensor. + """ + item = elements[0] + if 'score' not in item._data_fields: + return None + + batch_score = torch.stack([element.score for element in elements]) + if device is not None: + batch_score = batch_score.to(device) + return batch_score + + +@MODELS.register_module() +class ReIDDataPreprocessor(BaseDataPreprocessor): + """Image pre-processor for classification tasks. + + Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + to_onehot (bool): Whether to generate one-hot format gt-labels and set + to data samples. Defaults to False. + num_classes (int, optional): The number of classes. Defaults to None. + batch_augments (dict, optional): The batch augmentations settings, + including "augments" and "probs". For more details, see + :class:`mmpretrain.models.RandomBatchAugment`. + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + to_onehot: bool = False, + num_classes: Optional[int] = None, + batch_augments: Optional[dict] = None): + if mmpretrain is None: + raise RuntimeError('Please run "pip install openmim" and ' + 'run "mim install mmpretrain" to ' + 'install mmpretrain first.') + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.to_onehot = to_onehot + self.num_classes = num_classes + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + if batch_augments is not None: + self.batch_augments = RandomBatchAugment(**batch_augments) + if not self.to_onehot: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().info( + 'Because batch augmentations are enabled, the data ' + 'preprocessor automatically enables the `to_onehot` ' + 'option to generate one-hot format labels.') + self.to_onehot = True + else: + self.batch_augments = None + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + inputs = self.cast_data(data['inputs']) + + if isinstance(inputs, torch.Tensor): + # The branch if use `default_collate` as the collate_fn in the + # dataloader. + + # ------ To RGB ------ + if self.to_rgb and inputs.size(1) == 3: + inputs = inputs.flip(1) + + # -- Normalization --- + inputs = inputs.float() + if self._enable_normalize: + inputs = (inputs - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = inputs.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + else: + # The branch if use `pseudo_collate` as the collate_fn in the + # dataloader. + + processed_inputs = [] + for input_ in inputs: + # ------ To RGB ------ + if self.to_rgb and input_.size(0) == 3: + input_ = input_.flip(0) + + # -- Normalization --- + input_ = input_.float() + if self._enable_normalize: + input_ = (input_ - self.mean) / self.std + + processed_inputs.append(input_) + # Combine padding and stack + inputs = stack_batch(processed_inputs, self.pad_size_divisor, + self.pad_value) + + data_samples = data.get('data_samples', None) + sample_item = data_samples[0] if data_samples is not None else None + if 'gt_label' in sample_item: + gt_labels = [sample.gt_label for sample in data_samples] + gt_labels_tensor = [gt_label.label for gt_label in gt_labels] + batch_label, label_indices = cat_batch_labels(gt_labels_tensor) + batch_label = batch_label.to(self.device) + + batch_score = stack_batch_scores(gt_labels, device=self.device) + if batch_score is None and self.to_onehot: + assert batch_label is not None, \ + 'Cannot generate onehot format labels because no labels.' + num_classes = self.num_classes or data_samples[0].get( + 'num_classes') + assert num_classes is not None, \ + 'Cannot generate one-hot format labels because not set ' \ + '`num_classes` in `data_preprocessor`.' + batch_score = batch_label_to_onehot(batch_label, label_indices, + num_classes) + + # ----- Batch Augmentations ---- + if training and self.batch_augments is not None: + inputs, batch_score = self.batch_augments(inputs, batch_score) + + # ----- scatter labels and scores to data samples --- + if batch_label is not None: + for sample, label in zip( + data_samples, tensor_split(batch_label, + label_indices)): + sample.set_gt_label(label) + if batch_score is not None: + for sample, score in zip(data_samples, batch_score): + sample.set_gt_score(score) + + return {'inputs': inputs, 'data_samples': data_samples} diff --git a/mmdet/models/data_preprocessors/track_data_preprocessor.py b/mmdet/models/data_preprocessors/track_data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..40a65b8eaebacdaddd574768fbb00e8c5a072d85 --- /dev/null +++ b/mmdet/models/data_preprocessors/track_data_preprocessor.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model.utils import stack_batch + +from mmdet.models.utils.misc import samplelist_boxtype2tensor +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample +from mmdet.structures.mask import BitmapMasks +from .data_preprocessor import DetDataPreprocessor + + +@MODELS.register_module() +class TrackDataPreprocessor(DetDataPreprocessor): + """Image pre-processor for tracking tasks. + + Accepts the data sampled by the dataloader, and preprocesses + it into the format of the model input. ``TrackDataPreprocessor`` + provides the tracking data pre-processing as follows: + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to inputs. + - Convert inputs from bgr to rgb if the shape of input is (1, 3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations during training. + - Record the information of ``batch_input_shape`` and ``pad_shape``. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B + channels. Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + pad_mask (bool): Whether to pad instance masks. Defaults to False. + mask_pad_value (int): The padded pixel value for instance masks. + Defaults to 0. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + use_det_processor: (bool): whether to use DetDataPreprocessor + in training phrase. This is mainly for some tracking models + fed into one image rather than a group of image in training. + Defaults to False. + . boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of + bboxes data to ``Tensor`` type. Defaults to True. + batch_augments (list[dict], optional): Batch-level augmentations + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + use_det_processor: bool = False, + **kwargs): + super().__init__(mean=mean, std=std, **kwargs) + self.use_det_processor = use_det_processor + if mean is not None and not self.use_det_processor: + # overwrite the ``register_bufffer`` in ``ImgDataPreprocessor`` + # since the shape of ``mean`` and ``std`` in tracking tasks must be + # (T, C, H, W), which T is the temporal length of the video. + self.register_buffer('mean', + torch.tensor(mean).view(1, -1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(1, -1, 1, 1), False) + + def forward(self, data: dict, training: bool = False) -> Dict: + """Perform normalization,padding and bgr2rgb conversion based on + ``TrackDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Tuple[Dict[str, List[torch.Tensor]], OptSampleList]: Data in the + same format as the model input. + """ + if self.use_det_processor and training: + batch_pad_shape = self._get_pad_shape(data) + else: + batch_pad_shape = self._get_track_pad_shape(data) + + data = self.cast_data(data) + imgs, data_samples = data['inputs'], data['data_samples'] + + if self.use_det_processor and training: + assert imgs[0].dim() == 3, \ + 'Only support the 3 dims when use detpreprocessor in training' + if self._channel_conversion: + imgs = [_img[[2, 1, 0], ...] for _img in imgs] + # Convert to `float` + imgs = [_img.float() for _img in imgs] + if self._enable_normalize: + imgs = [(_img - self.mean) / self.std for _img in imgs] + inputs = stack_batch(imgs, self.pad_size_divisor, self.pad_value) + else: + assert imgs[0].dim() == 4, \ + 'Only support the 4 dims when use trackprocessor in training' + # The shape of imgs[0] is (T, C, H, W). + channel = imgs[0].size(1) + if self._channel_conversion and channel == 3: + imgs = [_img[:, [2, 1, 0], ...] for _img in imgs] + # change to `float` + imgs = [_img.float() for _img in imgs] + if self._enable_normalize: + imgs = [(_img - self.mean) / self.std for _img in imgs] + inputs = stack_track_batch(imgs, self.pad_size_divisor, + self.pad_value) + + if data_samples is not None: + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + batch_input_shape = tuple(inputs.size()[-2:]) + if self.use_det_processor and training: + for data_sample, pad_shape in zip(data_samples, + batch_pad_shape): + data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'pad_shape': pad_shape + }) + if self.boxtype2tensor: + samplelist_boxtype2tensor(data_samples) + if self.pad_mask: + self.pad_gt_masks(data_samples) + else: + for track_data_sample, pad_shapes in zip( + data_samples, batch_pad_shape): + for i in range(len(track_data_sample)): + det_data_sample = track_data_sample[i] + det_data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'pad_shape': pad_shapes[i] + }) + if self.pad_mask and training: + self.pad_track_gt_masks(data_samples) + + if training and self.batch_augments is not None: + for batch_aug in self.batch_augments: + if self.use_det_processor and training: + inputs, data_samples = batch_aug(inputs, data_samples) + else: + # we only support T==1 when using batch augments. + # Only yolox need batch_aug, and yolox can only process + # (N, C, H, W) shape. + # The shape of `inputs` is (N, T, C, H, W), hence, we use + # inputs[:, 0] to change the shape to (N, C, H, W). + assert inputs.size(1) == 1 and len( + data_samples[0] + ) == 1, 'Only support the number of sequence images equals to 1 when using batch augment.' # noqa: E501 + det_data_samples = [ + track_data_sample[0] + for track_data_sample in data_samples + ] + aug_inputs, aug_det_samples = batch_aug( + inputs[:, 0], det_data_samples) + inputs = aug_inputs.unsqueeze(1) + for track_data_sample, det_sample in zip( + data_samples, aug_det_samples): + track_data_sample.video_data_samples = [det_sample] + + # Note: inputs may contain large number of frames, so we must make + # sure that the mmeory is contiguous for stable forward + inputs = inputs.contiguous() + + return dict(inputs=inputs, data_samples=data_samples) + + def _get_track_pad_shape(self, data: dict) -> Dict[str, List]: + """Get the pad_shape of each image based on data and pad_size_divisor. + + Args: + data (dict): Data sampled from dataloader. + + Returns: + Dict[str, List]: The shape of padding. + """ + batch_pad_shape = dict() + batch_pad_shape = [] + for imgs in data['inputs']: + # The sequence images in one sample among a batch have the same + # original shape + pad_h = int(np.ceil(imgs.shape[-2] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int(np.ceil(imgs.shape[-1] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_shapes = [(pad_h, pad_w)] * imgs.size(0) + batch_pad_shape.append(pad_shapes) + return batch_pad_shape + + def pad_track_gt_masks(self, + data_samples: Sequence[TrackDataSample]) -> None: + """Pad gt_masks to shape of batch_input_shape.""" + if 'masks' in data_samples[0][0].get('gt_instances', None): + for track_data_sample in data_samples: + for i in range(len(track_data_sample)): + det_data_sample = track_data_sample[i] + masks = det_data_sample.gt_instances.masks + # TODO: whether to use BitmapMasks + assert isinstance(masks, BitmapMasks) + batch_input_shape = det_data_sample.batch_input_shape + det_data_sample.gt_instances.masks = masks.pad( + batch_input_shape, pad_val=self.mask_pad_value) + + +def stack_track_batch(tensors: List[torch.Tensor], + pad_size_divisor: int = 0, + pad_value: Union[int, float] = 0) -> torch.Tensor: + """Stack multiple tensors to form a batch and pad the images to the max + shape use the right bottom padding mode in these images. If + ``pad_size_divisor > 0``, add padding to ensure the common height and width + is divisible by ``pad_size_divisor``. The difference between this function + and ``stack_batch`` in MMEngine is that this function can process batch + sequence images with shape (N, T, C, H, W). + + Args: + tensors (List[Tensor]): The input multiple tensors. each is a + TCHW 4D-tensor. T denotes the number of key/reference frames. + pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding + to ensure the common height and width is divisible by + ``pad_size_divisor``. This depends on the model, and many + models need a divisibility of 32. Defaults to 0 + pad_value (int, float): The padding value. Defaults to 0 + + Returns: + Tensor: The NTCHW 5D-tensor. N denotes the batch size. + """ + assert isinstance(tensors, list), \ + f'Expected input type to be list, but got {type(tensors)}' + assert len(set([tensor.ndim for tensor in tensors])) == 1, \ + f'Expected the dimensions of all tensors must be the same, ' \ + f'but got {[tensor.ndim for tensor in tensors]}' + assert tensors[0].ndim == 4, f'Expected tensor dimension to be 4, ' \ + f'but got {tensors[0].ndim}' + assert len(set([tensor.shape[0] for tensor in tensors])) == 1, \ + f'Expected the channels of all tensors must be the same, ' \ + f'but got {[tensor.shape[0] for tensor in tensors]}' + + tensor_sizes = [(tensor.shape[-2], tensor.shape[-1]) for tensor in tensors] + max_size = np.stack(tensor_sizes).max(0) + + if pad_size_divisor > 1: + # the last two dims are H,W, both subject to divisibility requirement + max_size = ( + max_size + + (pad_size_divisor - 1)) // pad_size_divisor * pad_size_divisor + + padded_samples = [] + for tensor in tensors: + padding_size = [ + 0, max_size[-1] - tensor.shape[-1], 0, + max_size[-2] - tensor.shape[-2] + ] + if sum(padding_size) == 0: + padded_samples.append(tensor) + else: + padded_samples.append(F.pad(tensor, padding_size, value=pad_value)) + + return torch.stack(padded_samples, dim=0) diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b55ec2a4230a741e9a2c696ec434bf9cc8bafa --- /dev/null +++ b/mmdet/models/dense_heads/__init__.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .anchor_free_head import AnchorFreeHead +from .anchor_head import AnchorHead +from .atss_head import ATSSHead +from .atss_vlfusion_head import ATSSVLFusionHead +from .autoassign_head import AutoAssignHead +from .boxinst_head import BoxInstBboxHead, BoxInstMaskHead +from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead +from .centernet_head import CenterNetHead +from .centernet_update_head import CenterNetUpdateHead +from .centripetal_head import CentripetalHead +from .condinst_head import CondInstBboxHead, CondInstMaskHead +from .conditional_detr_head import ConditionalDETRHead +from .corner_head import CornerHead +from .dab_detr_head import DABDETRHead +from .ddod_head import DDODHead +from .ddq_detr_head import DDQDETRHead +from .deformable_detr_head import DeformableDETRHead +from .detr_head import DETRHead +from .dino_head import DINOHead +from .embedding_rpn_head import EmbeddingRPNHead +from .fcos_head import FCOSHead +from .fovea_head import FoveaHead +from .free_anchor_retina_head import FreeAnchorRetinaHead +from .fsaf_head import FSAFHead +from .ga_retina_head import GARetinaHead +from .ga_rpn_head import GARPNHead +from .gfl_head import GFLHead +from .grounding_dino_head import GroundingDINOHead +from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead +from .lad_head import LADHead +from .ld_head import LDHead +from .mask2former_head import Mask2FormerHead +from .maskformer_head import MaskFormerHead +from .nasfcos_head import NASFCOSHead +from .paa_head import PAAHead +from .pisa_retinanet_head import PISARetinaHead +from .pisa_ssd_head import PISASSDHead +from .reppoints_head import RepPointsHead +from .retina_head import RetinaHead +from .retina_sepbn_head import RetinaSepBNHead +from .rpn_head import RPNHead +from .rtmdet_head import RTMDetHead, RTMDetSepBNHead +from .rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead +from .sabl_retina_head import SABLRetinaHead +from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead +from .solov2_head import SOLOV2Head +from .ssd_head import SSDHead +from .tood_head import TOODHead +from .vfnet_head import VFNetHead +from .yolact_head import YOLACTHead, YOLACTProtonet +from .yolo_head import YOLOV3Head +from .yolof_head import YOLOFHead +from .yolox_head import YOLOXHead + +__all__ = [ + 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', + 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', + 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', + 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', + 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead', + 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'SABLRetinaHead', + 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', 'CascadeRPNHead', + 'EmbeddingRPNHead', 'LDHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', + 'DeformableDETRHead', 'CenterNetHead', 'YOLOXHead', 'SOLOHead', + 'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead', + 'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead', + 'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead', + 'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead', + 'BoxInstBboxHead', 'BoxInstMaskHead', 'ConditionalDETRHead', 'DINOHead', + 'ATSSVLFusionHead', 'DABDETRHead', 'DDQDETRHead', 'GroundingDINOHead' +] diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py new file mode 100644 index 0000000000000000000000000000000000000000..90a9b3625b8fef12a2ee3a964c89597b597cb2ec --- /dev/null +++ b/mmdet/models/dense_heads/anchor_free_head.py @@ -0,0 +1,317 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Any, List, Sequence, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule +from numpy import ndarray +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList) +from ..task_modules.prior_generators import MlvlPointGenerator +from ..utils import multi_apply +from .base_dense_head import BaseDenseHead + +StrideType = Union[Sequence[int], Sequence[Tuple[int, int]]] + + +@MODELS.register_module() +class AnchorFreeHead(BaseDenseHead): + """Anchor-free head (FCOS, Fovea, RepPoints, etc.). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + stacked_convs (int): Number of stacking convs of the head. + strides (Sequence[int] or Sequence[Tuple[int, int]]): Downsample + factor of each feature map. + dcn_on_last_conv (bool): If true, use dcn in the last layer of + towers. Defaults to False. + conv_bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias of conv will be set as True if `norm_cfg` is + None, otherwise False. Default: "auto". + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. + bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults + 'DistancePointBBoxCoder'. + conv_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for + normalization layer. Defaults to None. + train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of + anchor-free head. + test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of + anchor-free head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ # noqa: W605 + + _version = 1 + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 4, + strides: StrideType = (4, 8, 16, 32, 64), + dcn_on_last_conv: bool = False, + conv_bias: Union[bool, str] = 'auto', + loss_cls: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox: ConfigType = dict(type='IoULoss', loss_weight=1.0), + bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.strides = strides + self.dcn_on_last_conv = dcn_on_last_conv + assert conv_bias == 'auto' or isinstance(conv_bias, bool) + self.conv_bias = conv_bias + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.bbox_coder = TASK_UTILS.build(bbox_coder) + + self.prior_generator = MlvlPointGenerator(strides) + + # In order to keep a more general interface and be consistent with + # anchor_head. We can think of point like one anchor + self.num_base_priors = self.prior_generator.num_base_priors[0] + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.fp16_enabled = False + + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self._init_cls_convs() + self._init_reg_convs() + self._init_predictor() + + def _init_cls_convs(self) -> None: + """Initialize classification conv layers of the head.""" + self.cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + if self.dcn_on_last_conv and i == self.stacked_convs - 1: + conv_cfg = dict(type='DCNv2') + else: + conv_cfg = self.conv_cfg + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias)) + + def _init_reg_convs(self) -> None: + """Initialize bbox regression conv layers of the head.""" + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + if self.dcn_on_last_conv and i == self.stacked_convs - 1: + conv_cfg = dict(type='DCNv2') + else: + conv_cfg = self.conv_cfg + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias)) + + def _init_predictor(self) -> None: + """Initialize predictor layers of the head.""" + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + + def _load_from_state_dict(self, state_dict: dict, prefix: str, + local_metadata: dict, strict: bool, + missing_keys: Union[List[str], str], + unexpected_keys: Union[List[str], str], + error_msgs: Union[List[str], str]) -> None: + """Hack some keys of the model state dict so that can load checkpoints + of previous version.""" + version = local_metadata.get('version', None) + if version is None: + # the key is different in early versions + # for example, 'fcos_cls' become 'conv_cls' now + bbox_head_keys = [ + k for k in state_dict.keys() if k.startswith(prefix) + ] + ori_predictor_keys = [] + new_predictor_keys = [] + # e.g. 'fcos_cls' or 'fcos_reg' + for key in bbox_head_keys: + ori_predictor_keys.append(key) + key = key.split('.') + if len(key) < 2: + conv_name = None + elif key[1].endswith('cls'): + conv_name = 'conv_cls' + elif key[1].endswith('reg'): + conv_name = 'conv_reg' + elif key[1].endswith('centerness'): + conv_name = 'conv_centerness' + else: + conv_name = None + if conv_name is not None: + key[1] = conv_name + new_predictor_keys.append('.'.join(key)) + else: + ori_predictor_keys.pop(-1) + for i in range(len(new_predictor_keys)): + state_dict[new_predictor_keys[i]] = state_dict.pop( + ori_predictor_keys[i]) + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually contain classification scores and bbox predictions. + + - cls_scores (list[Tensor]): Box scores for each scale level, \ + each is a 4D-tensor, the channel number is \ + num_points * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for each scale \ + level, each is a 4D-tensor, the channel number is num_points * 4. + """ + return multi_apply(self.forward_single, x)[:2] + + def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + + Returns: + tuple: Scores for each class, bbox predictions, features + after classification and regression conv layers, some + models needs these features like FCOS. + """ + cls_feat = x + reg_feat = x + + for cls_layer in self.cls_convs: + cls_feat = cls_layer(cls_feat) + cls_score = self.conv_cls(cls_feat) + + for reg_layer in self.reg_convs: + reg_feat = reg_layer(reg_feat) + bbox_pred = self.conv_reg(reg_feat) + return cls_score, bbox_pred, cls_feat, reg_feat + + @abstractmethod + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * 4. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + """ + + raise NotImplementedError + + @abstractmethod + def get_targets(self, points: List[Tensor], + batch_gt_instances: InstanceList) -> Any: + """Compute regression, classification and centerness targets for points + in multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + """ + raise NotImplementedError + + # TODO refactor aug_test + def aug_test(self, + aug_batch_feats: List[Tensor], + aug_batch_img_metas: List[List[Tensor]], + rescale: bool = False) -> List[ndarray]: + """Test function with test time augmentation. + + Args: + aug_batch_feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + aug_batch_img_metas (list[list[dict]]): the outer list indicates + test-time augs (multiscale, flip, etc.) and the inner list + indicates images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[ndarray]: bbox results of each class + """ + return self.aug_test_bboxes( + aug_batch_feats, aug_batch_img_metas, rescale=rescale) diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4578caca818550397875a0df34c128f461e6ec75 --- /dev/null +++ b/mmdet/models/dense_heads/anchor_head.py @@ -0,0 +1,530 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, OptMultiConfig) +from ..task_modules.prior_generators import (AnchorGenerator, + anchor_inside_flags) +from ..task_modules.samplers import PseudoSampler +from ..utils import images_to_levels, multi_apply, unmap +from .base_dense_head import BaseDenseHead + + +@MODELS.register_module() +class AnchorHead(BaseDenseHead): + """Anchor-based head (RPN, RetinaNet, SSD, etc.). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + anchor_generator (dict): Config dict for anchor generator + bbox_coder (dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Default False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + loss_cls (dict): Config of classification loss. + loss_bbox (dict): Config of localization loss. + train_cfg (dict): Training config of anchor head. + test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + anchor_generator: ConfigType = dict( + type='AnchorGenerator', + scales=[8, 16, 32], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder: ConfigType = dict( + type='DeltaXYWHBBoxCoder', + clip_border=True, + target_means=(.0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0)), + reg_decoded_bbox: bool = False, + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = dict( + type='Normal', layer='Conv2d', std=0.01) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_classes = num_classes + self.feat_channels = feat_channels + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + if self.cls_out_channels <= 0: + raise ValueError(f'num_classes={num_classes} is too small') + self.reg_decoded_bbox = reg_decoded_bbox + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + if train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) + + self.fp16_enabled = False + + self.prior_generator = TASK_UTILS.build(anchor_generator) + + # Usually the numbers of anchors for each level are the same + # except SSD detectors. So it is an int in the most dense + # heads but a list of int in SSDHead + self.num_base_priors = self.prior_generator.num_base_priors[0] + self._init_layers() + + @property + def num_anchors(self) -> int: + warnings.warn('DeprecationWarning: `num_anchors` is deprecated, ' + 'for consistency or also use ' + '`num_base_priors` instead') + return self.prior_generator.num_base_priors[0] + + @property + def anchor_generator(self) -> AnchorGenerator: + warnings.warn('DeprecationWarning: anchor_generator is deprecated, ' + 'please use "prior_generator" instead') + return self.prior_generator + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.conv_cls = nn.Conv2d(self.in_channels, + self.num_base_priors * self.cls_out_channels, + 1) + reg_dim = self.bbox_coder.encode_size + self.conv_reg = nn.Conv2d(self.in_channels, + self.num_base_priors * reg_dim, 1) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level \ + the channels number is num_base_priors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale \ + level, the channels number is num_base_priors * 4. + """ + cls_score = self.conv_cls(x) + bbox_pred = self.conv_reg(x) + return cls_score, bbox_pred + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and bbox prediction. + + - cls_scores (list[Tensor]): Classification scores for all \ + scale levels, each is a 4D-tensor, the channels number \ + is num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all \ + scale levels, each is a 4D-tensor, the channels number \ + is num_base_priors * 4. + """ + return multi_apply(self.forward_single, x) + + def get_anchors(self, + featmap_sizes: List[tuple], + batch_img_metas: List[dict], + device: Union[torch.device, str] = 'cuda') \ + -> Tuple[List[List[Tensor]], List[List[Tensor]]]: + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + batch_img_metas (list[dict]): Image meta info. + device (torch.device | str): Device for returned tensors. + Defaults to cuda. + + Returns: + tuple: + + - anchor_list (list[list[Tensor]]): Anchors of each image. + - valid_flag_list (list[list[Tensor]]): Valid flags of each + image. + """ + num_imgs = len(batch_img_metas) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = self.prior_generator.grid_priors( + featmap_sizes, device=device) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + multi_level_flags = self.prior_generator.valid_flags( + featmap_sizes, img_meta['pad_shape'], device) + valid_flag_list.append(multi_level_flags) + + return anchor_list, valid_flag_list + + def _get_targets_single(self, + flat_anchors: Union[Tensor, BaseBoxes], + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + Args: + flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors + of the image, which are concatenated into a single tensor + or box type of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors, ). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: + + - labels (Tensor): Labels of each level. + - label_weights (Tensor): Label weights of each level. + - bbox_targets (Tensor): BBox targets of each level. + - bbox_weights (Tensor): BBox weights of each level. + - pos_inds (Tensor): positive samples indexes. + - neg_inds (Tensor): negative samples indexes. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + anchors = flat_anchors[inside_flags] + + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + # No sampling is required except for RPN and + # Guided Anchoring algorithms + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox \ + else self.bbox_coder.encode_size + bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim) + bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim) + + # TODO: Considering saving memory, is it necessary to be long? + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + # `bbox_coder.encode` accepts tensor or box type inputs and generates + # tensor targets. If regressing decoded boxes, the code will convert + # box type `pos_bbox_targets` to tensor. + if len(pos_inds) > 0: + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_priors, sampling_result.pos_gt_bboxes) + else: + pos_bbox_targets = sampling_result.pos_gt_bboxes + pos_bbox_targets = get_box_tensor(pos_bbox_targets) + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, sampling_result) + + def get_targets(self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True, + return_sampling_results: bool = False) -> tuple: + """Compute regression and classification targets for anchors in + multiple images. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - bbox_weights_list (list[Tensor]): BBox weights of each level. + - avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors to a single tensor + concat_anchor_list = [] + concat_valid_flag_list = [] + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + concat_anchor_list.append(cat_boxes(anchor_list[i])) + concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) + + # compute targets for each image + results = multi_apply( + self._get_targets_single, + concat_anchor_list, + concat_valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, + pos_inds_list, neg_inds_list, sampling_results_list) = results[:7] + rest_results = list(results[7:]) # user-added return values + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # update `_raw_positive_infos`, which will be used when calling + # `get_positive_infos`. + self._raw_positive_infos.update(sampling_results=sampling_results_list) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + res = (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) + if return_sampling_results: + res = res + (sampling_results_list, ) + for i, r in enumerate(rest_results): # user-added return values + rest_results[i] = images_to_levels(r, num_level_anchors) + + return res + tuple(rest_results) + + def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + anchors: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, avg_factor: int) -> tuple: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor + weight shape (N, num_total_anchors, 4). + bbox_weights (Tensor): BBox regression loss weights of each anchor + with shape (N, num_total_anchors, 4). + avg_factor (int): Average factor that is used to average the loss. + + Returns: + tuple: loss components. + """ + # classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + # regression loss + target_dim = bbox_targets.size(-1) + bbox_targets = bbox_targets.reshape(-1, target_dim) + bbox_weights = bbox_weights.reshape(-1, target_dim) + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(-1, + self.bbox_coder.encode_size) + if self.reg_decoded_bbox: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, it + # decodes the already encoded coordinates to absolute format. + anchors = anchors.reshape(-1, anchors.size(-1)) + bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) + bbox_pred = get_box_tensor(bbox_pred) + loss_bbox = self.loss_bbox( + bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) + return loss_cls, loss_bbox + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor) = cls_reg_targets + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i in range(len(anchor_list)): + concat_anchor_list.append(cat_boxes(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + + losses_cls, losses_bbox = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + all_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + avg_factor=avg_factor) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) diff --git a/mmdet/models/dense_heads/atss_head.py b/mmdet/models/dense_heads/atss_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce71b3eff5e0ed624ec7ae16e8db80c90e8ffa1 --- /dev/null +++ b/mmdet/models/dense_heads/atss_head.py @@ -0,0 +1,524 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Scale +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList, reduce_mean) +from ..task_modules.prior_generators import anchor_inside_flags +from ..utils import images_to_levels, multi_apply, unmap +from .anchor_head import AnchorHead + + +@MODELS.register_module() +class ATSSHead(AnchorHead): + """Detection Head of `ATSS `_. + + ATSS head structure is similar with FCOS, however ATSS use anchor boxes + and assign label by Adaptive Training Sample Selection instead max-iou. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + pred_kernel_size (int): Kernel size of ``nn.Conv2d`` + stacked_convs (int): Number of stacking convs of the head. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to ``dict(type='GN', num_groups=32, + requires_grad=True)``. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Defaults to False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + loss_centerness (:obj:`ConfigDict` or dict): Config of centerness loss. + Defaults to ``dict(type='CrossEntropyLoss', use_sigmoid=True, + loss_weight=1.0)``. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + pred_kernel_size: int = 3, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + reg_decoded_bbox: bool = True, + loss_centerness: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='atss_cls', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + self.pred_kernel_size = pred_kernel_size + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + reg_decoded_bbox=reg_decoded_bbox, + init_cfg=init_cfg, + **kwargs) + + self.sampling = False + self.loss_centerness = MODELS.build(loss_centerness) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + pred_pad_size = self.pred_kernel_size // 2 + self.atss_cls = nn.Conv2d( + self.feat_channels, + self.num_anchors * self.cls_out_channels, + self.pred_kernel_size, + padding=pred_pad_size) + self.atss_reg = nn.Conv2d( + self.feat_channels, + self.num_base_priors * 4, + self.pred_kernel_size, + padding=pred_pad_size) + self.atss_centerness = nn.Conv2d( + self.feat_channels, + self.num_base_priors * 1, + self.pred_kernel_size, + padding=pred_pad_size) + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * 4. + """ + return multi_apply(self.forward_single, x, self.scales) + + def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale + level, the channels number is num_anchors * 4. + centerness (Tensor): Centerness for a single scale level, the + channel number is (N, num_anchors * 1, H, W). + """ + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.atss_cls(cls_feat) + # we just follow atss, not apply exp in bbox_pred + bbox_pred = scale(self.atss_reg(reg_feat)).float() + centerness = self.atss_centerness(reg_feat) + return cls_score, bbox_pred, centerness + + def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, centerness: Tensor, + labels: Tensor, label_weights: Tensor, + bbox_targets: Tensor, avg_factor: float) -> dict: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + avg_factor (float): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + centerness = centerness.permute(0, 2, 3, 1).reshape(-1) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # classification loss + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_centerness = centerness[pos_inds] + + centerness_targets = self.centerness_target( + pos_anchors, pos_bbox_targets) + pos_decode_bbox_pred = self.bbox_coder.decode( + pos_anchors, pos_bbox_pred) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_bbox_targets, + weight=centerness_targets, + avg_factor=1.0) + + # centerness loss + loss_centerness = self.loss_centerness( + pos_centerness, centerness_targets, avg_factor=avg_factor) + + else: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + centernesses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + centernesses (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_reg_targets + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + losses_cls, losses_bbox, loss_centerness, \ + bbox_avg_factor = multi_apply( + self.loss_by_feat_single, + anchor_list, + cls_scores, + bbox_preds, + centernesses, + labels_list, + label_weights_list, + bbox_targets_list, + avg_factor=avg_factor) + + bbox_avg_factor = sum(bbox_avg_factor) + bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_centerness=loss_centerness) + + def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: + """Calculate the centerness between anchors and gts. + + Only calculate pos centerness targets, otherwise there may be nan. + + Args: + anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. + gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Centerness between anchors and gts. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l_ = anchors_cx - gts[:, 0] + t_ = anchors_cy - gts[:, 1] + r_ = gts[:, 2] - anchors_cx + b_ = gts[:, 3] - anchors_cy + + left_right = torch.stack([l_, r_], dim=1) + top_bottom = torch.stack([t_, b_], dim=1) + centerness = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) + assert not torch.isnan(centerness).any() + return centerness + + def get_targets(self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Get targets for ATSS head. + + This method is almost the same as `AnchorHead.get_targets()`. Besides + returning the targets as the parent method does, it also returns the + anchors as the first element of the returned tuple. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, avg_factor) + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + num_level_anchors: List[int], + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors (List[int]): Number of anchors of each scale + level. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4) + pos_inds (Tensor): Indices of positive anchor with shape + (num_pos,). + neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + sampling_result (:obj:`SamplingResult`): Sampling results. + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign(pred_instances, + num_level_anchors_inside, + gt_instances, gt_instances_ignore) + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if self.reg_decoded_bbox: + pos_bbox_targets = sampling_result.pos_gt_bboxes + else: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_priors, sampling_result.pos_gt_bboxes) + + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds, sampling_result) + + def get_num_level_anchors_inside(self, num_level_anchors, inside_flags): + """Get the number of valid anchors in every level.""" + + split_inside_flags = torch.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside diff --git a/mmdet/models/dense_heads/atss_vlfusion_head.py b/mmdet/models/dense_heads/atss_vlfusion_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cd28b4a040ba447130aed07629f6312f95dcf3 --- /dev/null +++ b/mmdet/models/dense_heads/atss_vlfusion_head.py @@ -0,0 +1,949 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Scale +from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d +from mmengine.config import ConfigDict +from mmengine.model import BaseModel +from mmengine.structures import InstanceData +from torch import Tensor + +try: + from transformers import BertConfig +except ImportError: + BertConfig = None + +from mmdet.registry import MODELS +from mmdet.structures.bbox import cat_boxes +from mmdet.utils import InstanceList, OptInstanceList, reduce_mean +from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk, + permute_and_flatten, select_single_mlvl, + unpack_gt_instances) +from ..utils.vlfuse_helper import MAX_CLAMP_VALUE +from .atss_head import ATSSHead + + +def convert_grounding_to_cls_scores(logits: Tensor, + positive_maps: List[dict]) -> Tensor: + """Convert logits to class scores.""" + assert len(positive_maps) == logits.shape[0] # batch size + + scores = torch.zeros(logits.shape[0], logits.shape[1], + len(positive_maps[0])).to(logits.device) + if positive_maps is not None: + if all(x == positive_maps[0] for x in positive_maps): + # only need to compute once + positive_map = positive_maps[0] + for label_j in positive_map: + scores[:, :, label_j - + 1] = logits[:, :, + torch.LongTensor(positive_map[label_j] + )].mean(-1) + else: + for i, positive_map in enumerate(positive_maps): + for label_j in positive_map: + scores[i, :, label_j - 1] = logits[ + i, :, torch.LongTensor(positive_map[label_j])].mean(-1) + return scores + + +class Conv3x3Norm(nn.Module): + """Conv3x3 and norm.""" + + def __init__(self, + in_channels: int, + out_channels: int, + stride: int, + groups: int = 1, + use_dcn: bool = False, + norm_type: Optional[Union[Sequence, str]] = None): + super().__init__() + + if use_dcn: + self.conv = ModulatedDeformConv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + groups=groups) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + groups=groups) + + if isinstance(norm_type, Sequence): + assert len(norm_type) == 2 + assert norm_type[0] == 'gn' + gn_group = norm_type[1] + norm_type = norm_type[0] + + if norm_type == 'bn': + bn_op = nn.BatchNorm2d(out_channels) + elif norm_type == 'gn': + bn_op = nn.GroupNorm( + num_groups=gn_group, num_channels=out_channels) + if norm_type is not None: + self.bn = bn_op + else: + self.bn = None + + def forward(self, x, **kwargs): + x = self.conv(x, **kwargs) + if self.bn: + x = self.bn(x) + return x + + +class DyReLU(nn.Module): + """Dynamic ReLU.""" + + def __init__(self, + in_channels: int, + out_channels: int, + expand_ratio: int = 4): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.expand_ratio = expand_ratio + self.out_channels = out_channels + + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels // expand_ratio), + nn.ReLU(inplace=True), + nn.Linear(in_channels // expand_ratio, + out_channels * self.expand_ratio), + nn.Hardsigmoid(inplace=True)) + + def forward(self, x) -> Tensor: + x_out = x + b, c, h, w = x.size() + x = self.avg_pool(x).view(b, c) + x = self.fc(x).view(b, -1, 1, 1) + + a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1) + a1 = (a1 - 0.5) * 2 + 1.0 + a2 = (a2 - 0.5) * 2 + b1 = b1 - 0.5 + b2 = b2 - 0.5 + out = torch.max(x_out * a1 + b1, x_out * a2 + b2) + return out + + +class DyConv(nn.Module): + """Dynamic Convolution.""" + + def __init__(self, + conv_func: Callable, + in_channels: int, + out_channels: int, + use_dyfuse: bool = True, + use_dyrelu: bool = False, + use_dcn: bool = False): + super().__init__() + + self.dyconvs = nn.ModuleList() + self.dyconvs.append(conv_func(in_channels, out_channels, 1)) + self.dyconvs.append(conv_func(in_channels, out_channels, 1)) + self.dyconvs.append(conv_func(in_channels, out_channels, 2)) + + if use_dyfuse: + self.attnconv = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, 1, kernel_size=1), + nn.ReLU(inplace=True)) + self.h_sigmoid = nn.Hardsigmoid(inplace=True) + else: + self.attnconv = None + + if use_dyrelu: + self.relu = DyReLU(in_channels, out_channels) + else: + self.relu = nn.ReLU() + + if use_dcn: + self.offset = nn.Conv2d( + in_channels, 27, kernel_size=3, stride=1, padding=1) + else: + self.offset = None + + self.init_weights() + + def init_weights(self): + for m in self.dyconvs.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + if self.attnconv is not None: + for m in self.attnconv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, inputs: dict) -> dict: + visual_feats = inputs['visual'] + + out_vis_feats = [] + for level, feature in enumerate(visual_feats): + + offset_conv_args = {} + if self.offset is not None: + offset_mask = self.offset(feature) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, 18:, :, :].sigmoid() + offset_conv_args = dict(offset=offset, mask=mask) + + temp_feats = [self.dyconvs[1](feature, **offset_conv_args)] + + if level > 0: + temp_feats.append(self.dyconvs[2](visual_feats[level - 1], + **offset_conv_args)) + if level < len(visual_feats) - 1: + temp_feats.append( + F.upsample_bilinear( + self.dyconvs[0](visual_feats[level + 1], + **offset_conv_args), + size=[feature.size(2), + feature.size(3)])) + mean_feats = torch.mean( + torch.stack(temp_feats), dim=0, keepdim=False) + + if self.attnconv is not None: + attn_feat = [] + res_feat = [] + for feat in temp_feats: + res_feat.append(feat) + attn_feat.append(self.attnconv(feat)) + + res_feat = torch.stack(res_feat) + spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat)) + + mean_feats = torch.mean( + res_feat * spa_pyr_attn, dim=0, keepdim=False) + + out_vis_feats.append(mean_feats) + + out_vis_feats = [self.relu(item) for item in out_vis_feats] + + features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']} + + return features_dict + + +class VLFusionModule(BaseModel): + """Visual-lang Fusion Module.""" + + def __init__(self, + in_channels: int, + feat_channels: int, + num_base_priors: int, + early_fuse: bool = False, + num_dyhead_blocks: int = 6, + lang_model_name: str = 'bert-base-uncased', + use_dyrelu: bool = True, + use_dyfuse: bool = True, + use_dcn: bool = True, + use_checkpoint: bool = False, + **kwargs) -> None: + super().__init__(**kwargs) + if BertConfig is None: + raise RuntimeError( + 'transformers is not installed, please install it by: ' + 'pip install transformers.') + self.in_channels = in_channels + self.feat_channels = feat_channels + self.num_base_priors = num_base_priors + self.early_fuse = early_fuse + self.num_dyhead_blocks = num_dyhead_blocks + self.use_dyrelu = use_dyrelu + self.use_dyfuse = use_dyfuse + self.use_dcn = use_dcn + self.use_checkpoint = use_checkpoint + + self.lang_cfg = BertConfig.from_pretrained(lang_model_name) + self.lang_dim = self.lang_cfg.hidden_size + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the model.""" + bias_value = -math.log((1 - 0.01) / 0.01) + + dyhead_tower = [] + for i in range(self.num_dyhead_blocks): + if self.early_fuse: + # cross-modality fusion + dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint)) + # lang branch + dyhead_tower.append( + BertEncoderLayer( + self.lang_cfg, + clamp_min_for_underflow=True, + clamp_max_for_overflow=True)) + + # vision branch + dyhead_tower.append( + DyConv( + lambda i, o, s: Conv3x3Norm( + i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]), + self.in_channels if i == 0 else self.feat_channels, + self.feat_channels, + use_dyrelu=(self.use_dyrelu + and self.in_channels == self.feat_channels) + if i == 0 else self.use_dyrelu, + use_dyfuse=(self.use_dyfuse + and self.in_channels == self.feat_channels) + if i == 0 else self.use_dyfuse, + use_dcn=(self.use_dcn + and self.in_channels == self.feat_channels) + if i == 0 else self.use_dcn, + )) + + self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) + + self.bbox_pred = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, kernel_size=1) + self.centerness = nn.Conv2d( + self.feat_channels, self.num_base_priors * 1, kernel_size=1) + self.dot_product_projection_text = nn.Linear( + self.lang_dim, + self.num_base_priors * self.feat_channels, + bias=True) + self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True) + self.bias_lang = nn.Parameter( + torch.zeros(self.lang_dim), requires_grad=True) + self.bias0 = nn.Parameter( + torch.Tensor([bias_value]), requires_grad=True) + self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)]) + + def forward(self, visual_feats: Tuple[Tensor], + language_feats: dict) -> Tuple: + feat_inputs = {'visual': visual_feats, 'lang': language_feats} + dyhead_tower = self.dyhead_tower(feat_inputs) + + if self.early_fuse: + embedding = dyhead_tower['lang']['hidden'] + else: + embedding = language_feats['embedded'] + + embedding = F.normalize(embedding, p=2, dim=-1) + dot_product_proj_tokens = self.dot_product_projection_text(embedding / + 2.0) + dot_product_proj_tokens_bias = torch.matmul( + embedding, self.bias_lang) + self.bias0 + + bbox_preds = [] + centerness = [] + cls_logits = [] + + for i, feature in enumerate(visual_feats): + visual = dyhead_tower['visual'][i] + B, C, H, W = visual.shape + + bbox_pred = self.scales[i](self.bbox_pred(visual)) + bbox_preds.append(bbox_pred) + centerness.append(self.centerness(visual)) + + dot_product_proj_queries = permute_and_flatten( + visual, B, self.num_base_priors, C, H, W) + + bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat( + 1, self.num_base_priors, 1) + dot_product_logit = ( + torch.matmul(dot_product_proj_queries, + dot_product_proj_tokens.transpose(-1, -2)) / + self.log_scale.exp()) + bias + dot_product_logit = torch.clamp( + dot_product_logit, max=MAX_CLAMP_VALUE) + dot_product_logit = torch.clamp( + dot_product_logit, min=-MAX_CLAMP_VALUE) + cls_logits.append(dot_product_logit) + + return bbox_preds, centerness, cls_logits + + +@MODELS.register_module() +class ATSSVLFusionHead(ATSSHead): + """ATSS head with visual-language fusion module. + + Args: + early_fuse (bool): Whether to fuse visual and language features + Defaults to False. + use_checkpoint (bool): Whether to use checkpoint. Defaults to False. + num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6. + lang_model_name (str): Name of the language model. + Defaults to 'bert-base-uncased'. + """ + + def __init__(self, + *args, + early_fuse: bool = False, + use_checkpoint: bool = False, + num_dyhead_blocks: int = 6, + lang_model_name: str = 'bert-base-uncased', + init_cfg=None, + **kwargs): + super().__init__(*args, **kwargs, init_cfg=init_cfg) + self.head = VLFusionModule( + in_channels=self.in_channels, + feat_channels=self.feat_channels, + num_base_priors=self.num_base_priors, + early_fuse=early_fuse, + use_checkpoint=use_checkpoint, + num_dyhead_blocks=num_dyhead_blocks, + lang_model_name=lang_model_name) + self.text_masks = None + + def _init_layers(self) -> None: + """No need to initialize the ATSS head layer.""" + pass + + def forward(self, visual_feats: Tuple[Tensor], + language_feats: dict) -> Tuple[Tensor]: + """Forward function.""" + bbox_preds, centerness, cls_logits = self.head(visual_feats, + language_feats) + return cls_logits, bbox_preds, centerness + + def loss(self, visual_feats: Tuple[Tensor], language_feats: dict, + batch_data_samples): + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + outs = self(visual_feats, language_feats) + self.text_masks = language_feats['masks'] + loss_inputs = outs + (batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + centernesses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + centernesses (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_reg_targets + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + anchors = torch.cat(anchor_list, dim=1) + labels = torch.cat(labels_list, dim=1) + label_weights = torch.cat(label_weights_list, dim=1) + bbox_targets = torch.cat(bbox_targets_list, dim=1) + cls_scores = torch.cat(cls_scores, dim=1) + + centernesses_ = [] + bbox_preds_ = [] + for bbox_pred, centerness in zip(bbox_preds, centernesses): + centernesses_.append( + centerness.permute(0, 2, 3, + 1).reshape(cls_scores.size(0), -1, 1)) + bbox_preds_.append( + bbox_pred.permute(0, 2, 3, + 1).reshape(cls_scores.size(0), -1, 4)) + bbox_preds = torch.cat(bbox_preds_, dim=1) + centernesses = torch.cat(centernesses_, dim=1) + + losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = \ + self._loss_by_feat( + anchors, + cls_scores, + bbox_preds, + centernesses, + labels, + label_weights, + bbox_targets, + avg_factor=avg_factor) + + bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() + losses_bbox = losses_bbox / bbox_avg_factor + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_centerness=loss_centerness) + + def _loss_by_feat(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, centerness: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + avg_factor: float) -> dict: + """Calculate the loss of all scale level based on the features + extracted by the detection head. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + anchors = anchors.reshape(-1, 4) + + # ===== this change ===== + pos_inds = (labels.sum(-1) > 0).reshape(-1) + + # Loss is not computed for the padded regions of the text. + assert (self.text_masks.dim() == 2) + text_mask = (self.text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, cls_score.size(1), 1) + cls_score = torch.masked_select(cls_score, text_mask).contiguous() + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., + None].repeat(1, 1, text_mask.size(-1)) + label_weights = torch.masked_select(label_weights, text_mask) + + bbox_pred = bbox_pred.reshape(-1, 4) + centerness = centerness.reshape(-1) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # classification loss + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + + if pos_inds.sum() > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_centerness = centerness[pos_inds] + + centerness_targets = self.centerness_target( + pos_anchors, pos_bbox_targets) + + if torch.isnan(centerness_targets).any(): + print('=====Centerness includes NaN=====') + mask = ~torch.isnan(centerness_targets) + centerness_targets = centerness_targets[mask] + pos_centerness = pos_centerness[mask] + pos_anchors = pos_anchors[mask] + pos_bbox_targets = pos_bbox_targets[mask] + pos_bbox_pred = pos_bbox_pred[mask] + + if pos_bbox_targets.shape[0] == 0: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = bbox_targets.new_tensor(0.) + return loss_cls, loss_bbox, loss_centerness, \ + centerness_targets.sum() + + # The decoding process takes the offset into consideration. + pos_anchors[:, 2:] += 1 + pos_decode_bbox_pred = self.bbox_coder.decode( + pos_anchors, pos_bbox_pred) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_bbox_targets, + weight=centerness_targets, + avg_factor=1.0) + + # centerness loss + loss_centerness = self.loss_centerness( + pos_centerness, centerness_targets, avg_factor=avg_factor) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + num_level_anchors: List[int], + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors (List[int]): Number of anchors of each scale + level. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4) + pos_inds (Tensor): Indices of positive anchor with shape + (num_pos,). + neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + sampling_result (:obj:`SamplingResult`): Sampling results. + """ + anchors = flat_anchors + # Align the official implementation + anchors[:, 2:] -= 1 + + num_level_anchors_inside = num_level_anchors + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign(pred_instances, + num_level_anchors_inside, + gt_instances, gt_instances_ignore) + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + + # ===== this change ===== + labels = anchors.new_full((num_valid_anchors, self.feat_channels), + 0, + dtype=torch.float32) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if self.reg_decoded_bbox: + pos_bbox_targets = sampling_result.pos_gt_bboxes + else: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_priors, sampling_result.pos_gt_bboxes) + + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + # ===== this change ===== + labels[pos_inds] = gt_instances.positive_maps[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds, sampling_result) + + def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: + """Calculate the centerness between anchors and gts. + + Only calculate pos centerness targets, otherwise there may be nan. + + Args: + anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. + gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Centerness between anchors and gts. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l_ = anchors_cx - gts[:, 0] + t_ = anchors_cy - gts[:, 1] + r_ = gts[:, 2] - anchors_cx + b_ = gts[:, 3] - anchors_cy + + left_right = torch.stack([l_, r_], dim=1) + top_bottom = torch.stack([t_, b_], dim=1) + centerness = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) + # assert not torch.isnan(centerness).any() + return centerness + + def predict(self, + visual_feats: Tuple[Tensor], + language_feats: dict, + batch_data_samples, + rescale: bool = True): + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + + Args: + visual_feats (tuple[Tensor]): Multi-level visual features from the + upstream network, each is a 4D-tensor. + language_feats (dict): Language features from the upstream network. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + batch_token_positive_maps = [ + data_samples.token_positive_map + for data_samples in batch_data_samples + ] + outs = self(visual_feats, language_feats) + + predictions = self.predict_by_feat( + *outs, + batch_img_metas=batch_img_metas, + batch_token_positive_maps=batch_token_positive_maps, + rescale=rescale) + return predictions + + def predict_by_feat(self, + cls_logits: List[Tensor], + bbox_preds: List[Tensor], + score_factors: List[Tensor], + batch_img_metas: Optional[List[dict]] = None, + batch_token_positive_maps: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_logits (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + batch_token_positive_maps (list[dict], Optional): Batch token + positive map. Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(bbox_preds) == len(score_factors) + num_levels = len(bbox_preds) + + featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + + result_list = [] + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + token_positive_maps = batch_token_positive_maps[img_id] + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + cls_logit_list = select_single_mlvl( + cls_logits, img_id, detach=True) + + results = self._predict_by_feat_single( + bbox_pred_list=bbox_pred_list, + score_factor_list=score_factor_list, + cls_logit_list=cls_logit_list, + mlvl_priors=mlvl_priors, + token_positive_maps=token_positive_maps, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + cls_logit_list: List[Tensor], + mlvl_priors: List[Tensor], + token_positive_maps: dict, + img_meta: dict, + cfg: ConfigDict, + rescale: bool = True, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + cls_logit_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + token_positive_maps (dict): Token positive map. + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + score_thr = cfg.get('score_thr', 0) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + + for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \ + enumerate(zip(bbox_pred_list, + score_factor_list, cls_logit_list, mlvl_priors)): + bbox_pred = bbox_pred.permute(1, 2, 0).reshape( + -1, self.bbox_coder.encode_size) + score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid() + + scores = convert_grounding_to_cls_scores( + logits=cls_logit.sigmoid()[None], + positive_maps=[token_positive_maps])[0] + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + score_factor = score_factor[keep_idxs] + scores = torch.sqrt(scores * score_factor) + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + + predictions = self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + if len(predictions) > 0: + # Note: GLIP adopts a very strange bbox decoder logic, + # and if 1 is not added here, it will not align with + # the official mAP. + predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1 + return predictions diff --git a/mmdet/models/dense_heads/autoassign_head.py b/mmdet/models/dense_heads/autoassign_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b30ff0d7d41205f0a92ede7b8eb10a234c5942 --- /dev/null +++ b/mmdet/models/dense_heads/autoassign_head.py @@ -0,0 +1,524 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Scale +from mmengine.model import bias_init_with_prob, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import InstanceList, OptInstanceList, reduce_mean +from ..task_modules.prior_generators import MlvlPointGenerator +from ..utils import levels_to_images, multi_apply +from .fcos_head import FCOSHead + +EPS = 1e-12 + + +class CenterPrior(nn.Module): + """Center Weighting module to adjust the category-specific prior + distributions. + + Args: + force_topk (bool): When no point falls into gt_bbox, forcibly + select the k points closest to the center to calculate + the center prior. Defaults to False. + topk (int): The number of points used to calculate the + center prior when no point falls in gt_bbox. Only work when + force_topk if True. Defaults to 9. + num_classes (int): The class number of dataset. Defaults to 80. + strides (Sequence[int]): The stride of each input feature map. + Defaults to (8, 16, 32, 64, 128). + """ + + def __init__( + self, + force_topk: bool = False, + topk: int = 9, + num_classes: int = 80, + strides: Sequence[int] = (8, 16, 32, 64, 128) + ) -> None: + super().__init__() + self.mean = nn.Parameter(torch.zeros(num_classes, 2)) + self.sigma = nn.Parameter(torch.ones(num_classes, 2)) + self.strides = strides + self.force_topk = force_topk + self.topk = topk + + def forward(self, anchor_points_list: List[Tensor], + gt_instances: InstanceData, + inside_gt_bbox_mask: Tensor) -> Tuple[Tensor, Tensor]: + """Get the center prior of each point on the feature map for each + instance. + + Args: + anchor_points_list (list[Tensor]): list of coordinate + of points on feature map. Each with shape + (num_points, 2). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + inside_gt_bbox_mask (Tensor): Tensor of bool type, + with shape of (num_points, num_gt), each + value is used to mark whether this point falls + within a certain gt. + + Returns: + tuple[Tensor, Tensor]: + + - center_prior_weights(Tensor): Float tensor with shape of \ + (num_points, num_gt). Each value represents the center \ + weighting coefficient. + - inside_gt_bbox_mask (Tensor): Tensor of bool type, with shape \ + of (num_points, num_gt), each value is used to mark whether this \ + point falls within a certain gt or is the topk nearest points for \ + a specific gt_bbox. + """ + gt_bboxes = gt_instances.bboxes + labels = gt_instances.labels + + inside_gt_bbox_mask = inside_gt_bbox_mask.clone() + num_gts = len(labels) + num_points = sum([len(item) for item in anchor_points_list]) + if num_gts == 0: + return gt_bboxes.new_zeros(num_points, + num_gts), inside_gt_bbox_mask + center_prior_list = [] + for slvl_points, stride in zip(anchor_points_list, self.strides): + # slvl_points: points from single level in FPN, has shape (h*w, 2) + # single_level_points has shape (h*w, num_gt, 2) + single_level_points = slvl_points[:, None, :].expand( + (slvl_points.size(0), len(gt_bboxes), 2)) + gt_center_x = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2) + gt_center_y = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2) + gt_center = torch.stack((gt_center_x, gt_center_y), dim=1) + gt_center = gt_center[None] + # instance_center has shape (1, num_gt, 2) + instance_center = self.mean[labels][None] + # instance_sigma has shape (1, num_gt, 2) + instance_sigma = self.sigma[labels][None] + # distance has shape (num_points, num_gt, 2) + distance = (((single_level_points - gt_center) / float(stride) - + instance_center)**2) + center_prior = torch.exp(-distance / + (2 * instance_sigma**2)).prod(dim=-1) + center_prior_list.append(center_prior) + center_prior_weights = torch.cat(center_prior_list, dim=0) + + if self.force_topk: + gt_inds_no_points_inside = torch.nonzero( + inside_gt_bbox_mask.sum(0) == 0).reshape(-1) + if gt_inds_no_points_inside.numel(): + topk_center_index = \ + center_prior_weights[:, gt_inds_no_points_inside].topk( + self.topk, + dim=0)[1] + temp_mask = inside_gt_bbox_mask[:, gt_inds_no_points_inside] + inside_gt_bbox_mask[:, gt_inds_no_points_inside] = \ + torch.scatter(temp_mask, + dim=0, + index=topk_center_index, + src=torch.ones_like( + topk_center_index, + dtype=torch.bool)) + + center_prior_weights[~inside_gt_bbox_mask] = 0 + return center_prior_weights, inside_gt_bbox_mask + + +@MODELS.register_module() +class AutoAssignHead(FCOSHead): + """AutoAssignHead head used in AutoAssign. + + More details can be found in the `paper + `_ . + + Args: + force_topk (bool): Used in center prior initialization to + handle extremely small gt. Default is False. + topk (int): The number of points used to calculate the + center prior when no point falls in gt_bbox. Only work when + force_topk if True. Defaults to 9. + pos_loss_weight (float): The loss weight of positive loss + and with default value 0.25. + neg_loss_weight (float): The loss weight of negative loss + and with default value 0.75. + center_loss_weight (float): The loss weight of center prior + loss and with default value 0.75. + """ + + def __init__(self, + *args, + force_topk: bool = False, + topk: int = 9, + pos_loss_weight: float = 0.25, + neg_loss_weight: float = 0.75, + center_loss_weight: float = 0.75, + **kwargs) -> None: + super().__init__(*args, conv_bias=True, **kwargs) + self.center_prior = CenterPrior( + force_topk=force_topk, + topk=topk, + num_classes=self.num_classes, + strides=self.strides) + self.pos_loss_weight = pos_loss_weight + self.neg_loss_weight = neg_loss_weight + self.center_loss_weight = center_loss_weight + self.prior_generator = MlvlPointGenerator(self.strides, offset=0) + + def init_weights(self) -> None: + """Initialize weights of the head. + + In particular, we have special initialization for classified conv's and + regression conv's bias + """ + + super(AutoAssignHead, self).init_weights() + bias_cls = bias_init_with_prob(0.02) + normal_init(self.conv_cls, std=0.01, bias=bias_cls) + normal_init(self.conv_reg, std=0.01, bias=4.0) + + def forward_single(self, x: Tensor, scale: Scale, + stride: int) -> Tuple[Tensor, Tensor, Tensor]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps, only + used to normalize the bbox prediction when self.norm_on_bbox + is True. + + Returns: + tuple[Tensor, Tensor, Tensor]: scores for each class, bbox + predictions and centerness predictions of input feature maps. + """ + cls_score, bbox_pred, cls_feat, reg_feat = super( + FCOSHead, self).forward_single(x) + centerness = self.conv_centerness(reg_feat) + # scale the bbox_pred of different level + # float to avoid overflow when enabling FP16 + bbox_pred = scale(bbox_pred).float() + # bbox_pred needed for gradient computation has been modified + # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace + # F.relu(bbox_pred) with bbox_pred.clamp(min=0) + bbox_pred = bbox_pred.clamp(min=0) + bbox_pred *= stride + return cls_score, bbox_pred, centerness + + def get_pos_loss_single(self, cls_score: Tensor, objectness: Tensor, + reg_loss: Tensor, gt_instances: InstanceData, + center_prior_weights: Tensor) -> Tuple[Tensor]: + """Calculate the positive loss of all points in gt_bboxes. + + Args: + cls_score (Tensor): All category scores for each point on + the feature map. The shape is (num_points, num_class). + objectness (Tensor): Foreground probability of all points, + has shape (num_points, 1). + reg_loss (Tensor): The regression loss of each gt_bbox and each + prediction box, has shape of (num_points, num_gt). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + center_prior_weights (Tensor): Float tensor with shape + of (num_points, num_gt). Each value represents + the center weighting coefficient. + + Returns: + tuple[Tensor]: + + - pos_loss (Tensor): The positive loss of all points in the \ + gt_bboxes. + """ + gt_labels = gt_instances.labels + # p_loc: localization confidence + p_loc = torch.exp(-reg_loss) + # p_cls: classification confidence + p_cls = (cls_score * objectness)[:, gt_labels] + # p_pos: joint confidence indicator + p_pos = p_cls * p_loc + + # 3 is a hyper-parameter to control the contributions of high and + # low confidence locations towards positive losses. + confidence_weight = torch.exp(p_pos * 3) + p_pos_weight = (confidence_weight * center_prior_weights) / ( + (confidence_weight * center_prior_weights).sum( + 0, keepdim=True)).clamp(min=EPS) + reweighted_p_pos = (p_pos * p_pos_weight).sum(0) + pos_loss = F.binary_cross_entropy( + reweighted_p_pos, + torch.ones_like(reweighted_p_pos), + reduction='none') + pos_loss = pos_loss.sum() * self.pos_loss_weight + return pos_loss, + + def get_neg_loss_single(self, cls_score: Tensor, objectness: Tensor, + gt_instances: InstanceData, ious: Tensor, + inside_gt_bbox_mask: Tensor) -> Tuple[Tensor]: + """Calculate the negative loss of all points in feature map. + + Args: + cls_score (Tensor): All category scores for each point on + the feature map. The shape is (num_points, num_class). + objectness (Tensor): Foreground probability of all points + and is shape of (num_points, 1). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + ious (Tensor): Float tensor with shape of (num_points, num_gt). + Each value represent the iou of pred_bbox and gt_bboxes. + inside_gt_bbox_mask (Tensor): Tensor of bool type, + with shape of (num_points, num_gt), each + value is used to mark whether this point falls + within a certain gt. + + Returns: + tuple[Tensor]: + + - neg_loss (Tensor): The negative loss of all points in the \ + feature map. + """ + gt_labels = gt_instances.labels + num_gts = len(gt_labels) + joint_conf = (cls_score * objectness) + p_neg_weight = torch.ones_like(joint_conf) + if num_gts > 0: + # the order of dinmension would affect the value of + # p_neg_weight, we strictly follow the original + # implementation. + inside_gt_bbox_mask = inside_gt_bbox_mask.permute(1, 0) + ious = ious.permute(1, 0) + + foreground_idxs = torch.nonzero(inside_gt_bbox_mask, as_tuple=True) + temp_weight = (1 / (1 - ious[foreground_idxs]).clamp_(EPS)) + + def normalize(x): + return (x - x.min() + EPS) / (x.max() - x.min() + EPS) + + for instance_idx in range(num_gts): + idxs = foreground_idxs[0] == instance_idx + if idxs.any(): + temp_weight[idxs] = normalize(temp_weight[idxs]) + + p_neg_weight[foreground_idxs[1], + gt_labels[foreground_idxs[0]]] = 1 - temp_weight + + logits = (joint_conf * p_neg_weight) + neg_loss = ( + logits**2 * F.binary_cross_entropy( + logits, torch.zeros_like(logits), reduction='none')) + neg_loss = neg_loss.sum() * self.neg_loss_weight + return neg_loss, + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + objectnesses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * 4. + objectnesses (list[Tensor]): objectness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + assert len(cls_scores) == len(bbox_preds) == len(objectnesses) + all_num_gt = sum([len(item) for item in batch_gt_instances]) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets( + all_level_points, batch_gt_instances) + + center_prior_weight_list = [] + temp_inside_gt_bbox_mask_list = [] + for gt_instances, inside_gt_bbox_mask in zip(batch_gt_instances, + inside_gt_bbox_mask_list): + center_prior_weight, inside_gt_bbox_mask = \ + self.center_prior(all_level_points, gt_instances, + inside_gt_bbox_mask) + center_prior_weight_list.append(center_prior_weight) + temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask) + inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list + mlvl_points = torch.cat(all_level_points, dim=0) + bbox_preds = levels_to_images(bbox_preds) + cls_scores = levels_to_images(cls_scores) + objectnesses = levels_to_images(objectnesses) + + reg_loss_list = [] + ious_list = [] + num_points = len(mlvl_points) + + for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip( + bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list): + temp_num_gt = encoded_targets.size(1) + expand_mlvl_points = mlvl_points[:, None, :].expand( + num_points, temp_num_gt, 2).reshape(-1, 2) + encoded_targets = encoded_targets.reshape(-1, 4) + expand_bbox_pred = bbox_pred[:, None, :].expand( + num_points, temp_num_gt, 4).reshape(-1, 4) + decoded_bbox_preds = self.bbox_coder.decode( + expand_mlvl_points, expand_bbox_pred) + decoded_target_preds = self.bbox_coder.decode( + expand_mlvl_points, encoded_targets) + with torch.no_grad(): + ious = bbox_overlaps( + decoded_bbox_preds, decoded_target_preds, is_aligned=True) + ious = ious.reshape(num_points, temp_num_gt) + if temp_num_gt: + ious = ious.max( + dim=-1, keepdim=True).values.repeat(1, temp_num_gt) + else: + ious = ious.new_zeros(num_points, temp_num_gt) + ious[~inside_gt_bbox_mask] = 0 + ious_list.append(ious) + loss_bbox = self.loss_bbox( + decoded_bbox_preds, + decoded_target_preds, + weight=None, + reduction_override='none') + reg_loss_list.append(loss_bbox.reshape(num_points, temp_num_gt)) + + cls_scores = [item.sigmoid() for item in cls_scores] + objectnesses = [item.sigmoid() for item in objectnesses] + pos_loss_list, = multi_apply(self.get_pos_loss_single, cls_scores, + objectnesses, reg_loss_list, + batch_gt_instances, + center_prior_weight_list) + pos_avg_factor = reduce_mean( + bbox_pred.new_tensor(all_num_gt)).clamp_(min=1) + pos_loss = sum(pos_loss_list) / pos_avg_factor + + neg_loss_list, = multi_apply(self.get_neg_loss_single, cls_scores, + objectnesses, batch_gt_instances, + ious_list, inside_gt_bbox_mask_list) + neg_avg_factor = sum(item.data.sum() + for item in center_prior_weight_list) + neg_avg_factor = reduce_mean(neg_avg_factor).clamp_(min=1) + neg_loss = sum(neg_loss_list) / neg_avg_factor + + center_loss = [] + for i in range(len(batch_img_metas)): + + if inside_gt_bbox_mask_list[i].any(): + center_loss.append( + len(batch_gt_instances[i]) / + center_prior_weight_list[i].sum().clamp_(min=EPS)) + # when width or height of gt_bbox is smaller than stride of p3 + else: + center_loss.append(center_prior_weight_list[i].sum() * 0) + + center_loss = torch.stack(center_loss).mean() * self.center_loss_weight + + # avoid dead lock in DDP + if all_num_gt == 0: + pos_loss = bbox_preds[0].sum() * 0 + dummy_center_prior_loss = self.center_prior.mean.sum( + ) * 0 + self.center_prior.sigma.sum() * 0 + center_loss = objectnesses[0].sum() * 0 + dummy_center_prior_loss + + loss = dict( + loss_pos=pos_loss, loss_neg=neg_loss, loss_center=center_loss) + + return loss + + def get_targets( + self, points: List[Tensor], batch_gt_instances: InstanceList + ) -> Tuple[List[Tensor], List[Tensor]]: + """Compute regression targets and each point inside or outside gt_bbox + in multiple images. + + Args: + points (list[Tensor]): Points of all fpn level, each has shape + (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple(list[Tensor], list[Tensor]): + + - inside_gt_bbox_mask_list (list[Tensor]): Each Tensor is with \ + bool type and shape of (num_points, num_gt), each value is used \ + to mark whether this point falls within a certain gt. + - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ + level. Each tensor has shape (num_points, num_gt, 4). + """ + + concat_points = torch.cat(points, dim=0) + # the number of points per img, per lvl + inside_gt_bbox_mask_list, bbox_targets_list = multi_apply( + self._get_targets_single, batch_gt_instances, points=concat_points) + return inside_gt_bbox_mask_list, bbox_targets_list + + def _get_targets_single(self, gt_instances: InstanceData, + points: Tensor) -> Tuple[Tensor, Tensor]: + """Compute regression targets and each point inside or outside gt_bbox + for a single image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + points (Tensor): Points of all fpn level, has shape + (num_points, 2). + + Returns: + tuple[Tensor, Tensor]: Containing the following Tensors: + + - inside_gt_bbox_mask (Tensor): Bool tensor with shape \ + (num_points, num_gt), each value is used to mark whether this \ + point falls within a certain gt. + - bbox_targets (Tensor): BBox targets of each points with each \ + gt_bboxes, has shape (num_points, num_gt, 4). + """ + gt_bboxes = gt_instances.bboxes + num_points = points.size(0) + num_gts = gt_bboxes.size(0) + gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) + xs, ys = points[:, 0], points[:, 1] + xs = xs[:, None] + ys = ys[:, None] + left = xs - gt_bboxes[..., 0] + right = gt_bboxes[..., 2] - xs + top = ys - gt_bboxes[..., 1] + bottom = gt_bboxes[..., 3] - ys + bbox_targets = torch.stack((left, top, right, bottom), -1) + if num_gts: + inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 + else: + inside_gt_bbox_mask = bbox_targets.new_zeros((num_points, num_gts), + dtype=torch.bool) + + return inside_gt_bbox_mask, bbox_targets diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a4469e02c469d029cc2791289dbf41554d6a53 --- /dev/null +++ b/mmdet/models/dense_heads/base_dense_head.py @@ -0,0 +1,583 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from abc import ABCMeta, abstractmethod +from inspect import signature +from typing import List, Optional, Tuple + +import torch +from mmcv.ops import batched_nms +from mmengine.config import ConfigDict +from mmengine.model import BaseModule, constant_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.structures import SampleList +from mmdet.structures.bbox import (cat_boxes, get_box_tensor, get_box_wh, + scale_boxes) +from mmdet.utils import InstanceList, OptMultiConfig +from ..test_time_augs import merge_aug_results +from ..utils import (filter_scores_and_topk, select_single_mlvl, + unpack_gt_instances) + + +class BaseDenseHead(BaseModule, metaclass=ABCMeta): + """Base class for DenseHeads. + + 1. The ``init_weights`` method is used to initialize densehead's + model parameters. After detector initialization, ``init_weights`` + is triggered when ``detector.init_weights()`` is called externally. + + 2. The ``loss`` method is used to calculate the loss of densehead, + which includes two steps: (1) the densehead model performs forward + propagation to obtain the feature maps (2) The ``loss_by_feat`` method + is called based on the feature maps to calculate the loss. + + .. code:: text + + loss(): forward() -> loss_by_feat() + + 3. The ``predict`` method is used to predict detection results, + which includes two steps: (1) the densehead model performs forward + propagation to obtain the feature maps (2) The ``predict_by_feat`` method + is called based on the feature maps to predict detection results including + post-processing. + + .. code:: text + + predict(): forward() -> predict_by_feat() + + 4. The ``loss_and_predict`` method is used to return loss and detection + results at the same time. It will call densehead's ``forward``, + ``loss_by_feat`` and ``predict_by_feat`` methods in order. If one-stage is + used as RPN, the densehead needs to return both losses and predictions. + This predictions is used as the proposal of roihead. + + .. code:: text + + loss_and_predict(): forward() -> loss_by_feat() -> predict_by_feat() + """ + + def __init__(self, init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + # `_raw_positive_infos` will be used in `get_positive_infos`, which + # can get positive information. + self._raw_positive_infos = dict() + + def init_weights(self) -> None: + """Initialize the weights.""" + super().init_weights() + # avoid init_cfg overwrite the initialization of `conv_offset` + for m in self.modules(): + # DeformConv2dPack, ModulatedDeformConv2dPack + if hasattr(m, 'conv_offset'): + constant_init(m.conv_offset, 0) + + def get_positive_infos(self) -> InstanceList: + """Get positive information from sampling results. + + Returns: + list[:obj:`InstanceData`]: Positive information of each image, + usually including positive bboxes, positive labels, positive + priors, etc. + """ + if len(self._raw_positive_infos) == 0: + return None + + sampling_results = self._raw_positive_infos.get( + 'sampling_results', None) + assert sampling_results is not None + positive_infos = [] + for sampling_result in enumerate(sampling_results): + pos_info = InstanceData() + pos_info.bboxes = sampling_result.pos_gt_bboxes + pos_info.labels = sampling_result.pos_gt_labels + pos_info.priors = sampling_result.pos_priors + pos_info.pos_assigned_gt_inds = \ + sampling_result.pos_assigned_gt_inds + pos_info.pos_inds = sampling_result.pos_inds + positive_infos.append(pos_info) + return positive_infos + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + outs = self(x) + + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + loss_inputs = outs + (batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + losses = self.loss_by_feat(*loss_inputs) + return losses + + @abstractmethod + def loss_by_feat(self, **kwargs) -> dict: + """Calculate the loss based on the features extracted by the detection + head.""" + pass + + def loss_and_predict( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + proposal_cfg: Optional[ConfigDict] = None + ) -> Tuple[dict, InstanceList]: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. + + Args: + x (tuple[Tensor]): Features from FPN. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + proposal_cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + + Returns: + tuple: the return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each image after the post process. + """ + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + outs = self(x) + + loss_inputs = outs + (batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + losses = self.loss_by_feat(*loss_inputs) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg) + return losses, predictions + + def predict(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + outs = self(x) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device) + + result_list = [] + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl( + cls_scores, img_id, detach=True) + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + if with_score_factors: + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + else: + score_factor_list = [None for _ in range(num_levels)] + + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + score_factor_list=score_factor_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + for level_idx, (cls_score, bbox_pred, score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, + score_factor_list, mlvl_priors)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + + # the `custom_cls_channels` parameter is derived from + # CrossEntropyCustomLoss and FocalCustomLoss, and is currently used + # in v3det. + if getattr(self.loss_cls, 'custom_cls_channels', False): + scores = self.loss_cls.get_activation(cls_score) + elif self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = cls_score.softmax(-1)[:, :-1] + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + + if with_score_factors: + score_factor = score_factor[keep_idxs] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + if with_score_factors: + mlvl_score_factors.append(score_factor) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def _bbox_post_process(self, + results: InstanceData, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> InstanceData: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (ConfigDict): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if rescale: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + results.bboxes = scale_boxes(results.bboxes, scale_factor) + + if hasattr(results, 'score_factors'): + # TODO: Add sqrt operation in order to be consistent with + # the paper. + score_factors = results.pop('score_factors') + results.scores = results.scores * score_factors + + # filter small size bboxes + if cfg.get('min_bbox_size', -1) >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg + if with_nms and results.bboxes.numel() > 0: + bboxes = get_box_tensor(results.bboxes) + det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, + results.labels, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + results = results[:cfg.max_per_img] + + return results + + def aug_test(self, + aug_batch_feats, + aug_batch_img_metas, + rescale=False, + with_ori_nms=False, + **kwargs): + """Test function with test time augmentation. + + Args: + aug_batch_feats (list[tuple[Tensor]]): The outer list + indicates test-time augmentations and inner tuple + indicate the multi-level feats from + FPN, each Tensor should have a shape (B, C, H, W), + aug_batch_img_metas (list[list[dict]]): Meta information + of images under the different test-time augs + (multiscale, flip, etc.). The outer list indicate + the + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + with_ori_nms (bool): Whether execute the nms in original head. + Defaults to False. It will be `True` when the head is + adopted as `rpn_head`. + + Returns: + list(obj:`InstanceData`): Detection results of the + input images. Each item usually contains\ + following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance,) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances,). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + # TODO: remove this for detr and deformdetr + sig_of_get_results = signature(self.get_results) + get_results_args = [ + p.name for p in sig_of_get_results.parameters.values() + ] + get_results_single_sig = signature(self._get_results_single) + get_results_single_sig_args = [ + p.name for p in get_results_single_sig.parameters.values() + ] + assert ('with_nms' in get_results_args) and \ + ('with_nms' in get_results_single_sig_args), \ + f'{self.__class__.__name__}' \ + 'does not support test-time augmentation ' + + num_imgs = len(aug_batch_img_metas[0]) + aug_batch_results = [] + for x, img_metas in zip(aug_batch_feats, aug_batch_img_metas): + outs = self.forward(x) + batch_instance_results = self.get_results( + *outs, + img_metas=img_metas, + cfg=self.test_cfg, + rescale=False, + with_nms=with_ori_nms, + **kwargs) + aug_batch_results.append(batch_instance_results) + + # after merging, bboxes will be rescaled to the original image + batch_results = merge_aug_results(aug_batch_results, + aug_batch_img_metas) + + final_results = [] + for img_id in range(num_imgs): + results = batch_results[img_id] + det_bboxes, keep_idxs = batched_nms(results.bboxes, results.scores, + results.labels, + self.test_cfg.nms) + results = results[keep_idxs] + # some nms operation may reweight the score such as softnms + results.scores = det_bboxes[:, -1] + results = results[:self.test_cfg.max_per_img] + if rescale: + # all results have been mapped to the original scale + # in `merge_aug_results`, so just pass + pass + else: + # map to the first aug image scale + scale_factor = results.bboxes.new_tensor( + aug_batch_img_metas[0][img_id]['scale_factor']) + results.bboxes = \ + results.bboxes * scale_factor + + final_results.append(results) + + return final_results diff --git a/mmdet/models/dense_heads/base_mask_head.py b/mmdet/models/dense_heads/base_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7183d782829aa15bf12b9e2f7ade999c84d0593f --- /dev/null +++ b/mmdet/models/dense_heads/base_mask_head.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Tuple, Union + +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.structures import SampleList +from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig +from ..utils import unpack_gt_instances + + +class BaseMaskHead(BaseModule, metaclass=ABCMeta): + """Base class for mask heads used in One-Stage Instance Segmentation.""" + + def __init__(self, init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + + @abstractmethod + def loss_by_feat(self, *args, **kwargs): + """Calculate the loss based on the features extracted by the mask + head.""" + pass + + @abstractmethod + def predict_by_feat(self, *args, **kwargs): + """Transform a batch of output features extracted from the head into + mask results.""" + pass + + def loss(self, + x: Union[List[Tensor], Tuple[Tensor]], + batch_data_samples: SampleList, + positive_infos: OptInstanceList = None, + **kwargs) -> dict: + """Perform forward propagation and loss calculation of the mask head on + the features of the upstream network. + + Args: + x (list[Tensor] | tuple[Tensor]): Features from FPN. + Each has a shape (B, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + positive_infos (list[:obj:`InstanceData`], optional): Information + of positive samples. Used when the label assignment is + done outside the MaskHead, e.g., BboxHead in + YOLACT or CondInst, etc. When the label assignment is done in + MaskHead, it would be None, like SOLO or SOLOv2. All values + in it should have shape (num_positive_samples, *). + + + Returns: + dict: A dictionary of loss components. + """ + if positive_infos is None: + outs = self(x) + else: + outs = self(x, positive_infos) + + assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \ + 'even if only one item is returned' + + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + for gt_instances, img_metas in zip(batch_gt_instances, + batch_img_metas): + img_shape = img_metas['batch_input_shape'] + gt_masks = gt_instances.masks.pad(img_shape) + gt_instances.masks = gt_masks + + losses = self.loss_by_feat( + *outs, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + positive_infos=positive_infos, + batch_gt_instances_ignore=batch_gt_instances_ignore, + **kwargs) + return losses + + def predict(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = False, + results_list: OptInstanceList = None, + **kwargs) -> InstanceList: + """Test function without test-time augmentation. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + results_list (list[obj:`InstanceData`], optional): Detection + results of each image after the post process. Only exist + if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc. + + Returns: + list[obj:`InstanceData`]: Instance segmentation + results of each image after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance,) + - labels (Tensor): Has a shape (num_instances,). + - masks (Tensor): Processed mask results, has a + shape (num_instances, h, w). + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + if results_list is None: + outs = self(x) + else: + outs = self(x, results_list) + + results_list = self.predict_by_feat( + *outs, + batch_img_metas=batch_img_metas, + rescale=rescale, + results_list=results_list, + **kwargs) + + return results_list diff --git a/mmdet/models/dense_heads/boxinst_head.py b/mmdet/models/dense_heads/boxinst_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6e8f7777a852cad89b709e59af2d8e12b343a6 --- /dev/null +++ b/mmdet/models/dense_heads/boxinst_head.py @@ -0,0 +1,252 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine import MessageHub +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import InstanceList +from ..utils.misc import unfold_wo_center +from .condinst_head import CondInstBboxHead, CondInstMaskHead + + +@MODELS.register_module() +class BoxInstBboxHead(CondInstBboxHead): + """BoxInst box head used in https://arxiv.org/abs/2012.02310.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +@MODELS.register_module() +class BoxInstMaskHead(CondInstMaskHead): + """BoxInst mask head used in https://arxiv.org/abs/2012.02310. + + This head outputs the mask for BoxInst. + + Args: + pairwise_size (dict): The size of neighborhood for each pixel. + Defaults to 3. + pairwise_dilation (int): The dilation of neighborhood for each pixel. + Defaults to 2. + warmup_iters (int): Warmup iterations for pair-wise loss. + Defaults to 10000. + """ + + def __init__(self, + *arg, + pairwise_size: int = 3, + pairwise_dilation: int = 2, + warmup_iters: int = 10000, + **kwargs) -> None: + self.pairwise_size = pairwise_size + self.pairwise_dilation = pairwise_dilation + self.warmup_iters = warmup_iters + super().__init__(*arg, **kwargs) + + def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor: + """Compute the pairwise affinity for each pixel.""" + log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1) + log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1) + + log_fg_prob_unfold = unfold_wo_center( + log_fg_prob, + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + log_bg_prob_unfold = unfold_wo_center( + log_bg_prob, + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + + # the probability of making the same prediction: + # p_i * p_j + (1 - p_i) * (1 - p_j) + # we compute the the probability in log space + # to avoid numerical instability + log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold + log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold + + # TODO: Figure out the difference between it and directly sum + max_ = torch.max(log_same_fg_prob, log_same_bg_prob) + log_same_prob = torch.log( + torch.exp(log_same_fg_prob - max_) + + torch.exp(log_same_bg_prob - max_)) + max_ + + return -log_same_prob[:, 0] + + def loss_by_feat(self, mask_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], positive_infos: InstanceList, + **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (list[Tensor]): List of predicted masks, each has + shape (num_classes, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Information of + positive samples of each image that are assigned in detection + head. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert positive_infos is not None, \ + 'positive_infos should not be None in `BoxInstMaskHead`' + losses = dict() + + loss_mask_project = 0. + loss_mask_pairwise = 0. + num_imgs = len(mask_preds) + total_pos = 0. + avg_fatcor = 0. + + for idx in range(num_imgs): + (mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \ + self._get_targets_single( + mask_preds[idx], batch_gt_instances[idx], + positive_infos[idx]) + # mask loss + total_pos += num_pos + if num_pos == 0 or pos_mask_targets is None: + loss_project = mask_pred.new_zeros(1).mean() + loss_pairwise = mask_pred.new_zeros(1).mean() + avg_fatcor += 0. + else: + # compute the project term + loss_project_x = self.loss_mask( + mask_pred.max(dim=1, keepdim=True)[0], + pos_mask_targets.max(dim=1, keepdim=True)[0], + reduction_override='none').sum() + loss_project_y = self.loss_mask( + mask_pred.max(dim=2, keepdim=True)[0], + pos_mask_targets.max(dim=2, keepdim=True)[0], + reduction_override='none').sum() + loss_project = loss_project_x + loss_project_y + # compute the pairwise term + pairwise_affinity = self.get_pairwise_affinity(mask_pred) + avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0) + loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum() + + loss_mask_project += loss_project + loss_mask_pairwise += loss_pairwise + + if total_pos == 0: + total_pos += 1 # avoid nan + if avg_fatcor == 0: + avg_fatcor += 1 # avoid nan + loss_mask_project = loss_mask_project / total_pos + loss_mask_pairwise = loss_mask_pairwise / avg_fatcor + message_hub = MessageHub.get_current_instance() + iter = message_hub.get_info('iter') + warmup_factor = min(iter / float(self.warmup_iters), 1.0) + loss_mask_pairwise *= warmup_factor + + losses.update( + loss_mask_project=loss_mask_project, + loss_mask_pairwise=loss_mask_pairwise) + return losses + + def _get_targets_single(self, mask_preds: Tensor, + gt_instances: InstanceData, + positive_info: InstanceData): + """Compute targets for predictions of single image. + + Args: + mask_preds (Tensor): Predicted prototypes with shape + (num_classes, H, W). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + positive_info (:obj:`InstanceData`): Information of positive + samples that are assigned in detection head. It usually + contains following keys. + + - pos_assigned_gt_inds (Tensor): Assigner GT indexes of + positive proposals, has shape (num_pos, ) + - pos_inds (Tensor): Positive index of image, has + shape (num_pos, ). + - param_pred (Tensor): Positive param preditions + with shape (num_pos, num_params). + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - mask_preds (Tensor): Positive predicted mask with shape + (num_pos, mask_h, mask_w). + - pos_mask_targets (Tensor): Positive mask targets with shape + (num_pos, mask_h, mask_w). + - pos_pairwise_masks (Tensor): Positive pairwise masks with + shape: (num_pos, num_neighborhood, mask_h, mask_w). + - num_pos (int): Positive numbers. + """ + gt_bboxes = gt_instances.bboxes + device = gt_bboxes.device + # Note that gt_masks are generated by full box + # from BoxInstDataPreprocessor + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device).float() + # Note that pairwise_masks are generated by image color similarity + # from BoxInstDataPreprocessor + pairwise_masks = gt_instances.pairwise_masks + pairwise_masks = pairwise_masks.to(device=device) + + # process with mask targets + pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') + scores = positive_info.get('scores') + centernesses = positive_info.get('centernesses') + num_pos = pos_assigned_gt_inds.size(0) + + if gt_masks.size(0) == 0 or num_pos == 0: + return mask_preds, None, None, 0 + # Since we're producing (near) full image masks, + # it'd take too much vram to backprop on every single mask. + # Thus we select only a subset. + if (self.max_masks_to_train != -1) and \ + (num_pos > self.max_masks_to_train): + perm = torch.randperm(num_pos) + select = perm[:self.max_masks_to_train] + mask_preds = mask_preds[select] + pos_assigned_gt_inds = pos_assigned_gt_inds[select] + num_pos = self.max_masks_to_train + elif self.topk_masks_per_img != -1: + unique_gt_inds = pos_assigned_gt_inds.unique() + num_inst_per_gt = max( + int(self.topk_masks_per_img / len(unique_gt_inds)), 1) + + keep_mask_preds = [] + keep_pos_assigned_gt_inds = [] + for gt_ind in unique_gt_inds: + per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) + mask_preds_per_inst = mask_preds[per_inst_pos_inds] + gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] + if sum(per_inst_pos_inds) > num_inst_per_gt: + per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( + dim=1)[0] + per_inst_centerness = centernesses[ + per_inst_pos_inds].sigmoid().reshape(-1, ) + select = (per_inst_scores * per_inst_centerness).topk( + k=num_inst_per_gt, dim=0)[1] + mask_preds_per_inst = mask_preds_per_inst[select] + gt_inds_per_inst = gt_inds_per_inst[select] + keep_mask_preds.append(mask_preds_per_inst) + keep_pos_assigned_gt_inds.append(gt_inds_per_inst) + mask_preds = torch.cat(keep_mask_preds) + pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) + num_pos = pos_assigned_gt_inds.size(0) + + # Follow the origin implement + start = int(self.mask_out_stride // 2) + gt_masks = gt_masks[:, start::self.mask_out_stride, + start::self.mask_out_stride] + gt_masks = gt_masks.gt(0.5).float() + pos_mask_targets = gt_masks[pos_assigned_gt_inds] + pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds] + pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1) + + return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos) diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a8686cc2c9118094df34a04fdeabd87daa636707 --- /dev/null +++ b/mmdet/models/dense_heads/cascade_rpn_head.py @@ -0,0 +1,1110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import division +import copy +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.ops import DeformConv2d +from mmengine.config import ConfigDict +from mmengine.model import BaseModule, ModuleList +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, + OptInstanceList, OptMultiConfig) +from ..task_modules.assigners import RegionAssigner +from ..task_modules.samplers import PseudoSampler +from ..utils import (images_to_levels, multi_apply, select_single_mlvl, + unpack_gt_instances) +from .base_dense_head import BaseDenseHead +from .rpn_head import RPNHead + + +class AdaptiveConv(BaseModule): + """AdaptiveConv used to adapt the sampling location with the anchors. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size (int or tuple[int]): Size of the conv kernel. + Defaults to 3. + stride (int or tuple[int]): Stride of the convolution. Defaults to 1. + padding (int or tuple[int]): Zero-padding added to both sides of + the input. Defaults to 1. + dilation (int or tuple[int]): Spacing between kernel elements. + Defaults to 3. + groups (int): Number of blocked connections from input channels to + output channels. Defaults to 1. + bias (bool): If set True, adds a learnable bias to the output. + Defaults to False. + adapt_type (str): Type of adaptive conv, can be either ``offset`` + (arbitrary anchors) or 'dilation' (uniform anchor). + Defaults to 'dilation'. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \ + list[dict]): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int]] = 1, + dilation: Union[int, Tuple[int]] = 3, + groups: int = 1, + bias: bool = False, + adapt_type: str = 'dilation', + init_cfg: MultiConfig = dict( + type='Normal', std=0.01, override=dict(name='conv')) + ) -> None: + super().__init__(init_cfg=init_cfg) + assert adapt_type in ['offset', 'dilation'] + self.adapt_type = adapt_type + + assert kernel_size == 3, 'Adaptive conv only supports kernels 3' + if self.adapt_type == 'offset': + assert stride == 1 and padding == 1 and groups == 1, \ + 'Adaptive conv offset mode only supports padding: {1}, ' \ + f'stride: {1}, groups: {1}' + self.conv = DeformConv2d( + in_channels, + out_channels, + kernel_size, + padding=padding, + stride=stride, + groups=groups, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + padding=dilation, + dilation=dilation) + + def forward(self, x: Tensor, offset: Tensor) -> Tensor: + """Forward function.""" + if self.adapt_type == 'offset': + N, _, H, W = x.shape + assert offset is not None + assert H * W == offset.shape[1] + # reshape [N, NA, 18] to (N, 18, H, W) + offset = offset.permute(0, 2, 1).reshape(N, -1, H, W) + offset = offset.contiguous() + x = self.conv(x, offset) + else: + assert offset is None + x = self.conv(x) + return x + + +@MODELS.register_module() +class StageCascadeRPNHead(RPNHead): + """Stage of CascadeRPNHead. + + Args: + in_channels (int): Number of channels in the input feature map. + anchor_generator (:obj:`ConfigDict` or dict): anchor generator config. + adapt_cfg (:obj:`ConfigDict` or dict): adaptation config. + bridged_feature (bool): whether update rpn feature. Defaults to False. + with_cls (bool): whether use classification branch. Defaults to True. + init_cfg :obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + anchor_generator: ConfigType = dict( + type='AnchorGenerator', + scales=[8], + ratios=[1.0], + strides=[4, 8, 16, 32, 64]), + adapt_cfg: ConfigType = dict(type='dilation', dilation=3), + bridged_feature: bool = False, + with_cls: bool = True, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + self.with_cls = with_cls + self.anchor_strides = anchor_generator['strides'] + self.anchor_scales = anchor_generator['scales'] + self.bridged_feature = bridged_feature + self.adapt_cfg = adapt_cfg + super().__init__( + in_channels=in_channels, + anchor_generator=anchor_generator, + init_cfg=init_cfg, + **kwargs) + + # override sampling and sampler + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + # use PseudoSampler when sampling is False + if self.train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) + + if init_cfg is None: + self.init_cfg = dict( + type='Normal', std=0.01, override=[dict(name='rpn_reg')]) + if self.with_cls: + self.init_cfg['override'].append(dict(name='rpn_cls')) + + def _init_layers(self) -> None: + """Init layers of a CascadeRPN stage.""" + adapt_cfg = copy.deepcopy(self.adapt_cfg) + adapt_cfg['adapt_type'] = adapt_cfg.pop('type') + self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels, + **adapt_cfg) + if self.with_cls: + self.rpn_cls = nn.Conv2d(self.feat_channels, + self.num_anchors * self.cls_out_channels, + 1) + self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1) + self.relu = nn.ReLU(inplace=True) + + def forward_single(self, x: Tensor, offset: Tensor) -> Tuple[Tensor]: + """Forward function of single scale.""" + bridged_x = x + x = self.relu(self.rpn_conv(x, offset)) + if self.bridged_feature: + bridged_x = x # update feature + cls_score = self.rpn_cls(x) if self.with_cls else None + bbox_pred = self.rpn_reg(x) + return bridged_x, cls_score, bbox_pred + + def forward( + self, + feats: List[Tensor], + offset_list: Optional[List[Tensor]] = None) -> Tuple[List[Tensor]]: + """Forward function.""" + if offset_list is None: + offset_list = [None for _ in range(len(feats))] + return multi_apply(self.forward_single, feats, offset_list) + + def _region_targets_single(self, flat_anchors: Tensor, valid_flags: Tensor, + gt_instances: InstanceData, img_meta: dict, + gt_instances_ignore: InstanceData, + featmap_sizes: List[Tuple[int, int]], + num_level_anchors: List[int]) -> tuple: + """Get anchor targets based on region for single level. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors, ). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + featmap_sizes (list[Tuple[int, int]]): Feature map size each level. + num_level_anchors (list[int]): The number of anchors in each level. + + Returns: + tuple: + + - labels (Tensor): Labels of each level. + - label_weights (Tensor): Label weights of each level. + - bbox_targets (Tensor): BBox targets of each level. + - bbox_weights (Tensor): BBox weights of each level. + - pos_inds (Tensor): positive samples indexes. + - neg_inds (Tensor): negative samples indexes. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + pred_instances = InstanceData() + pred_instances.priors = flat_anchors + pred_instances.valid_flags = valid_flags + + assign_result = self.assigner.assign( + pred_instances, + gt_instances, + img_meta, + featmap_sizes, + num_level_anchors, + self.anchor_scales[0], + self.anchor_strides, + gt_instances_ignore=gt_instances_ignore, + allowed_border=self.train_cfg['allowed_border']) + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_anchors = flat_anchors.shape[0] + bbox_targets = torch.zeros_like(flat_anchors) + bbox_weights = torch.zeros_like(flat_anchors) + labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long) + label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + else: + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, sampling_result) + + def region_targets( + self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + featmap_sizes: List[Tuple[int, int]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + return_sampling_results: bool = False, + ) -> tuple: + """Compute regression and classification targets for anchors when using + RegionAssigner. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. + featmap_sizes (list[Tuple[int, int]]): Feature map size each level. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + tuple: + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - bbox_weights_list (list[Tensor]): BBox weights of each level. + - avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + ``PseudoSampler``, ``avg_factor`` is usually equal to the + number of positive priors. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors to a single tensor + concat_anchor_list = [] + concat_valid_flag_list = [] + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + concat_anchor_list.append(torch.cat(anchor_list[i])) + concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) + + # compute targets for each image + (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, + pos_inds_list, neg_inds_list, sampling_results_list) = multi_apply( + self._region_targets_single, + concat_anchor_list, + concat_valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + featmap_sizes=featmap_sizes, + num_level_anchors=num_level_anchors) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + res = (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) + if return_sampling_results: + res = res + (sampling_results_list, ) + return res + + def get_targets( + self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + featmap_sizes: List[Tuple[int, int]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + return_sampling_results: bool = False, + ) -> tuple: + """Compute regression and classification targets for anchors. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. + featmap_sizes (list[Tuple[int, int]]): Feature map size each level. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - bbox_weights_list (list[Tensor]): BBox weights of each level. + - avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + ``PseudoSampler``, ``avg_factor`` is usually equal to the + number of positive priors. + """ + if isinstance(self.assigner, RegionAssigner): + cls_reg_targets = self.region_targets( + anchor_list, + valid_flag_list, + featmap_sizes, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + return_sampling_results=return_sampling_results) + else: + cls_reg_targets = super().get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + return_sampling_results=return_sampling_results) + return cls_reg_targets + + def anchor_offset(self, anchor_list: List[List[Tensor]], + anchor_strides: List[int], + featmap_sizes: List[Tuple[int, int]]) -> List[Tensor]: + """ Get offset for deformable conv based on anchor shape + NOTE: currently support deformable kernel_size=3 and dilation=1 + + Args: + anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of + multi-level anchors + anchor_strides (list[int]): anchor stride of each level + + Returns: + list[tensor]: offset of DeformConv kernel with shapes of + [NLVL, NA, 2, 18]. + """ + + def _shape_offset(anchors, stride, ks=3, dilation=1): + # currently support kernel_size=3 and dilation=1 + assert ks == 3 and dilation == 1 + pad = (ks - 1) // 2 + idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device) + yy, xx = torch.meshgrid(idx, idx) # return order matters + xx = xx.reshape(-1) + yy = yy.reshape(-1) + w = (anchors[:, 2] - anchors[:, 0]) / stride + h = (anchors[:, 3] - anchors[:, 1]) / stride + w = w / (ks - 1) - dilation + h = h / (ks - 1) - dilation + offset_x = w[:, None] * xx # (NA, ks**2) + offset_y = h[:, None] * yy # (NA, ks**2) + return offset_x, offset_y + + def _ctr_offset(anchors, stride, featmap_size): + feat_h, feat_w = featmap_size + assert len(anchors) == feat_h * feat_w + + x = (anchors[:, 0] + anchors[:, 2]) * 0.5 + y = (anchors[:, 1] + anchors[:, 3]) * 0.5 + # compute centers on feature map + x = x / stride + y = y / stride + # compute predefine centers + xx = torch.arange(0, feat_w, device=anchors.device) + yy = torch.arange(0, feat_h, device=anchors.device) + yy, xx = torch.meshgrid(yy, xx) + xx = xx.reshape(-1).type_as(x) + yy = yy.reshape(-1).type_as(y) + + offset_x = x - xx # (NA, ) + offset_y = y - yy # (NA, ) + return offset_x, offset_y + + num_imgs = len(anchor_list) + num_lvls = len(anchor_list[0]) + dtype = anchor_list[0][0].dtype + device = anchor_list[0][0].device + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + + offset_list = [] + for i in range(num_imgs): + mlvl_offset = [] + for lvl in range(num_lvls): + c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl], + anchor_strides[lvl], + featmap_sizes[lvl]) + s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl], + anchor_strides[lvl]) + + # offset = ctr_offset + shape_offset + offset_x = s_offset_x + c_offset_x[:, None] + offset_y = s_offset_y + c_offset_y[:, None] + + # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9) + offset = torch.stack([offset_y, offset_x], dim=-1) + offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2] + mlvl_offset.append(offset) + offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2] + offset_list = images_to_levels(offset_list, num_level_anchors) + return offset_list + + def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + anchors: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, avg_factor: int) -> tuple: + """Loss function on single scale.""" + # classification loss + if self.with_cls: + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + # regression loss + bbox_targets = bbox_targets.reshape(-1, 4) + bbox_weights = bbox_weights.reshape(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + if self.reg_decoded_bbox: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, it + # decodes the already encoded coordinates to absolute format. + anchors = anchors.reshape(-1, 4) + bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) + loss_reg = self.loss_bbox( + bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) + if self.with_cls: + return loss_cls, loss_reg + return None, loss_reg + + def loss_by_feat( + self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Compute losses of the head. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + featmap_sizes, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + return_sampling_results=True) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor, sampling_results_list) = cls_reg_targets + if not sampling_results_list[0].avg_factor_with_neg: + # 200 is hard-coded average factor, + # which follows guided anchoring. + avg_factor = sum([label.numel() for label in labels_list]) / 200.0 + + # change per image, per level anchor_list to per_level, per_image + mlvl_anchor_list = list(zip(*anchor_list)) + # concat mlvl_anchor_list + mlvl_anchor_list = [ + torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list + ] + + losses = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + mlvl_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + avg_factor=avg_factor) + if self.with_cls: + return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1]) + return dict(loss_rpn_reg=losses[1]) + + def predict_by_feat(self, + anchor_list: List[List[Tensor]], + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_img_metas: List[dict], + cfg: Optional[ConfigDict] = None, + rescale: bool = False) -> InstanceList: + """Get proposal predict. Overriding to enable input ``anchor_list`` + from outside. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + batch_img_metas (list[dict], Optional): Image meta info. + cfg (:obj:`ConfigDict`, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) + + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score_list = select_single_mlvl(cls_scores, img_id) + bbox_pred_list = select_single_mlvl(bbox_preds, img_id) + proposals = self._predict_by_feat_single( + cls_scores=cls_score_list, + bbox_preds=bbox_pred_list, + mlvl_anchors=anchor_list[img_id], + img_meta=batch_img_metas[img_id], + cfg=cfg, + rescale=rescale) + result_list.append(proposals) + return result_list + + def _predict_by_feat_single(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + mlvl_anchors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False) -> InstanceData: + """Transform outputs of a single image into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has + shape (num_anchors * 4, H, W). + mlvl_anchors (list[Tensor]): Box reference from all scale + levels of a single image, each item has shape + (num_total_anchors, 4). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arange as + (w_scale, h_scale, w_scale, h_scale). + cfg (:obj:`ConfigDict`): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + # bboxes from different level should be independent during NMS, + # level_ids are used as labels for batched NMS to separate them + level_ids = [] + mlvl_scores = [] + mlvl_bbox_preds = [] + mlvl_valid_anchors = [] + nms_pre = cfg.get('nms_pre', -1) + for idx in range(len(cls_scores)): + rpn_cls_score = cls_scores[idx] + rpn_bbox_pred = bbox_preds[idx] + assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] + rpn_cls_score = rpn_cls_score.permute(1, 2, 0) + if self.use_sigmoid_cls: + rpn_cls_score = rpn_cls_score.reshape(-1) + scores = rpn_cls_score.sigmoid() + else: + rpn_cls_score = rpn_cls_score.reshape(-1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = rpn_cls_score.softmax(dim=1)[:, 0] + rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) + anchors = mlvl_anchors[idx] + + if 0 < nms_pre < scores.shape[0]: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] + rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] + anchors = anchors[topk_inds, :] + mlvl_scores.append(scores) + mlvl_bbox_preds.append(rpn_bbox_pred) + mlvl_valid_anchors.append(anchors) + level_ids.append( + scores.new_full((scores.size(0), ), idx, dtype=torch.long)) + + anchors = torch.cat(mlvl_valid_anchors) + rpn_bbox_pred = torch.cat(mlvl_bbox_preds) + bboxes = self.bbox_coder.decode( + anchors, rpn_bbox_pred, max_shape=img_meta['img_shape']) + + proposals = InstanceData() + proposals.bboxes = bboxes + proposals.scores = torch.cat(mlvl_scores) + proposals.level_ids = torch.cat(level_ids) + + return self._bbox_post_process( + results=proposals, cfg=cfg, rescale=rescale, img_meta=img_meta) + + def refine_bboxes(self, anchor_list: List[List[Tensor]], + bbox_preds: List[Tensor], + img_metas: List[dict]) -> List[List[Tensor]]: + """Refine bboxes through stages.""" + num_levels = len(bbox_preds) + new_anchor_list = [] + for img_id in range(len(img_metas)): + mlvl_anchors = [] + for i in range(num_levels): + bbox_pred = bbox_preds[i][img_id].detach() + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + img_shape = img_metas[img_id]['img_shape'] + bboxes = self.bbox_coder.decode(anchor_list[img_id][i], + bbox_pred, img_shape) + mlvl_anchors.append(bboxes) + new_anchor_list.append(mlvl_anchors) + return new_anchor_list + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, _, batch_img_metas = outputs + + featmap_sizes = [featmap.size()[-2:] for featmap in x] + device = x[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + if self.adapt_cfg['type'] == 'offset': + offset_list = self.anchor_offset(anchor_list, self.anchor_strides, + featmap_sizes) + else: + offset_list = None + + x, cls_score, bbox_pred = self(x, offset_list) + rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, bbox_pred, + batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*rpn_loss_inputs) + + return losses + + def loss_and_predict( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + proposal_cfg: Optional[ConfigDict] = None, + ) -> Tuple[dict, InstanceList]: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. + + Args: + x (tuple[Tensor]): Features from FPN. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + proposal_cfg (:obj`ConfigDict`, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + + Returns: + tuple: the return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each image after the post process. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, _, batch_img_metas = outputs + + featmap_sizes = [featmap.size()[-2:] for featmap in x] + device = x[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + if self.adapt_cfg['type'] == 'offset': + offset_list = self.anchor_offset(anchor_list, self.anchor_strides, + featmap_sizes) + else: + offset_list = None + + x, cls_score, bbox_pred = self(x, offset_list) + rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, bbox_pred, + batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*rpn_loss_inputs) + + predictions = self.predict_by_feat( + anchor_list, + cls_score, + bbox_pred, + batch_img_metas=batch_img_metas, + cfg=proposal_cfg) + return losses, predictions + + def predict(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + featmap_sizes = [featmap.size()[-2:] for featmap in x] + device = x[0].device + anchor_list, _ = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + if self.adapt_cfg['type'] == 'offset': + offset_list = self.anchor_offset(anchor_list, self.anchor_strides, + featmap_sizes) + else: + offset_list = None + + x, cls_score, bbox_pred = self(x, offset_list) + predictions = self.stages[-1].predict_by_feat( + anchor_list, + cls_score, + bbox_pred, + batch_img_metas=batch_img_metas, + rescale=rescale) + return predictions + + +@MODELS.register_module() +class CascadeRPNHead(BaseDenseHead): + """The CascadeRPNHead will predict more accurate region proposals, which is + required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN + consists of a sequence of RPNStage to progressively improve the accuracy of + the detected proposals. + + More details can be found in ``https://arxiv.org/abs/1909.06720``. + + Args: + num_stages (int): number of CascadeRPN stages. + stages (list[:obj:`ConfigDict` or dict]): list of configs to build + the stages. + train_cfg (list[:obj:`ConfigDict` or dict]): list of configs at + training time each stage. + test_cfg (:obj:`ConfigDict` or dict): config at testing time. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \ + list[dict]): Initialization config dict. + """ + + def __init__(self, + num_classes: int, + num_stages: int, + stages: List[ConfigType], + train_cfg: List[ConfigType], + test_cfg: ConfigType, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + assert num_classes == 1, 'Only support num_classes == 1' + assert num_stages == len(stages) + self.num_stages = num_stages + # Be careful! Pretrained weights cannot be loaded when use + # nn.ModuleList + self.stages = ModuleList() + for i in range(len(stages)): + train_cfg_i = train_cfg[i] if train_cfg is not None else None + stages[i].update(train_cfg=train_cfg_i) + stages[i].update(test_cfg=test_cfg) + self.stages.append(MODELS.build(stages[i])) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def loss_by_feat(self): + """loss_by_feat() is implemented in StageCascadeRPNHead.""" + pass + + def predict_by_feat(self): + """predict_by_feat() is implemented in StageCascadeRPNHead.""" + pass + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, _, batch_img_metas = outputs + + featmap_sizes = [featmap.size()[-2:] for featmap in x] + device = x[0].device + anchor_list, valid_flag_list = self.stages[0].get_anchors( + featmap_sizes, batch_img_metas, device=device) + + losses = dict() + + for i in range(self.num_stages): + stage = self.stages[i] + + if stage.adapt_cfg['type'] == 'offset': + offset_list = stage.anchor_offset(anchor_list, + stage.anchor_strides, + featmap_sizes) + else: + offset_list = None + x, cls_score, bbox_pred = stage(x, offset_list) + rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, + bbox_pred, batch_gt_instances, batch_img_metas) + stage_loss = stage.loss_by_feat(*rpn_loss_inputs) + for name, value in stage_loss.items(): + losses['s{}.{}'.format(i, name)] = value + + # refine boxes + if i < self.num_stages - 1: + anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, + batch_img_metas) + + return losses + + def loss_and_predict( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + proposal_cfg: Optional[ConfigDict] = None, + ) -> Tuple[dict, InstanceList]: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. + + Args: + x (tuple[Tensor]): Features from FPN. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + proposal_cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + + Returns: + tuple: the return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each image after the post process. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, _, batch_img_metas = outputs + + featmap_sizes = [featmap.size()[-2:] for featmap in x] + device = x[0].device + anchor_list, valid_flag_list = self.stages[0].get_anchors( + featmap_sizes, batch_img_metas, device=device) + + losses = dict() + + for i in range(self.num_stages): + stage = self.stages[i] + + if stage.adapt_cfg['type'] == 'offset': + offset_list = stage.anchor_offset(anchor_list, + stage.anchor_strides, + featmap_sizes) + else: + offset_list = None + x, cls_score, bbox_pred = stage(x, offset_list) + rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, + bbox_pred, batch_gt_instances, batch_img_metas) + stage_loss = stage.loss_by_feat(*rpn_loss_inputs) + for name, value in stage_loss.items(): + losses['s{}.{}'.format(i, name)] = value + + # refine boxes + if i < self.num_stages - 1: + anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, + batch_img_metas) + + predictions = self.stages[-1].predict_by_feat( + anchor_list, + cls_score, + bbox_pred, + batch_img_metas=batch_img_metas, + cfg=proposal_cfg) + return losses, predictions + + def predict(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + featmap_sizes = [featmap.size()[-2:] for featmap in x] + device = x[0].device + anchor_list, _ = self.stages[0].get_anchors( + featmap_sizes, batch_img_metas, device=device) + + for i in range(self.num_stages): + stage = self.stages[i] + if stage.adapt_cfg['type'] == 'offset': + offset_list = stage.anchor_offset(anchor_list, + stage.anchor_strides, + featmap_sizes) + else: + offset_list = None + x, cls_score, bbox_pred = stage(x, offset_list) + if i < self.num_stages - 1: + anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, + batch_img_metas) + + predictions = self.stages[-1].predict_by_feat( + anchor_list, + cls_score, + bbox_pred, + batch_img_metas=batch_img_metas, + rescale=rescale) + return predictions diff --git a/mmdet/models/dense_heads/centernet_head.py b/mmdet/models/dense_heads/centernet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..09f3e599eb176965e53f270014cbd326858b7c17 --- /dev/null +++ b/mmdet/models/dense_heads/centernet_head.py @@ -0,0 +1,447 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.ops import batched_nms +from mmengine.config import ConfigDict +from mmengine.model import bias_init_with_prob, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, OptMultiConfig) +from ..utils import (gaussian_radius, gen_gaussian_target, get_local_maximum, + get_topk_from_heatmap, multi_apply, + transpose_and_gather_feat) +from .base_dense_head import BaseDenseHead + + +@MODELS.register_module() +class CenterNetHead(BaseDenseHead): + """Objects as Points Head. CenterHead use center_point to indicate object's + position. Paper link + + Args: + in_channels (int): Number of channel in the input feature map. + feat_channels (int): Number of channel in the intermediate feature map. + num_classes (int): Number of categories excluding the background + category. + loss_center_heatmap (:obj:`ConfigDict` or dict): Config of center + heatmap loss. Defaults to + dict(type='GaussianFocalLoss', loss_weight=1.0) + loss_wh (:obj:`ConfigDict` or dict): Config of wh loss. Defaults to + dict(type='L1Loss', loss_weight=0.1). + loss_offset (:obj:`ConfigDict` or dict): Config of offset loss. + Defaults to dict(type='L1Loss', loss_weight=1.0). + train_cfg (:obj:`ConfigDict` or dict, optional): Training config. + Useless in CenterNet, but we keep this variable for + SingleStageDetector. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config + of CenterNet. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`], optional): Initialization + config dict. + """ + + def __init__(self, + in_channels: int, + feat_channels: int, + num_classes: int, + loss_center_heatmap: ConfigType = dict( + type='GaussianFocalLoss', loss_weight=1.0), + loss_wh: ConfigType = dict(type='L1Loss', loss_weight=0.1), + loss_offset: ConfigType = dict( + type='L1Loss', loss_weight=1.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.heatmap_head = self._build_head(in_channels, feat_channels, + num_classes) + self.wh_head = self._build_head(in_channels, feat_channels, 2) + self.offset_head = self._build_head(in_channels, feat_channels, 2) + + self.loss_center_heatmap = MODELS.build(loss_center_heatmap) + self.loss_wh = MODELS.build(loss_wh) + self.loss_offset = MODELS.build(loss_offset) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.fp16_enabled = False + + def _build_head(self, in_channels: int, feat_channels: int, + out_channels: int) -> nn.Sequential: + """Build head for each branch.""" + layer = nn.Sequential( + nn.Conv2d(in_channels, feat_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(feat_channels, out_channels, kernel_size=1)) + return layer + + def init_weights(self) -> None: + """Initialize weights of the head.""" + bias_init = bias_init_with_prob(0.1) + self.heatmap_head[-1].bias.data.fill_(bias_init) + for head in [self.wh_head, self.offset_head]: + for m in head.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + + def forward(self, x: Tuple[Tensor, ...]) -> Tuple[List[Tensor]]: + """Forward features. Notice CenterNet head does not use FPN. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + center_heatmap_preds (list[Tensor]): center predict heatmaps for + all levels, the channels number is num_classes. + wh_preds (list[Tensor]): wh predicts for all levels, the channels + number is 2. + offset_preds (list[Tensor]): offset predicts for all levels, the + channels number is 2. + """ + return multi_apply(self.forward_single, x) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]: + """Forward feature of a single level. + + Args: + x (Tensor): Feature of a single level. + + Returns: + center_heatmap_pred (Tensor): center predict heatmaps, the + channels number is num_classes. + wh_pred (Tensor): wh predicts, the channels number is 2. + offset_pred (Tensor): offset predicts, the channels number is 2. + """ + center_heatmap_pred = self.heatmap_head(x).sigmoid() + wh_pred = self.wh_head(x) + offset_pred = self.offset_head(x) + return center_heatmap_pred, wh_pred, offset_pred + + def loss_by_feat( + self, + center_heatmap_preds: List[Tensor], + wh_preds: List[Tensor], + offset_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Compute losses of the head. + + Args: + center_heatmap_preds (list[Tensor]): center predict heatmaps for + all levels with shape (B, num_classes, H, W). + wh_preds (list[Tensor]): wh predicts for all levels with + shape (B, 2, H, W). + offset_preds (list[Tensor]): offset predicts for all levels + with shape (B, 2, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: which has components below: + - loss_center_heatmap (Tensor): loss of center heatmap. + - loss_wh (Tensor): loss of hw heatmap + - loss_offset (Tensor): loss of offset heatmap. + """ + assert len(center_heatmap_preds) == len(wh_preds) == len( + offset_preds) == 1 + center_heatmap_pred = center_heatmap_preds[0] + wh_pred = wh_preds[0] + offset_pred = offset_preds[0] + + gt_bboxes = [ + gt_instances.bboxes for gt_instances in batch_gt_instances + ] + gt_labels = [ + gt_instances.labels for gt_instances in batch_gt_instances + ] + img_shape = batch_img_metas[0]['batch_input_shape'] + target_result, avg_factor = self.get_targets(gt_bboxes, gt_labels, + center_heatmap_pred.shape, + img_shape) + + center_heatmap_target = target_result['center_heatmap_target'] + wh_target = target_result['wh_target'] + offset_target = target_result['offset_target'] + wh_offset_target_weight = target_result['wh_offset_target_weight'] + + # Since the channel of wh_target and offset_target is 2, the avg_factor + # of loss_center_heatmap is always 1/2 of loss_wh and loss_offset. + loss_center_heatmap = self.loss_center_heatmap( + center_heatmap_pred, center_heatmap_target, avg_factor=avg_factor) + loss_wh = self.loss_wh( + wh_pred, + wh_target, + wh_offset_target_weight, + avg_factor=avg_factor * 2) + loss_offset = self.loss_offset( + offset_pred, + offset_target, + wh_offset_target_weight, + avg_factor=avg_factor * 2) + return dict( + loss_center_heatmap=loss_center_heatmap, + loss_wh=loss_wh, + loss_offset=loss_offset) + + def get_targets(self, gt_bboxes: List[Tensor], gt_labels: List[Tensor], + feat_shape: tuple, img_shape: tuple) -> Tuple[dict, int]: + """Compute regression and classification targets in multiple images. + + Args: + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box. + feat_shape (tuple): feature map shape with value [B, _, H, W] + img_shape (tuple): image shape. + + Returns: + tuple[dict, float]: The float value is mean avg_factor, the dict + has components below: + - center_heatmap_target (Tensor): targets of center heatmap, \ + shape (B, num_classes, H, W). + - wh_target (Tensor): targets of wh predict, shape \ + (B, 2, H, W). + - offset_target (Tensor): targets of offset predict, shape \ + (B, 2, H, W). + - wh_offset_target_weight (Tensor): weights of wh and offset \ + predict, shape (B, 2, H, W). + """ + img_h, img_w = img_shape[:2] + bs, _, feat_h, feat_w = feat_shape + + width_ratio = float(feat_w / img_w) + height_ratio = float(feat_h / img_h) + + center_heatmap_target = gt_bboxes[-1].new_zeros( + [bs, self.num_classes, feat_h, feat_w]) + wh_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w]) + offset_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w]) + wh_offset_target_weight = gt_bboxes[-1].new_zeros( + [bs, 2, feat_h, feat_w]) + + for batch_id in range(bs): + gt_bbox = gt_bboxes[batch_id] + gt_label = gt_labels[batch_id] + center_x = (gt_bbox[:, [0]] + gt_bbox[:, [2]]) * width_ratio / 2 + center_y = (gt_bbox[:, [1]] + gt_bbox[:, [3]]) * height_ratio / 2 + gt_centers = torch.cat((center_x, center_y), dim=1) + + for j, ct in enumerate(gt_centers): + ctx_int, cty_int = ct.int() + ctx, cty = ct + scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio + scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio + radius = gaussian_radius([scale_box_h, scale_box_w], + min_overlap=0.3) + radius = max(0, int(radius)) + ind = gt_label[j] + gen_gaussian_target(center_heatmap_target[batch_id, ind], + [ctx_int, cty_int], radius) + + wh_target[batch_id, 0, cty_int, ctx_int] = scale_box_w + wh_target[batch_id, 1, cty_int, ctx_int] = scale_box_h + + offset_target[batch_id, 0, cty_int, ctx_int] = ctx - ctx_int + offset_target[batch_id, 1, cty_int, ctx_int] = cty - cty_int + + wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1 + + avg_factor = max(1, center_heatmap_target.eq(1).sum()) + target_result = dict( + center_heatmap_target=center_heatmap_target, + wh_target=wh_target, + offset_target=offset_target, + wh_offset_target_weight=wh_offset_target_weight) + return target_result, avg_factor + + def predict_by_feat(self, + center_heatmap_preds: List[Tensor], + wh_preds: List[Tensor], + offset_preds: List[Tensor], + batch_img_metas: Optional[List[dict]] = None, + rescale: bool = True, + with_nms: bool = False) -> InstanceList: + """Transform network output for a batch into bbox predictions. + + Args: + center_heatmap_preds (list[Tensor]): Center predict heatmaps for + all levels with shape (B, num_classes, H, W). + wh_preds (list[Tensor]): WH predicts for all levels with + shape (B, 2, H, W). + offset_preds (list[Tensor]): Offset predicts for all levels + with shape (B, 2, H, W). + batch_img_metas (list[dict], optional): Batch image meta info. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to True. + with_nms (bool): If True, do nms before return boxes. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Instance segmentation + results of each image after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(center_heatmap_preds) == len(wh_preds) == len( + offset_preds) == 1 + result_list = [] + for img_id in range(len(batch_img_metas)): + result_list.append( + self._predict_by_feat_single( + center_heatmap_preds[0][img_id:img_id + 1, ...], + wh_preds[0][img_id:img_id + 1, ...], + offset_preds[0][img_id:img_id + 1, ...], + batch_img_metas[img_id], + rescale=rescale, + with_nms=with_nms)) + return result_list + + def _predict_by_feat_single(self, + center_heatmap_pred: Tensor, + wh_pred: Tensor, + offset_pred: Tensor, + img_meta: dict, + rescale: bool = True, + with_nms: bool = False) -> InstanceData: + """Transform outputs of a single image into bbox results. + + Args: + center_heatmap_pred (Tensor): Center heatmap for current level with + shape (1, num_classes, H, W). + wh_pred (Tensor): WH heatmap for current level with shape + (1, num_classes, H, W). + offset_pred (Tensor): Offset for current level with shape + (1, corner_offset_channels, H, W). + img_meta (dict): Meta information of current image, e.g., + image size, scaling factor, etc. + rescale (bool): If True, return boxes in original image space. + Defaults to True. + with_nms (bool): If True, do nms before return boxes. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + batch_det_bboxes, batch_labels = self._decode_heatmap( + center_heatmap_pred, + wh_pred, + offset_pred, + img_meta['batch_input_shape'], + k=self.test_cfg.topk, + kernel=self.test_cfg.local_maximum_kernel) + + det_bboxes = batch_det_bboxes.view([-1, 5]) + det_labels = batch_labels.view(-1) + + batch_border = det_bboxes.new_tensor(img_meta['border'])[..., + [2, 0, 2, 0]] + det_bboxes[..., :4] -= batch_border + + if rescale and 'scale_factor' in img_meta: + det_bboxes[..., :4] /= det_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + + if with_nms: + det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels, + self.test_cfg) + results = InstanceData() + results.bboxes = det_bboxes[..., :4] + results.scores = det_bboxes[..., 4] + results.labels = det_labels + return results + + def _decode_heatmap(self, + center_heatmap_pred: Tensor, + wh_pred: Tensor, + offset_pred: Tensor, + img_shape: tuple, + k: int = 100, + kernel: int = 3) -> Tuple[Tensor, Tensor]: + """Transform outputs into detections raw bbox prediction. + + Args: + center_heatmap_pred (Tensor): center predict heatmap, + shape (B, num_classes, H, W). + wh_pred (Tensor): wh predict, shape (B, 2, H, W). + offset_pred (Tensor): offset predict, shape (B, 2, H, W). + img_shape (tuple): image shape in hw format. + k (int): Get top k center keypoints from heatmap. Defaults to 100. + kernel (int): Max pooling kernel for extract local maximum pixels. + Defaults to 3. + + Returns: + tuple[Tensor]: Decoded output of CenterNetHead, containing + the following Tensors: + + - batch_bboxes (Tensor): Coords of each box with shape (B, k, 5) + - batch_topk_labels (Tensor): Categories of each box with \ + shape (B, k) + """ + height, width = center_heatmap_pred.shape[2:] + inp_h, inp_w = img_shape + + center_heatmap_pred = get_local_maximum( + center_heatmap_pred, kernel=kernel) + + *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap( + center_heatmap_pred, k=k) + batch_scores, batch_index, batch_topk_labels = batch_dets + + wh = transpose_and_gather_feat(wh_pred, batch_index) + offset = transpose_and_gather_feat(offset_pred, batch_index) + topk_xs = topk_xs + offset[..., 0] + topk_ys = topk_ys + offset[..., 1] + tl_x = (topk_xs - wh[..., 0] / 2) * (inp_w / width) + tl_y = (topk_ys - wh[..., 1] / 2) * (inp_h / height) + br_x = (topk_xs + wh[..., 0] / 2) * (inp_w / width) + br_y = (topk_ys + wh[..., 1] / 2) * (inp_h / height) + + batch_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=2) + batch_bboxes = torch.cat((batch_bboxes, batch_scores[..., None]), + dim=-1) + return batch_bboxes, batch_topk_labels + + def _bboxes_nms(self, bboxes: Tensor, labels: Tensor, + cfg: ConfigDict) -> Tuple[Tensor, Tensor]: + """bboxes nms.""" + if labels.numel() > 0: + max_num = cfg.max_per_img + bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, + -1].contiguous(), + labels, cfg.nms) + if max_num > 0: + bboxes = bboxes[:max_num] + labels = labels[keep][:max_num] + + return bboxes, labels diff --git a/mmdet/models/dense_heads/centernet_update_head.py b/mmdet/models/dense_heads/centernet_update_head.py new file mode 100644 index 0000000000000000000000000000000000000000..00cfcb89806209c9416b1bd7e9a14d82a4911175 --- /dev/null +++ b/mmdet/models/dense_heads/centernet_update_head.py @@ -0,0 +1,624 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import Scale +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox2distance +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, reduce_mean) +from ..utils import multi_apply +from .anchor_free_head import AnchorFreeHead + +INF = 1000000000 +RangeType = Sequence[Tuple[int, int]] + + +def _transpose(tensor_list: List[Tensor], + num_point_list: list) -> List[Tensor]: + """This function is used to transpose image first tensors to level first + ones.""" + for img_idx in range(len(tensor_list)): + tensor_list[img_idx] = torch.split( + tensor_list[img_idx], num_point_list, dim=0) + + tensors_level_first = [] + for targets_per_level in zip(*tensor_list): + tensors_level_first.append(torch.cat(targets_per_level, dim=0)) + return tensors_level_first + + +@MODELS.register_module() +class CenterNetUpdateHead(AnchorFreeHead): + """CenterNetUpdateHead is an improved version of CenterNet in CenterNet2. + Paper link ``_. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channel in the input feature map. + regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple + level points. + hm_min_radius (int): Heatmap target minimum radius of cls branch. + Defaults to 4. + hm_min_overlap (float): Heatmap target minimum overlap of cls branch. + Defaults to 0.8. + more_pos_thresh (float): The filtering threshold when the cls branch + adds more positive samples. Defaults to 0.2. + more_pos_topk (int): The maximum number of additional positive samples + added to each gt. Defaults to 9. + soft_weight_on_reg (bool): Whether to use the soft target of the + cls branch as the soft weight of the bbox branch. + Defaults to False. + loss_cls (:obj:`ConfigDict` or dict): Config of cls loss. Defaults to + dict(type='GaussianFocalLoss', loss_weight=1.0) + loss_bbox (:obj:`ConfigDict` or dict): Config of bbox loss. Defaults to + dict(type='GIoULoss', loss_weight=2.0). + norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct + and config norm layer. Defaults to + ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config. + Unused in CenterNet. Reserved for compatibility with + SingleStageDetector. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config + of CenterNet. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + regress_ranges: RangeType = ((0, 80), (64, 160), (128, 320), + (256, 640), (512, INF)), + hm_min_radius: int = 4, + hm_min_overlap: float = 0.8, + more_pos_thresh: float = 0.2, + more_pos_topk: int = 9, + soft_weight_on_reg: bool = False, + loss_cls: ConfigType = dict( + type='GaussianFocalLoss', + pos_weight=0.25, + neg_weight=0.75, + loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='GIoULoss', loss_weight=2.0), + norm_cfg: OptConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + **kwargs) -> None: + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + norm_cfg=norm_cfg, + train_cfg=train_cfg, + test_cfg=test_cfg, + **kwargs) + self.soft_weight_on_reg = soft_weight_on_reg + self.hm_min_radius = hm_min_radius + self.more_pos_thresh = more_pos_thresh + self.more_pos_topk = more_pos_topk + self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap) + self.sigmoid_clamp = 0.0001 + + # GaussianFocalLoss must be sigmoid mode + self.use_sigmoid_cls = True + self.cls_out_channels = num_classes + + self.regress_ranges = regress_ranges + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + def _init_predictor(self) -> None: + """Initialize predictor layers of the head.""" + self.conv_cls = nn.Conv2d( + self.feat_channels, self.num_classes, 3, padding=1) + self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of each level outputs. + + - cls_scores (list[Tensor]): Box scores for each scale level, \ + each is a 4D-tensor, the channel number is num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for each \ + scale level, each is a 4D-tensor, the channel number is 4. + """ + return multi_apply(self.forward_single, x, self.scales, self.strides) + + def forward_single(self, x: Tensor, scale: Scale, + stride: int) -> Tuple[Tensor, Tensor]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps. + + Returns: + tuple: scores for each class, bbox predictions of + input feature maps. + """ + cls_score, bbox_pred, _, _ = super().forward_single(x) + # scale the bbox_pred of different level + # float to avoid overflow when enabling FP16 + bbox_pred = scale(bbox_pred).float() + # bbox_pred needed for gradient computation has been modified + # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace + # F.relu(bbox_pred) with bbox_pred.clamp(min=0) + bbox_pred = bbox_pred.clamp(min=0) + if not self.training: + bbox_pred *= stride + return cls_score, bbox_pred + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is 4. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = cls_scores[0].size(0) + assert len(cls_scores) == len(bbox_preds) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + + # 1 flatten outputs + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + assert (torch.isfinite(flatten_bbox_preds).all().item()) + + # 2 calc reg and cls branch targets + cls_targets, bbox_targets = self.get_targets(all_level_points, + batch_gt_instances) + + # 3 add more pos index for cls branch + featmap_sizes = flatten_points.new_tensor(featmap_sizes) + pos_inds, cls_labels = self.add_cls_pos_inds(flatten_points, + flatten_bbox_preds, + featmap_sizes, + batch_gt_instances) + + # 4 calc cls loss + if pos_inds is None: + # num_gts=0 + num_pos_cls = bbox_preds[0].new_tensor(0, dtype=torch.float) + else: + num_pos_cls = bbox_preds[0].new_tensor( + len(pos_inds), dtype=torch.float) + num_pos_cls = max(reduce_mean(num_pos_cls), 1.0) + flatten_cls_scores = flatten_cls_scores.sigmoid().clamp( + min=self.sigmoid_clamp, max=1 - self.sigmoid_clamp) + cls_loss = self.loss_cls( + flatten_cls_scores, + cls_targets, + pos_inds=pos_inds, + pos_labels=cls_labels, + avg_factor=num_pos_cls) + + # 5 calc reg loss + pos_bbox_inds = torch.nonzero( + bbox_targets.max(dim=1)[0] >= 0).squeeze(1) + pos_bbox_preds = flatten_bbox_preds[pos_bbox_inds] + pos_bbox_targets = bbox_targets[pos_bbox_inds] + + bbox_weight_map = cls_targets.max(dim=1)[0] + bbox_weight_map = bbox_weight_map[pos_bbox_inds] + bbox_weight_map = bbox_weight_map if self.soft_weight_on_reg \ + else torch.ones_like(bbox_weight_map) + num_pos_bbox = max(reduce_mean(bbox_weight_map.sum()), 1.0) + + if len(pos_bbox_inds) > 0: + pos_points = flatten_points[pos_bbox_inds] + pos_decoded_bbox_preds = self.bbox_coder.decode( + pos_points, pos_bbox_preds) + pos_decoded_target_preds = self.bbox_coder.decode( + pos_points, pos_bbox_targets) + bbox_loss = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + weight=bbox_weight_map, + avg_factor=num_pos_bbox) + else: + bbox_loss = flatten_bbox_preds.sum() * 0 + + return dict(loss_cls=cls_loss, loss_bbox=bbox_loss) + + def get_targets( + self, + points: List[Tensor], + batch_gt_instances: InstanceList, + ) -> Tuple[Tensor, Tensor]: + """Compute classification and bbox targets for points in multiple + images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: Targets of each level. + + - concat_lvl_labels (Tensor): Labels of all level and batch. + - concat_lvl_bbox_targets (Tensor): BBox targets of all \ + level and batch. + """ + assert len(points) == len(self.regress_ranges) + + num_levels = len(points) + # the number of points per img, per lvl + num_points = [center.size(0) for center in points] + + # expand regress ranges to align with points + expanded_regress_ranges = [ + points[i].new_tensor(self.regress_ranges[i])[None].expand_as( + points[i]) for i in range(num_levels) + ] + # concat all levels points and regress ranges + concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) + concat_points = torch.cat(points, dim=0) + concat_strides = torch.cat([ + concat_points.new_ones(num_points[i]) * self.strides[i] + for i in range(num_levels) + ]) + + # get labels and bbox_targets of each image + cls_targets_list, bbox_targets_list = multi_apply( + self._get_targets_single, + batch_gt_instances, + points=concat_points, + regress_ranges=concat_regress_ranges, + strides=concat_strides) + + bbox_targets_list = _transpose(bbox_targets_list, num_points) + cls_targets_list = _transpose(cls_targets_list, num_points) + concat_lvl_bbox_targets = torch.cat(bbox_targets_list, 0) + concat_lvl_cls_targets = torch.cat(cls_targets_list, dim=0) + return concat_lvl_cls_targets, concat_lvl_bbox_targets + + def _get_targets_single(self, gt_instances: InstanceData, points: Tensor, + regress_ranges: Tensor, + strides: Tensor) -> Tuple[Tensor, Tensor]: + """Compute classification and bbox targets for a single image.""" + num_points = points.size(0) + num_gts = len(gt_instances) + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + + if num_gts == 0: + return gt_labels.new_full((num_points, + self.num_classes), + self.num_classes), \ + gt_bboxes.new_full((num_points, 4), -1) + + # Calculate the regression tblr target corresponding to all points + points = points[:, None].expand(num_points, num_gts, 2) + gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) + strides = strides[:, None, None].expand(num_points, num_gts, 2) + + bbox_target = bbox2distance(points, gt_bboxes) # M x N x 4 + + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_target.min(dim=2)[0] > 0 # M x N + + # condition2: Calculate the nearest points from + # the upper, lower, left and right ranges from + # the center of the gt bbox + centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2) + centers_discret = ((centers / strides).int() * strides).float() + \ + strides / 2 + + centers_discret_dist = points - centers_discret + dist_x = centers_discret_dist[..., 0].abs() + dist_y = centers_discret_dist[..., 1].abs() + inside_gt_center3x3_mask = (dist_x <= strides[..., 0]) & \ + (dist_y <= strides[..., 0]) + + # condition3: limit the regression range for each location + bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:] + crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2 + inside_fpn_level_mask = (crit >= regress_ranges[:, [0]]) & \ + (crit <= regress_ranges[:, [1]]) + bbox_target_mask = inside_gt_bbox_mask & \ + inside_gt_center3x3_mask & \ + inside_fpn_level_mask + + # Calculate the distance weight map + gt_center_peak_mask = ((centers_discret_dist**2).sum(dim=2) == 0) + weighted_dist = ((points - centers)**2).sum(dim=2) # M x N + weighted_dist[gt_center_peak_mask] = 0 + + areas = (gt_bboxes[..., 2] - gt_bboxes[..., 0]) * ( + gt_bboxes[..., 3] - gt_bboxes[..., 1]) + radius = self.delta**2 * 2 * areas + radius = torch.clamp(radius, min=self.hm_min_radius**2) + weighted_dist = weighted_dist / radius + + # Calculate bbox_target + bbox_weighted_dist = weighted_dist.clone() + bbox_weighted_dist[bbox_target_mask == 0] = INF * 1.0 + min_dist, min_inds = bbox_weighted_dist.min(dim=1) + bbox_target = bbox_target[range(len(bbox_target)), + min_inds] # M x N x 4 --> M x 4 + bbox_target[min_dist == INF] = -INF + + # Convert to feature map scale + bbox_target /= strides[:, 0, :].repeat(1, 2) + + # Calculate cls_target + cls_target = self._create_heatmaps_from_dist(weighted_dist, gt_labels) + + return cls_target, bbox_target + + @torch.no_grad() + def add_cls_pos_inds( + self, flatten_points: Tensor, flatten_bbox_preds: Tensor, + featmap_sizes: Tensor, batch_gt_instances: InstanceList + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + """Provide additional adaptive positive samples to the classification + branch. + + Args: + flatten_points (Tensor): The point after flatten, including + batch image and all levels. The shape is (N, 2). + flatten_bbox_preds (Tensor): The bbox predicts after flatten, + including batch image and all levels. The shape is (N, 4). + featmap_sizes (Tensor): Feature map size of all layers. + The shape is (5, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: + + - pos_inds (Tensor): Adaptively selected positive sample index. + - cls_labels (Tensor): Corresponding positive class label. + """ + outputs = self._get_center3x3_region_index_targets( + batch_gt_instances, featmap_sizes) + cls_labels, fpn_level_masks, center3x3_inds, \ + center3x3_bbox_targets, center3x3_masks = outputs + + num_gts, total_level, K = cls_labels.shape[0], len( + self.strides), center3x3_masks.shape[-1] + + if num_gts == 0: + return None, None + + # The out-of-bounds index is forcibly set to 0 + # to prevent loss calculation errors + center3x3_inds[center3x3_masks == 0] = 0 + reg_pred_center3x3 = flatten_bbox_preds[center3x3_inds] + center3x3_points = flatten_points[center3x3_inds].view(-1, 2) + + center3x3_bbox_targets_expand = center3x3_bbox_targets.view( + -1, 4).clamp(min=0) + + pos_decoded_bbox_preds = self.bbox_coder.decode( + center3x3_points, reg_pred_center3x3.view(-1, 4)) + pos_decoded_target_preds = self.bbox_coder.decode( + center3x3_points, center3x3_bbox_targets_expand) + center3x3_bbox_loss = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + None, + reduction_override='none').view(num_gts, total_level, + K) / self.loss_bbox.loss_weight + + # Invalid index Loss set to infinity + center3x3_bbox_loss[center3x3_masks == 0] = INF + + # 4 is the center point of the sampled 9 points, the center point + # of gt bbox after discretization. + # The center point of gt bbox after discretization + # must be a positive sample, so we force its loss to be set to 0. + center3x3_bbox_loss.view(-1, K)[fpn_level_masks.view(-1), 4] = 0 + center3x3_bbox_loss = center3x3_bbox_loss.view(num_gts, -1) + + loss_thr = torch.kthvalue( + center3x3_bbox_loss, self.more_pos_topk, dim=1)[0] + + loss_thr[loss_thr > self.more_pos_thresh] = self.more_pos_thresh + new_pos = center3x3_bbox_loss < loss_thr.view(num_gts, 1) + pos_inds = center3x3_inds.view(num_gts, -1)[new_pos] + cls_labels = cls_labels.view(num_gts, + 1).expand(num_gts, + total_level * K)[new_pos] + return pos_inds, cls_labels + + def _create_heatmaps_from_dist(self, weighted_dist: Tensor, + cls_labels: Tensor) -> Tensor: + """Generate heatmaps of classification branch based on weighted + distance map.""" + heatmaps = weighted_dist.new_zeros( + (weighted_dist.shape[0], self.num_classes)) + for c in range(self.num_classes): + inds = (cls_labels == c) # N + if inds.int().sum() == 0: + continue + heatmaps[:, c] = torch.exp(-weighted_dist[:, inds].min(dim=1)[0]) + zeros = heatmaps[:, c] < 1e-4 + heatmaps[zeros, c] = 0 + return heatmaps + + def _get_center3x3_region_index_targets(self, + bacth_gt_instances: InstanceList, + shapes_per_level: Tensor) -> tuple: + """Get the center (and the 3x3 region near center) locations and target + of each objects.""" + cls_labels = [] + inside_fpn_level_masks = [] + center3x3_inds = [] + center3x3_masks = [] + center3x3_bbox_targets = [] + + total_levels = len(self.strides) + batch = len(bacth_gt_instances) + + shapes_per_level = shapes_per_level.long() + area_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]) + + # Select a total of 9 positions of 3x3 in the center of the gt bbox + # as candidate positive samples + K = 9 + dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0, + 1]).view(1, 1, K) + dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1, + 1]).view(1, 1, K) + + regress_ranges = shapes_per_level.new_tensor(self.regress_ranges).view( + len(self.regress_ranges), 2) # L x 2 + strides = shapes_per_level.new_tensor(self.strides) + + start_coord_pre_level = [] + _start = 0 + for level in range(total_levels): + start_coord_pre_level.append(_start) + _start = _start + batch * area_per_level[level] + start_coord_pre_level = shapes_per_level.new_tensor( + start_coord_pre_level).view(1, total_levels, 1) + area_per_level = area_per_level.view(1, total_levels, 1) + + for im_i in range(batch): + gt_instance = bacth_gt_instances[im_i] + gt_bboxes = gt_instance.bboxes + gt_labels = gt_instance.labels + num_gts = gt_bboxes.shape[0] + if num_gts == 0: + continue + + cls_labels.append(gt_labels) + + gt_bboxes = gt_bboxes[:, None].expand(num_gts, total_levels, 4) + expanded_strides = strides[None, :, + None].expand(num_gts, total_levels, 2) + expanded_regress_ranges = regress_ranges[None].expand( + num_gts, total_levels, 2) + expanded_shapes_per_level = shapes_per_level[None].expand( + num_gts, total_levels, 2) + + # calc reg_target + centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2) + centers_inds = (centers / expanded_strides).long() + centers_discret = centers_inds * expanded_strides \ + + expanded_strides // 2 + + bbox_target = bbox2distance(centers_discret, + gt_bboxes) # M x N x 4 + + # calc inside_fpn_level_mask + bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:] + crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2 + inside_fpn_level_mask = \ + (crit >= expanded_regress_ranges[..., 0]) & \ + (crit <= expanded_regress_ranges[..., 1]) + + inside_gt_bbox_mask = bbox_target.min(dim=2)[0] >= 0 + inside_fpn_level_mask = inside_gt_bbox_mask & inside_fpn_level_mask + inside_fpn_level_masks.append(inside_fpn_level_mask) + + # calc center3x3_ind and mask + expand_ws = expanded_shapes_per_level[..., 1:2].expand( + num_gts, total_levels, K) + expand_hs = expanded_shapes_per_level[..., 0:1].expand( + num_gts, total_levels, K) + centers_inds_x = centers_inds[..., 0:1] + centers_inds_y = centers_inds[..., 1:2] + + center3x3_idx = start_coord_pre_level + \ + im_i * area_per_level + \ + (centers_inds_y + dy) * expand_ws + \ + (centers_inds_x + dx) + center3x3_mask = \ + ((centers_inds_y + dy) < expand_hs) & \ + ((centers_inds_y + dy) >= 0) & \ + ((centers_inds_x + dx) < expand_ws) & \ + ((centers_inds_x + dx) >= 0) + + # recalc center3x3 region reg target + bbox_target = bbox_target / expanded_strides.repeat(1, 1, 2) + center3x3_bbox_target = bbox_target[..., None, :].expand( + num_gts, total_levels, K, 4).clone() + center3x3_bbox_target[..., 0] += dx + center3x3_bbox_target[..., 1] += dy + center3x3_bbox_target[..., 2] -= dx + center3x3_bbox_target[..., 3] -= dy + # update center3x3_mask + center3x3_mask = center3x3_mask & ( + center3x3_bbox_target.min(dim=3)[0] >= 0) # n x L x K + + center3x3_inds.append(center3x3_idx) + center3x3_masks.append(center3x3_mask) + center3x3_bbox_targets.append(center3x3_bbox_target) + + if len(inside_fpn_level_masks) > 0: + cls_labels = torch.cat(cls_labels, dim=0) + inside_fpn_level_masks = torch.cat(inside_fpn_level_masks, dim=0) + center3x3_inds = torch.cat(center3x3_inds, dim=0).long() + center3x3_bbox_targets = torch.cat(center3x3_bbox_targets, dim=0) + center3x3_masks = torch.cat(center3x3_masks, dim=0) + else: + cls_labels = shapes_per_level.new_zeros(0).long() + inside_fpn_level_masks = shapes_per_level.new_zeros( + (0, total_levels)).bool() + center3x3_inds = shapes_per_level.new_zeros( + (0, total_levels, K)).long() + center3x3_bbox_targets = shapes_per_level.new_zeros( + (0, total_levels, K, 4)).float() + center3x3_masks = shapes_per_level.new_zeros( + (0, total_levels, K)).bool() + return cls_labels, inside_fpn_level_masks, center3x3_inds, \ + center3x3_bbox_targets, center3x3_masks diff --git a/mmdet/models/dense_heads/centripetal_head.py b/mmdet/models/dense_heads/centripetal_head.py new file mode 100644 index 0000000000000000000000000000000000000000..18f6601ff82394864d53351b10b40f51eb2aec6b --- /dev/null +++ b/mmdet/models/dense_heads/centripetal_head.py @@ -0,0 +1,459 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import DeformConv2d +from mmengine.model import normal_init +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, OptInstanceList, + OptMultiConfig) +from ..utils import multi_apply +from .corner_head import CornerHead + + +@MODELS.register_module() +class CentripetalHead(CornerHead): + """Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object + Detection. + + CentripetalHead inherits from :class:`CornerHead`. It removes the + embedding branch and adds guiding shift and centripetal shift branches. + More details can be found in the `paper + `_ . + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_feat_levels (int): Levels of feature from the previous module. + 2 for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104 + outputs the final feature and intermediate supervision feature and + HourglassNet-52 only outputs the final feature. Defaults to 2. + corner_emb_channels (int): Channel of embedding vector. Defaults to 1. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config. + Useless in CornerHead, but we keep this variable for + SingleStageDetector. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + CornerHead. + loss_heatmap (:obj:`ConfigDict` or dict): Config of corner heatmap + loss. Defaults to GaussianFocalLoss. + loss_embedding (:obj:`ConfigDict` or dict): Config of corner embedding + loss. Defaults to AssociativeEmbeddingLoss. + loss_offset (:obj:`ConfigDict` or dict): Config of corner offset loss. + Defaults to SmoothL1Loss. + loss_guiding_shift (:obj:`ConfigDict` or dict): Config of + guiding shift loss. Defaults to SmoothL1Loss. + loss_centripetal_shift (:obj:`ConfigDict` or dict): Config of + centripetal shift loss. Defaults to SmoothL1Loss. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. + """ + + def __init__(self, + *args, + centripetal_shift_channels: int = 2, + guiding_shift_channels: int = 2, + feat_adaption_conv_kernel: int = 3, + loss_guiding_shift: ConfigType = dict( + type='SmoothL1Loss', beta=1.0, loss_weight=0.05), + loss_centripetal_shift: ConfigType = dict( + type='SmoothL1Loss', beta=1.0, loss_weight=1), + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + assert centripetal_shift_channels == 2, ( + 'CentripetalHead only support centripetal_shift_channels == 2') + self.centripetal_shift_channels = centripetal_shift_channels + assert guiding_shift_channels == 2, ( + 'CentripetalHead only support guiding_shift_channels == 2') + self.guiding_shift_channels = guiding_shift_channels + self.feat_adaption_conv_kernel = feat_adaption_conv_kernel + super().__init__(*args, init_cfg=init_cfg, **kwargs) + self.loss_guiding_shift = MODELS.build(loss_guiding_shift) + self.loss_centripetal_shift = MODELS.build(loss_centripetal_shift) + + def _init_centripetal_layers(self) -> None: + """Initialize centripetal layers. + + Including feature adaption deform convs (feat_adaption), deform offset + prediction convs (dcn_off), guiding shift (guiding_shift) and + centripetal shift ( centripetal_shift). Each branch has two parts: + prefix `tl_` for top-left and `br_` for bottom-right. + """ + self.tl_feat_adaption = nn.ModuleList() + self.br_feat_adaption = nn.ModuleList() + self.tl_dcn_offset = nn.ModuleList() + self.br_dcn_offset = nn.ModuleList() + self.tl_guiding_shift = nn.ModuleList() + self.br_guiding_shift = nn.ModuleList() + self.tl_centripetal_shift = nn.ModuleList() + self.br_centripetal_shift = nn.ModuleList() + + for _ in range(self.num_feat_levels): + self.tl_feat_adaption.append( + DeformConv2d(self.in_channels, self.in_channels, + self.feat_adaption_conv_kernel, 1, 1)) + self.br_feat_adaption.append( + DeformConv2d(self.in_channels, self.in_channels, + self.feat_adaption_conv_kernel, 1, 1)) + + self.tl_guiding_shift.append( + self._make_layers( + out_channels=self.guiding_shift_channels, + in_channels=self.in_channels)) + self.br_guiding_shift.append( + self._make_layers( + out_channels=self.guiding_shift_channels, + in_channels=self.in_channels)) + + self.tl_dcn_offset.append( + ConvModule( + self.guiding_shift_channels, + self.feat_adaption_conv_kernel**2 * + self.guiding_shift_channels, + 1, + bias=False, + act_cfg=None)) + self.br_dcn_offset.append( + ConvModule( + self.guiding_shift_channels, + self.feat_adaption_conv_kernel**2 * + self.guiding_shift_channels, + 1, + bias=False, + act_cfg=None)) + + self.tl_centripetal_shift.append( + self._make_layers( + out_channels=self.centripetal_shift_channels, + in_channels=self.in_channels)) + self.br_centripetal_shift.append( + self._make_layers( + out_channels=self.centripetal_shift_channels, + in_channels=self.in_channels)) + + def _init_layers(self) -> None: + """Initialize layers for CentripetalHead. + + Including two parts: CornerHead layers and CentripetalHead layers + """ + super()._init_layers() # using _init_layers in CornerHead + self._init_centripetal_layers() + + def init_weights(self) -> None: + super().init_weights() + for i in range(self.num_feat_levels): + normal_init(self.tl_feat_adaption[i], std=0.01) + normal_init(self.br_feat_adaption[i], std=0.01) + normal_init(self.tl_dcn_offset[i].conv, std=0.1) + normal_init(self.br_dcn_offset[i].conv, std=0.1) + _ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]] + _ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]] + _ = [ + x.conv.reset_parameters() for x in self.tl_centripetal_shift[i] + ] + _ = [ + x.conv.reset_parameters() for x in self.br_centripetal_shift[i] + ] + + def forward_single(self, x: Tensor, lvl_ind: int) -> List[Tensor]: + """Forward feature of a single level. + + Args: + x (Tensor): Feature of a single level. + lvl_ind (int): Level index of current feature. + + Returns: + tuple[Tensor]: A tuple of CentripetalHead's output for current + feature level. Containing the following Tensors: + + - tl_heat (Tensor): Predicted top-left corner heatmap. + - br_heat (Tensor): Predicted bottom-right corner heatmap. + - tl_off (Tensor): Predicted top-left offset heatmap. + - br_off (Tensor): Predicted bottom-right offset heatmap. + - tl_guiding_shift (Tensor): Predicted top-left guiding shift + heatmap. + - br_guiding_shift (Tensor): Predicted bottom-right guiding + shift heatmap. + - tl_centripetal_shift (Tensor): Predicted top-left centripetal + shift heatmap. + - br_centripetal_shift (Tensor): Predicted bottom-right + centripetal shift heatmap. + """ + tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super( + ).forward_single( + x, lvl_ind, return_pool=True) + + tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool) + br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool) + + tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach()) + br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach()) + + tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool, + tl_dcn_offset) + br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool, + br_dcn_offset) + + tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind]( + tl_feat_adaption) + br_centripetal_shift = self.br_centripetal_shift[lvl_ind]( + br_feat_adaption) + + result_list = [ + tl_heat, br_heat, tl_off, br_off, tl_guiding_shift, + br_guiding_shift, tl_centripetal_shift, br_centripetal_shift + ] + return result_list + + def loss_by_feat( + self, + tl_heats: List[Tensor], + br_heats: List[Tensor], + tl_offs: List[Tensor], + br_offs: List[Tensor], + tl_guiding_shifts: List[Tensor], + br_guiding_shifts: List[Tensor], + tl_centripetal_shifts: List[Tensor], + br_centripetal_shifts: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + tl_heats (list[Tensor]): Top-left corner heatmaps for each level + with shape (N, num_classes, H, W). + br_heats (list[Tensor]): Bottom-right corner heatmaps for each + level with shape (N, num_classes, H, W). + tl_offs (list[Tensor]): Top-left corner offsets for each level + with shape (N, corner_offset_channels, H, W). + br_offs (list[Tensor]): Bottom-right corner offsets for each level + with shape (N, corner_offset_channels, H, W). + tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each + level with shape (N, guiding_shift_channels, H, W). + br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for + each level with shape (N, guiding_shift_channels, H, W). + tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts + for each level with shape (N, centripetal_shift_channels, H, + W). + br_centripetal_shifts (list[Tensor]): Bottom-right centripetal + shifts for each level with shape (N, + centripetal_shift_channels, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Specify which bounding boxes can be ignored when computing + the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. Containing the + following losses: + + - det_loss (list[Tensor]): Corner keypoint losses of all + feature levels. + - off_loss (list[Tensor]): Corner offset losses of all feature + levels. + - guiding_loss (list[Tensor]): Guiding shift losses of all + feature levels. + - centripetal_loss (list[Tensor]): Centripetal shift losses of + all feature levels. + """ + gt_bboxes = [ + gt_instances.bboxes for gt_instances in batch_gt_instances + ] + gt_labels = [ + gt_instances.labels for gt_instances in batch_gt_instances + ] + + targets = self.get_targets( + gt_bboxes, + gt_labels, + tl_heats[-1].shape, + batch_img_metas[0]['batch_input_shape'], + with_corner_emb=self.with_corner_emb, + with_guiding_shift=True, + with_centripetal_shift=True) + mlvl_targets = [targets for _ in range(self.num_feat_levels)] + [det_losses, off_losses, guiding_losses, centripetal_losses + ] = multi_apply(self.loss_by_feat_single, tl_heats, br_heats, tl_offs, + br_offs, tl_guiding_shifts, br_guiding_shifts, + tl_centripetal_shifts, br_centripetal_shifts, + mlvl_targets) + loss_dict = dict( + det_loss=det_losses, + off_loss=off_losses, + guiding_loss=guiding_losses, + centripetal_loss=centripetal_losses) + return loss_dict + + def loss_by_feat_single(self, tl_hmp: Tensor, br_hmp: Tensor, + tl_off: Tensor, br_off: Tensor, + tl_guiding_shift: Tensor, br_guiding_shift: Tensor, + tl_centripetal_shift: Tensor, + br_centripetal_shift: Tensor, + targets: dict) -> Tuple[Tensor, ...]: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + tl_hmp (Tensor): Top-left corner heatmap for current level with + shape (N, num_classes, H, W). + br_hmp (Tensor): Bottom-right corner heatmap for current level with + shape (N, num_classes, H, W). + tl_off (Tensor): Top-left corner offset for current level with + shape (N, corner_offset_channels, H, W). + br_off (Tensor): Bottom-right corner offset for current level with + shape (N, corner_offset_channels, H, W). + tl_guiding_shift (Tensor): Top-left guiding shift for current level + with shape (N, guiding_shift_channels, H, W). + br_guiding_shift (Tensor): Bottom-right guiding shift for current + level with shape (N, guiding_shift_channels, H, W). + tl_centripetal_shift (Tensor): Top-left centripetal shift for + current level with shape (N, centripetal_shift_channels, H, W). + br_centripetal_shift (Tensor): Bottom-right centripetal shift for + current level with shape (N, centripetal_shift_channels, H, W). + targets (dict): Corner target generated by `get_targets`. + + Returns: + tuple[torch.Tensor]: Losses of the head's different branches + containing the following losses: + + - det_loss (Tensor): Corner keypoint loss. + - off_loss (Tensor): Corner offset loss. + - guiding_loss (Tensor): Guiding shift loss. + - centripetal_loss (Tensor): Centripetal shift loss. + """ + targets['corner_embedding'] = None + + det_loss, _, _, off_loss = super().loss_by_feat_single( + tl_hmp, br_hmp, None, None, tl_off, br_off, targets) + + gt_tl_guiding_shift = targets['topleft_guiding_shift'] + gt_br_guiding_shift = targets['bottomright_guiding_shift'] + gt_tl_centripetal_shift = targets['topleft_centripetal_shift'] + gt_br_centripetal_shift = targets['bottomright_centripetal_shift'] + + gt_tl_heatmap = targets['topleft_heatmap'] + gt_br_heatmap = targets['bottomright_heatmap'] + # We only compute the offset loss at the real corner position. + # The value of real corner would be 1 in heatmap ground truth. + # The mask is computed in class agnostic mode and its shape is + # batch * 1 * width * height. + tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as( + gt_tl_heatmap) + br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as( + gt_br_heatmap) + + # Guiding shift loss + tl_guiding_loss = self.loss_guiding_shift( + tl_guiding_shift, + gt_tl_guiding_shift, + tl_mask, + avg_factor=tl_mask.sum()) + br_guiding_loss = self.loss_guiding_shift( + br_guiding_shift, + gt_br_guiding_shift, + br_mask, + avg_factor=br_mask.sum()) + guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0 + # Centripetal shift loss + tl_centripetal_loss = self.loss_centripetal_shift( + tl_centripetal_shift, + gt_tl_centripetal_shift, + tl_mask, + avg_factor=tl_mask.sum()) + br_centripetal_loss = self.loss_centripetal_shift( + br_centripetal_shift, + gt_br_centripetal_shift, + br_mask, + avg_factor=br_mask.sum()) + centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0 + + return det_loss, off_loss, guiding_loss, centripetal_loss + + def predict_by_feat(self, + tl_heats: List[Tensor], + br_heats: List[Tensor], + tl_offs: List[Tensor], + br_offs: List[Tensor], + tl_guiding_shifts: List[Tensor], + br_guiding_shifts: List[Tensor], + tl_centripetal_shifts: List[Tensor], + br_centripetal_shifts: List[Tensor], + batch_img_metas: Optional[List[dict]] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + tl_heats (list[Tensor]): Top-left corner heatmaps for each level + with shape (N, num_classes, H, W). + br_heats (list[Tensor]): Bottom-right corner heatmaps for each + level with shape (N, num_classes, H, W). + tl_offs (list[Tensor]): Top-left corner offsets for each level + with shape (N, corner_offset_channels, H, W). + br_offs (list[Tensor]): Bottom-right corner offsets for each level + with shape (N, corner_offset_channels, H, W). + tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each + level with shape (N, guiding_shift_channels, H, W). Useless in + this function, we keep this arg because it's the raw output + from CentripetalHead. + br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for + each level with shape (N, guiding_shift_channels, H, W). + Useless in this function, we keep this arg because it's the + raw output from CentripetalHead. + tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts + for each level with shape (N, centripetal_shift_channels, H, + W). + br_centripetal_shifts (list[Tensor]): Bottom-right centripetal + shifts for each level with shape (N, + centripetal_shift_channels, H, W). + batch_img_metas (list[dict], optional): Batch image meta info. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len( + batch_img_metas) + result_list = [] + for img_id in range(len(batch_img_metas)): + result_list.append( + self._predict_by_feat_single( + tl_heats[-1][img_id:img_id + 1, :], + br_heats[-1][img_id:img_id + 1, :], + tl_offs[-1][img_id:img_id + 1, :], + br_offs[-1][img_id:img_id + 1, :], + batch_img_metas[img_id], + tl_emb=None, + br_emb=None, + tl_centripetal_shift=tl_centripetal_shifts[-1][ + img_id:img_id + 1, :], + br_centripetal_shift=br_centripetal_shifts[-1][ + img_id:img_id + 1, :], + rescale=rescale, + with_nms=with_nms)) + + return result_list diff --git a/mmdet/models/dense_heads/condinst_head.py b/mmdet/models/dense_heads/condinst_head.py new file mode 100644 index 0000000000000000000000000000000000000000..35a25e6339a8161314cb0523e7181f9d400023ac --- /dev/null +++ b/mmdet/models/dense_heads/condinst_head.py @@ -0,0 +1,1226 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from mmengine.config import ConfigDict +from mmengine.model import BaseModule, kaiming_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import cat_boxes +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList, reduce_mean) +from ..task_modules.prior_generators import MlvlPointGenerator +from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply, + relative_coordinate_maps, select_single_mlvl) +from ..utils.misc import empty_instances +from .base_mask_head import BaseMaskHead +from .fcos_head import FCOSHead + +INF = 1e8 + + +@MODELS.register_module() +class CondInstBboxHead(FCOSHead): + """CondInst box head used in https://arxiv.org/abs/1904.02689. + + Note that CondInst Bbox Head is a extension of FCOS head. + Two differences are described as follows: + + 1. CondInst box head predicts a set of params for each instance. + 2. CondInst box head return the pos_gt_inds and pos_inds. + + Args: + num_params (int): Number of params for instance segmentation. + """ + + def __init__(self, *args, num_params: int = 169, **kwargs) -> None: + self.num_params = num_params + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + super()._init_layers() + self.controller = nn.Conv2d( + self.feat_channels, self.num_params, 3, padding=1) + + def forward_single(self, x: Tensor, scale: Scale, + stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps, only + used to normalize the bbox prediction when self.norm_on_bbox + is True. + + Returns: + tuple: scores for each class, bbox predictions, centerness + predictions and param predictions of input feature maps. + """ + cls_score, bbox_pred, cls_feat, reg_feat = \ + super(FCOSHead, self).forward_single(x) + if self.centerness_on_reg: + centerness = self.conv_centerness(reg_feat) + else: + centerness = self.conv_centerness(cls_feat) + # scale the bbox_pred of different level + # float to avoid overflow when enabling FP16 + bbox_pred = scale(bbox_pred).float() + if self.norm_on_bbox: + # bbox_pred needed for gradient computation has been modified + # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace + # F.relu(bbox_pred) with bbox_pred.clamp(min=0) + bbox_pred = bbox_pred.clamp(min=0) + if not self.training: + bbox_pred *= stride + else: + bbox_pred = bbox_pred.exp() + param_pred = self.controller(reg_feat) + return cls_score, bbox_pred, centerness, param_pred + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + centernesses: List[Tensor], + param_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * 4. + centernesses (list[Tensor]): centerness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + param_preds (List[Tensor]): param_pred for each scale level, each + is a 4D-tensor, the channel number is num_params. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) == len(centernesses) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + # Need stride for rel coord compute + all_level_points_strides = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device, + with_stride=True) + all_level_points = [i[:, :2] for i in all_level_points_strides] + all_level_strides = [i[:, 2] for i in all_level_points_strides] + labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \ + self.get_targets(all_level_points, batch_gt_instances) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores, bbox_preds and centerness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_centerness = [ + centerness.permute(0, 2, 3, 1).reshape(-1) + for centerness in centernesses + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_centerness = torch.cat(flatten_centerness) + flatten_labels = torch.cat(labels) + flatten_bbox_targets = torch.cat(bbox_targets) + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((flatten_labels >= 0) + & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) + num_pos = torch.tensor( + len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) + num_pos = max(reduce_mean(num_pos), 1.0) + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, avg_factor=num_pos) + + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_centerness = flatten_centerness[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_centerness_targets = self.centerness_target(pos_bbox_targets) + # centerness weighted iou loss + centerness_denorm = max( + reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) + + if len(pos_inds) > 0: + pos_points = flatten_points[pos_inds] + pos_decoded_bbox_preds = self.bbox_coder.decode( + pos_points, pos_bbox_preds) + pos_decoded_target_preds = self.bbox_coder.decode( + pos_points, pos_bbox_targets) + loss_bbox = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + weight=pos_centerness_targets, + avg_factor=centerness_denorm) + loss_centerness = self.loss_centerness( + pos_centerness, pos_centerness_targets, avg_factor=num_pos) + else: + loss_bbox = pos_bbox_preds.sum() + loss_centerness = pos_centerness.sum() + + self._raw_positive_infos.update(cls_scores=cls_scores) + self._raw_positive_infos.update(centernesses=centernesses) + self._raw_positive_infos.update(param_preds=param_preds) + self._raw_positive_infos.update(all_level_points=all_level_points) + self._raw_positive_infos.update(all_level_strides=all_level_strides) + self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list) + self._raw_positive_infos.update(pos_inds_list=pos_inds_list) + + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_centerness=loss_centerness) + + def get_targets( + self, points: List[Tensor], batch_gt_instances: InstanceList + ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: + """Compute regression, classification and centerness targets for points + in multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: Targets of each level. + + - concat_lvl_labels (list[Tensor]): Labels of each level. + - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ + level. + - pos_inds_list (list[Tensor]): pos_inds of each image. + - pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image. + """ + assert len(points) == len(self.regress_ranges) + num_levels = len(points) + # expand regress ranges to align with points + expanded_regress_ranges = [ + points[i].new_tensor(self.regress_ranges[i])[None].expand_as( + points[i]) for i in range(num_levels) + ] + # concat all levels points and regress ranges + concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) + concat_points = torch.cat(points, dim=0) + + # the number of points per img, per lvl + num_points = [center.size(0) for center in points] + + # get labels and bbox_targets of each image + labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \ + multi_apply( + self._get_targets_single, + batch_gt_instances, + points=concat_points, + regress_ranges=concat_regress_ranges, + num_points_per_lvl=num_points) + + # split to per img, per level + labels_list = [labels.split(num_points, 0) for labels in labels_list] + bbox_targets_list = [ + bbox_targets.split(num_points, 0) + for bbox_targets in bbox_targets_list + ] + + # concat per level image + concat_lvl_labels = [] + concat_lvl_bbox_targets = [] + for i in range(num_levels): + concat_lvl_labels.append( + torch.cat([labels[i] for labels in labels_list])) + bbox_targets = torch.cat( + [bbox_targets[i] for bbox_targets in bbox_targets_list]) + if self.norm_on_bbox: + bbox_targets = bbox_targets / self.strides[i] + concat_lvl_bbox_targets.append(bbox_targets) + return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list, + pos_gt_inds_list) + + def _get_targets_single( + self, gt_instances: InstanceData, points: Tensor, + regress_ranges: Tensor, num_points_per_lvl: List[int] + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Compute regression and classification targets for a single image.""" + num_points = points.size(0) + num_gts = len(gt_instances) + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + gt_masks = gt_instances.get('masks', None) + + if num_gts == 0: + return gt_labels.new_full((num_points,), self.num_classes), \ + gt_bboxes.new_zeros((num_points, 4)), \ + gt_bboxes.new_zeros((0,), dtype=torch.int64), \ + gt_bboxes.new_zeros((0,), dtype=torch.int64) + + areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) + # TODO: figure out why these two are different + # areas = areas[None].expand(num_points, num_gts) + areas = areas[None].repeat(num_points, 1) + regress_ranges = regress_ranges[:, None, :].expand( + num_points, num_gts, 2) + gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) + xs, ys = points[:, 0], points[:, 1] + xs = xs[:, None].expand(num_points, num_gts) + ys = ys[:, None].expand(num_points, num_gts) + + left = xs - gt_bboxes[..., 0] + right = gt_bboxes[..., 2] - xs + top = ys - gt_bboxes[..., 1] + bottom = gt_bboxes[..., 3] - ys + bbox_targets = torch.stack((left, top, right, bottom), -1) + + if self.center_sampling: + # condition1: inside a `center bbox` + radius = self.center_sample_radius + # if gt_mask not None, use gt mask's centroid to determine + # the center region rather than gt_bbox center + if gt_masks is None: + center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 + center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 + else: + h, w = gt_masks.height, gt_masks.width + masks = gt_masks.to_tensor( + dtype=torch.bool, device=gt_bboxes.device) + yys = torch.arange( + 0, h, dtype=torch.float32, device=masks.device) + xxs = torch.arange( + 0, w, dtype=torch.float32, device=masks.device) + # m00/m10/m01 represent the moments of a contour + # centroid is computed by m00/m10 and m00/m01 + m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6) + m10 = (masks * xxs).sum(dim=-1).sum(dim=-1) + m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1) + center_xs = m10 / m00 + center_ys = m01 / m00 + + center_xs = center_xs[None].expand(num_points, num_gts) + center_ys = center_ys[None].expand(num_points, num_gts) + center_gts = torch.zeros_like(gt_bboxes) + stride = center_xs.new_zeros(center_xs.shape) + + # project the points on current lvl back to the `original` sizes + lvl_begin = 0 + for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): + lvl_end = lvl_begin + num_points_lvl + stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius + lvl_begin = lvl_end + + x_mins = center_xs - stride + y_mins = center_ys - stride + x_maxs = center_xs + stride + y_maxs = center_ys + stride + center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], + x_mins, gt_bboxes[..., 0]) + center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], + y_mins, gt_bboxes[..., 1]) + center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], + gt_bboxes[..., 2], x_maxs) + center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], + gt_bboxes[..., 3], y_maxs) + + cb_dist_left = xs - center_gts[..., 0] + cb_dist_right = center_gts[..., 2] - xs + cb_dist_top = ys - center_gts[..., 1] + cb_dist_bottom = center_gts[..., 3] - ys + center_bbox = torch.stack( + (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) + inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 + else: + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 + + # condition2: limit the regression range for each location + max_regress_distance = bbox_targets.max(-1)[0] + inside_regress_range = ( + (max_regress_distance >= regress_ranges[..., 0]) + & (max_regress_distance <= regress_ranges[..., 1])) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + areas[inside_gt_bbox_mask == 0] = INF + areas[inside_regress_range == 0] = INF + min_area, min_area_inds = areas.min(dim=1) + + labels = gt_labels[min_area_inds] + labels[min_area == INF] = self.num_classes # set as BG + bbox_targets = bbox_targets[range(num_points), min_area_inds] + + # return pos_inds & pos_gt_inds + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().reshape(-1) + pos_gt_inds = min_area_inds[labels < self.num_classes] + return labels, bbox_targets, pos_inds, pos_gt_inds + + def get_positive_infos(self) -> InstanceList: + """Get positive information from sampling results. + + Returns: + list[:obj:`InstanceData`]: Positive information of each image, + usually including positive bboxes, positive labels, positive + priors, etc. + """ + assert len(self._raw_positive_infos) > 0 + + pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list'] + pos_inds_list = self._raw_positive_infos['pos_inds_list'] + num_imgs = len(pos_gt_inds_list) + + cls_score_list = [] + centerness_list = [] + param_pred_list = [] + point_list = [] + stride_list = [] + for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\ + point_per_lvl, stride_per_lvl in \ + zip(self._raw_positive_infos['cls_scores'], + self._raw_positive_infos['centernesses'], + self._raw_positive_infos['param_preds'], + self._raw_positive_infos['all_level_points'], + self._raw_positive_infos['all_level_strides']): + cls_score_per_lvl = \ + cls_score_per_lvl.permute( + 0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) + centerness_per_lvl = \ + centerness_per_lvl.permute( + 0, 2, 3, 1).reshape(num_imgs, -1, 1) + param_pred_per_lvl = \ + param_pred_per_lvl.permute( + 0, 2, 3, 1).reshape(num_imgs, -1, self.num_params) + point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1) + stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1) + + cls_score_list.append(cls_score_per_lvl) + centerness_list.append(centerness_per_lvl) + param_pred_list.append(param_pred_per_lvl) + point_list.append(point_per_lvl) + stride_list.append(stride_per_lvl) + cls_scores = torch.cat(cls_score_list, dim=1) + centernesses = torch.cat(centerness_list, dim=1) + param_preds = torch.cat(param_pred_list, dim=1) + all_points = torch.cat(point_list, dim=1) + all_strides = torch.cat(stride_list, dim=1) + + positive_infos = [] + for i, (pos_gt_inds, + pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)): + pos_info = InstanceData() + pos_info.points = all_points[i][pos_inds] + pos_info.strides = all_strides[i][pos_inds] + pos_info.scores = cls_scores[i][pos_inds] + pos_info.centernesses = centernesses[i][pos_inds] + pos_info.param_preds = param_preds[i][pos_inds] + pos_info.pos_assigned_gt_inds = pos_gt_inds + pos_info.pos_inds = pos_inds + positive_infos.append(pos_info) + return positive_infos + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + param_preds: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + param_preds (list[Tensor], optional): Params for all scale + level, each is a 4D-tensor, has shape + (batch_size, num_priors * num_params, H, W) + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + all_level_points_strides = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device, + with_stride=True) + all_level_points = [i[:, :2] for i in all_level_points_strides] + all_level_strides = [i[:, 2] for i in all_level_points_strides] + + result_list = [] + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl( + cls_scores, img_id, detach=True) + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + if with_score_factors: + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + else: + score_factor_list = [None for _ in range(num_levels)] + param_pred_list = select_single_mlvl( + param_preds, img_id, detach=True) + + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + score_factor_list=score_factor_list, + param_pred_list=param_pred_list, + mlvl_points=all_level_points, + mlvl_strides=all_level_strides, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + param_pred_list: List[Tensor], + mlvl_points: List[Tensor], + mlvl_strides: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + param_pred_list (List[Tensor]): Param predition from all scale + levels of a single image, each item has shape + (num_priors * num_params, H, W). + mlvl_points (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. + It has shape (num_priors, 2) + mlvl_strides (List[Tensor]): Each element in the list is + the stride of a single level in feature pyramid. + It has shape (num_priors, 1) + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_param_preds = [] + mlvl_valid_points = [] + mlvl_valid_strides = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + for level_idx, (cls_score, bbox_pred, score_factor, + param_pred, points, strides) in \ + enumerate(zip(cls_score_list, bbox_pred_list, + score_factor_list, param_pred_list, + mlvl_points, mlvl_strides)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = cls_score.softmax(-1)[:, :-1] + + param_pred = param_pred.permute(1, 2, + 0).reshape(-1, self.num_params) + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict( + bbox_pred=bbox_pred, + param_pred=param_pred, + points=points, + strides=strides)) + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + param_pred = filtered_results['param_pred'] + points = filtered_results['points'] + strides = filtered_results['strides'] + + if with_score_factors: + score_factor = score_factor[keep_idxs] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_param_preds.append(param_pred) + mlvl_valid_points.append(points) + mlvl_valid_strides.append(strides) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + if with_score_factors: + mlvl_score_factors.append(score_factor) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_points) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + results.param_preds = torch.cat(mlvl_param_preds) + results.points = torch.cat(mlvl_valid_points) + results.strides = torch.cat(mlvl_valid_strides) + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + +class MaskFeatModule(BaseModule): + """CondInst mask feature map branch used in \ + https://arxiv.org/abs/1904.02689. + + Args: + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels of the mask feature + map branch. + start_level (int): The starting feature map level from RPN that + will be used to predict the mask feature map. + end_level (int): The ending feature map level from rpn that + will be used to predict the mask feature map. + out_channels (int): Number of output channels of the mask feature + map branch. This is the channel count of the mask + feature map that to be dynamically convolved with the predicted + kernel. + mask_stride (int): Downsample factor of the mask feature map output. + Defaults to 4. + num_stacked_convs (int): Number of convs in mask feature branch. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels: int, + feat_channels: int, + start_level: int, + end_level: int, + out_channels: int, + mask_stride: int = 4, + num_stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: MultiConfig = [ + dict(type='Normal', layer='Conv2d', std=0.01) + ], + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.start_level = start_level + self.end_level = end_level + self.mask_stride = mask_stride + self.num_stacked_convs = num_stacked_convs + assert start_level >= 0 and end_level >= start_level + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.convs_all_levels = nn.ModuleList() + for i in range(self.start_level, self.end_level + 1): + convs_per_level = nn.Sequential() + convs_per_level.add_module( + f'conv{i}', + ConvModule( + self.in_channels, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=False, + bias=False)) + self.convs_all_levels.append(convs_per_level) + + conv_branch = [] + for _ in range(self.num_stacked_convs): + conv_branch.append( + ConvModule( + self.feat_channels, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=False)) + self.conv_branch = nn.Sequential(*conv_branch) + + self.conv_pred = nn.Conv2d( + self.feat_channels, self.out_channels, 1, stride=1) + + def init_weights(self) -> None: + """Initialize weights of the head.""" + super().init_weights() + kaiming_init(self.convs_all_levels, a=1, distribution='uniform') + kaiming_init(self.conv_branch, a=1, distribution='uniform') + kaiming_init(self.conv_pred, a=1, distribution='uniform') + + def forward(self, x: Tuple[Tensor]) -> Tensor: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + Tensor: The predicted mask feature map. + """ + inputs = x[self.start_level:self.end_level + 1] + assert len(inputs) == (self.end_level - self.start_level + 1) + feature_add_all_level = self.convs_all_levels[0](inputs[0]) + target_h, target_w = feature_add_all_level.size()[2:] + for i in range(1, len(inputs)): + input_p = inputs[i] + x_p = self.convs_all_levels[i](input_p) + h, w = x_p.size()[2:] + factor_h = target_h // h + factor_w = target_w // w + assert factor_h == factor_w + feature_per_level = aligned_bilinear(x_p, factor_h) + feature_add_all_level = feature_add_all_level + \ + feature_per_level + + feature_add_all_level = self.conv_branch(feature_add_all_level) + feature_pred = self.conv_pred(feature_add_all_level) + return feature_pred + + +@MODELS.register_module() +class CondInstMaskHead(BaseMaskHead): + """CondInst mask head used in https://arxiv.org/abs/1904.02689. + + This head outputs the mask for CondInst. + + Args: + mask_feature_head (dict): Config of CondInstMaskFeatHead. + num_layers (int): Number of dynamic conv layers. + feat_channels (int): Number of channels in the dynamic conv. + mask_out_stride (int): The stride of the mask feat. + size_of_interest (int): The size of the region used in rel coord. + max_masks_to_train (int): Maximum number of masks to train for + each image. + loss_segm (:obj:`ConfigDict` or dict, optional): Config of + segmentation loss. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config + of head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + head. + """ + + def __init__(self, + mask_feature_head: ConfigType, + num_layers: int = 3, + feat_channels: int = 8, + mask_out_stride: int = 4, + size_of_interest: int = 8, + max_masks_to_train: int = -1, + topk_masks_per_img: int = -1, + loss_mask: ConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None) -> None: + super().__init__() + self.mask_feature_head = MaskFeatModule(**mask_feature_head) + self.mask_feat_stride = self.mask_feature_head.mask_stride + self.in_channels = self.mask_feature_head.out_channels + self.num_layers = num_layers + self.feat_channels = feat_channels + self.size_of_interest = size_of_interest + self.mask_out_stride = mask_out_stride + self.max_masks_to_train = max_masks_to_train + self.topk_masks_per_img = topk_masks_per_img + self.prior_generator = MlvlPointGenerator([self.mask_feat_stride]) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.loss_mask = MODELS.build(loss_mask) + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + weight_nums, bias_nums = [], [] + for i in range(self.num_layers): + if i == 0: + weight_nums.append((self.in_channels + 2) * self.feat_channels) + bias_nums.append(self.feat_channels) + elif i == self.num_layers - 1: + weight_nums.append(self.feat_channels * 1) + bias_nums.append(1) + else: + weight_nums.append(self.feat_channels * self.feat_channels) + bias_nums.append(self.feat_channels) + + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_params = sum(weight_nums) + sum(bias_nums) + + def parse_dynamic_params( + self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]: + """parse the dynamic params for dynamic conv.""" + num_insts = params.size(0) + params_splits = list( + torch.split_with_sizes( + params, self.weight_nums + self.bias_nums, dim=1)) + weight_splits = params_splits[:self.num_layers] + bias_splits = params_splits[self.num_layers:] + for i in range(self.num_layers): + if i < self.num_layers - 1: + weight_splits[i] = weight_splits[i].reshape( + num_insts * self.in_channels, -1, 1, 1) + bias_splits[i] = bias_splits[i].reshape(num_insts * + self.in_channels) + else: + # out_channels x in_channels x 1 x 1 + weight_splits[i] = weight_splits[i].reshape( + num_insts * 1, -1, 1, 1) + bias_splits[i] = bias_splits[i].reshape(num_insts) + + return weight_splits, bias_splits + + def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor], + biases: List[Tensor], num_insts: int) -> Tensor: + """dynamic forward, each layer follow a relu.""" + n_layers = len(weights) + x = features + for i, (w, b) in enumerate(zip(weights, biases)): + x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts) + if i < n_layers - 1: + x = F.relu(x) + return x + + def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: + """Forward feature from the upstream network to get prototypes and + linearly combine the prototypes, using masks coefficients, into + instance masks. Finally, crop the instance masks with given bboxes. + + Args: + x (Tuple[Tensor]): Feature from the upstream network, which is + a 4D-tensor. + positive_infos (List[:obj:``InstanceData``]): Positive information + that calculate from detect head. + + Returns: + tuple: Predicted instance segmentation masks + """ + mask_feats = self.mask_feature_head(x) + return multi_apply(self.forward_single, mask_feats, positive_infos) + + def forward_single(self, mask_feat: Tensor, + positive_info: InstanceData) -> Tensor: + """Forward features of a each image.""" + pos_param_preds = positive_info.get('param_preds') + pos_points = positive_info.get('points') + pos_strides = positive_info.get('strides') + + num_inst = pos_param_preds.shape[0] + mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1) + _, _, H, W = mask_feat.size() + if num_inst == 0: + return (pos_param_preds.new_zeros((0, 1, H, W)), ) + + locations = self.prior_generator.single_level_grid_priors( + mask_feat.size()[2:], 0, device=mask_feat.device) + + rel_coords = relative_coordinate_maps(locations, pos_points, + pos_strides, + self.size_of_interest, + mask_feat.size()[2:]) + mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1) + mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) + + weights, biases = self.parse_dynamic_params(pos_param_preds) + mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights, + biases, num_inst) + mask_preds = mask_preds.reshape(-1, H, W) + mask_preds = aligned_bilinear( + mask_preds.unsqueeze(0), + int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0) + + return (mask_preds, ) + + def loss_by_feat(self, mask_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], positive_infos: InstanceList, + **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (list[Tensor]): List of predicted masks, each has + shape (num_classes, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Information of + positive samples of each image that are assigned in detection + head. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert positive_infos is not None, \ + 'positive_infos should not be None in `CondInstMaskHead`' + losses = dict() + + loss_mask = 0. + num_imgs = len(mask_preds) + total_pos = 0 + + for idx in range(num_imgs): + (mask_pred, pos_mask_targets, num_pos) = \ + self._get_targets_single( + mask_preds[idx], batch_gt_instances[idx], + positive_infos[idx]) + # mask loss + total_pos += num_pos + if num_pos == 0 or pos_mask_targets is None: + loss = mask_pred.new_zeros(1).mean() + else: + loss = self.loss_mask( + mask_pred, pos_mask_targets, + reduction_override='none').sum() + loss_mask += loss + + if total_pos == 0: + total_pos += 1 # avoid nan + loss_mask = loss_mask / total_pos + losses.update(loss_mask=loss_mask) + return losses + + def _get_targets_single(self, mask_preds: Tensor, + gt_instances: InstanceData, + positive_info: InstanceData): + """Compute targets for predictions of single image. + + Args: + mask_preds (Tensor): Predicted prototypes with shape + (num_classes, H, W). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + positive_info (:obj:`InstanceData`): Information of positive + samples that are assigned in detection head. It usually + contains following keys. + + - pos_assigned_gt_inds (Tensor): Assigner GT indexes of + positive proposals, has shape (num_pos, ) + - pos_inds (Tensor): Positive index of image, has + shape (num_pos, ). + - param_pred (Tensor): Positive param preditions + with shape (num_pos, num_params). + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - mask_preds (Tensor): Positive predicted mask with shape + (num_pos, mask_h, mask_w). + - pos_mask_targets (Tensor): Positive mask targets with shape + (num_pos, mask_h, mask_w). + - num_pos (int): Positive numbers. + """ + gt_bboxes = gt_instances.bboxes + device = gt_bboxes.device + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device).float() + + # process with mask targets + pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') + scores = positive_info.get('scores') + centernesses = positive_info.get('centernesses') + num_pos = pos_assigned_gt_inds.size(0) + + if gt_masks.size(0) == 0 or num_pos == 0: + return mask_preds, None, 0 + # Since we're producing (near) full image masks, + # it'd take too much vram to backprop on every single mask. + # Thus we select only a subset. + if (self.max_masks_to_train != -1) and \ + (num_pos > self.max_masks_to_train): + perm = torch.randperm(num_pos) + select = perm[:self.max_masks_to_train] + mask_preds = mask_preds[select] + pos_assigned_gt_inds = pos_assigned_gt_inds[select] + num_pos = self.max_masks_to_train + elif self.topk_masks_per_img != -1: + unique_gt_inds = pos_assigned_gt_inds.unique() + num_inst_per_gt = max( + int(self.topk_masks_per_img / len(unique_gt_inds)), 1) + + keep_mask_preds = [] + keep_pos_assigned_gt_inds = [] + for gt_ind in unique_gt_inds: + per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) + mask_preds_per_inst = mask_preds[per_inst_pos_inds] + gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] + if sum(per_inst_pos_inds) > num_inst_per_gt: + per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( + dim=1)[0] + per_inst_centerness = centernesses[ + per_inst_pos_inds].sigmoid().reshape(-1, ) + select = (per_inst_scores * per_inst_centerness).topk( + k=num_inst_per_gt, dim=0)[1] + mask_preds_per_inst = mask_preds_per_inst[select] + gt_inds_per_inst = gt_inds_per_inst[select] + keep_mask_preds.append(mask_preds_per_inst) + keep_pos_assigned_gt_inds.append(gt_inds_per_inst) + mask_preds = torch.cat(keep_mask_preds) + pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) + num_pos = pos_assigned_gt_inds.size(0) + + # Follow the origin implement + start = int(self.mask_out_stride // 2) + gt_masks = gt_masks[:, start::self.mask_out_stride, + start::self.mask_out_stride] + gt_masks = gt_masks.gt(0.5).float() + pos_mask_targets = gt_masks[pos_assigned_gt_inds] + + return (mask_preds, pos_mask_targets, num_pos) + + def predict_by_feat(self, + mask_preds: List[Tensor], + results_list: InstanceList, + batch_img_metas: List[dict], + rescale: bool = True, + **kwargs) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. + + Args: + mask_preds (list[Tensor]): Predicted prototypes with shape + (num_classes, H, W). + results_list (List[:obj:``InstanceData``]): BBoxHead results. + batch_img_metas (list[dict]): Meta information of all images. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + assert len(mask_preds) == len(results_list) == len(batch_img_metas) + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = results_list[img_id] + bboxes = results.bboxes + mask_pred = mask_preds[img_id] + if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0: + results_list[img_id] = empty_instances( + [img_meta], + bboxes.device, + task_type='mask', + instance_results=[results])[0] + else: + im_mask = self._predict_by_feat_single( + mask_preds=mask_pred, + bboxes=bboxes, + img_meta=img_meta, + rescale=rescale) + results.masks = im_mask + return results_list + + def _predict_by_feat_single(self, + mask_preds: Tensor, + bboxes: Tensor, + img_meta: dict, + rescale: bool, + cfg: OptConfigType = None): + """Transform a single image's features extracted from the head into + mask results. + + Args: + mask_preds (Tensor): Predicted prototypes, has shape [H, W, N]. + img_meta (dict): Meta information of each image, e.g., + image size, scaling factor, etc. + rescale (bool): If rescale is False, then returned masks will + fit the scale of imgs[0]. + cfg (dict, optional): Config used in test phase. + Defaults to None. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + cfg = self.test_cfg if cfg is None else cfg + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + img_h, img_w = img_meta['img_shape'][:2] + ori_h, ori_w = img_meta['ori_shape'][:2] + + mask_preds = mask_preds.sigmoid().unsqueeze(0) + mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) + mask_preds = mask_preds[:, :, :img_h, :img_w] + if rescale: # in-placed rescale the bboxes + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + bboxes /= scale_factor + + masks = F.interpolate( + mask_preds, (ori_h, ori_w), + mode='bilinear', + align_corners=False).squeeze(0) > cfg.mask_thr + else: + masks = mask_preds.squeeze(0) > cfg.mask_thr + + return masks diff --git a/mmdet/models/dense_heads/conditional_detr_head.py b/mmdet/models/dense_heads/conditional_detr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2df2c215667121c5fe329f369510ecd4666faf --- /dev/null +++ b/mmdet/models/dense_heads/conditional_detr_head.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmengine.model import bias_init_with_prob +from torch import Tensor + +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import InstanceList +from .detr_head import DETRHead + + +@MODELS.register_module() +class ConditionalDETRHead(DETRHead): + """Head of Conditional DETR. Conditional DETR: Conditional DETR for Fast + Training Convergence. More details can be found in the `paper. + + `_ . + """ + + def init_weights(self): + """Initialize weights of the transformer head.""" + super().init_weights() + # The initialization below for transformer head is very + # important as we use Focal_loss for loss_cls + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.fc_cls.bias, bias_init) + + def forward(self, hidden_states: Tensor, + references: Tensor) -> Tuple[Tensor, Tensor]: + """"Forward function. + + Args: + hidden_states (Tensor): Features from transformer decoder. If + `return_intermediate_dec` is True output has shape + (num_decoder_layers, bs, num_queries, dim), else has shape (1, + bs, num_queries, dim) which only contains the last layer + outputs. + references (Tensor): References from transformer decoder, has + shape (bs, num_queries, 2). + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - layers_cls_scores (Tensor): Outputs from the classification head, + shape (num_decoder_layers, bs, num_queries, cls_out_channels). + Note cls_out_channels should include background. + - layers_bbox_preds (Tensor): Sigmoid outputs from the regression + head with normalized coordinate format (cx, cy, w, h), has shape + (num_decoder_layers, bs, num_queries, 4). + """ + + references_unsigmoid = inverse_sigmoid(references) + layers_bbox_preds = [] + for layer_id in range(hidden_states.shape[0]): + tmp_reg_preds = self.fc_reg( + self.activate(self.reg_ffn(hidden_states[layer_id]))) + tmp_reg_preds[..., :2] += references_unsigmoid + outputs_coord = tmp_reg_preds.sigmoid() + layers_bbox_preds.append(outputs_coord) + layers_bbox_preds = torch.stack(layers_bbox_preds) + + layers_cls_scores = self.fc_cls(hidden_states) + return layers_cls_scores, layers_bbox_preds + + def loss(self, hidden_states: Tensor, references: Tensor, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + hidden_states (Tensor): Features from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, dim). + references (Tensor): References from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, 2). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_and_predict( + self, hidden_states: Tensor, references: Tensor, + batch_data_samples: SampleList) -> Tuple[dict, InstanceList]: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. Over-write because + img_metas are needed as inputs for bbox_head. + + Args: + hidden_states (Tensor): Features from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, dim). + references (Tensor): References from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, 2). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple: The return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each image after the post process. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas) + return losses, predictions + + def predict(self, + hidden_states: Tensor, + references: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. Over-write + because img_metas are needed as inputs for bbox_head. + + Args: + hidden_states (Tensor): Features from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, dim). + references (Tensor): References from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, 2). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + last_layer_hidden_state = hidden_states[-1].unsqueeze(0) + outs = self(last_layer_hidden_state, references) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + + return predictions diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0cec71d50947ff58224ae698ec9c2f9406b58efb --- /dev/null +++ b/mmdet/models/dense_heads/corner_head.py @@ -0,0 +1,1084 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from logging import warning +from math import ceil, log +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import CornerPool, batched_nms +from mmengine.config import ConfigDict +from mmengine.model import BaseModule, bias_init_with_prob +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, OptMultiConfig) +from ..utils import (gather_feat, gaussian_radius, gen_gaussian_target, + get_local_maximum, get_topk_from_heatmap, multi_apply, + transpose_and_gather_feat) +from .base_dense_head import BaseDenseHead + + +class BiCornerPool(BaseModule): + """Bidirectional Corner Pooling Module (TopLeft, BottomRight, etc.) + + Args: + in_channels (int): Input channels of module. + directions (list[str]): Directions of two CornerPools. + out_channels (int): Output channels of module. + feat_channels (int): Feature channels of module. + norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct + and config norm layer. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to + control the initialization. + """ + + def __init__(self, + in_channels: int, + directions: List[int], + feat_channels: int = 128, + out_channels: int = 128, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg) + self.direction1_conv = ConvModule( + in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg) + self.direction2_conv = ConvModule( + in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg) + + self.aftpool_conv = ConvModule( + feat_channels, + out_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=None) + + self.conv1 = ConvModule( + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + self.conv2 = ConvModule( + in_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg) + + self.direction1_pool = CornerPool(directions[0]) + self.direction2_pool = CornerPool(directions[1]) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: Tensor) -> Tensor: + """Forward features from the upstream network. + + Args: + x (tensor): Input feature of BiCornerPool. + + Returns: + conv2 (tensor): Output feature of BiCornerPool. + """ + direction1_conv = self.direction1_conv(x) + direction2_conv = self.direction2_conv(x) + direction1_feat = self.direction1_pool(direction1_conv) + direction2_feat = self.direction2_pool(direction2_conv) + aftpool_conv = self.aftpool_conv(direction1_feat + direction2_feat) + conv1 = self.conv1(x) + relu = self.relu(aftpool_conv + conv1) + conv2 = self.conv2(relu) + return conv2 + + +@MODELS.register_module() +class CornerHead(BaseDenseHead): + """Head of CornerNet: Detecting Objects as Paired Keypoints. + + Code is modified from the `official github repo + `_ . + + More details can be found in the `paper + `_ . + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_feat_levels (int): Levels of feature from the previous module. + 2 for HourglassNet-104 and 1 for HourglassNet-52. Because + HourglassNet-104 outputs the final feature and intermediate + supervision feature and HourglassNet-52 only outputs the final + feature. Defaults to 2. + corner_emb_channels (int): Channel of embedding vector. Defaults to 1. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config. + Useless in CornerHead, but we keep this variable for + SingleStageDetector. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + CornerHead. + loss_heatmap (:obj:`ConfigDict` or dict): Config of corner heatmap + loss. Defaults to GaussianFocalLoss. + loss_embedding (:obj:`ConfigDict` or dict): Config of corner embedding + loss. Defaults to AssociativeEmbeddingLoss. + loss_offset (:obj:`ConfigDict` or dict): Config of corner offset loss. + Defaults to SmoothL1Loss. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + num_feat_levels: int = 2, + corner_emb_channels: int = 1, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + loss_heatmap: ConfigType = dict( + type='GaussianFocalLoss', + alpha=2.0, + gamma=4.0, + loss_weight=1), + loss_embedding: ConfigType = dict( + type='AssociativeEmbeddingLoss', + pull_weight=0.25, + push_weight=0.25), + loss_offset: ConfigType = dict( + type='SmoothL1Loss', beta=1.0, loss_weight=1), + init_cfg: OptMultiConfig = None) -> None: + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.in_channels = in_channels + self.corner_emb_channels = corner_emb_channels + self.with_corner_emb = self.corner_emb_channels > 0 + self.corner_offset_channels = 2 + self.num_feat_levels = num_feat_levels + self.loss_heatmap = MODELS.build( + loss_heatmap) if loss_heatmap is not None else None + self.loss_embedding = MODELS.build( + loss_embedding) if loss_embedding is not None else None + self.loss_offset = MODELS.build( + loss_offset) if loss_offset is not None else None + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self._init_layers() + + def _make_layers(self, + out_channels: int, + in_channels: int = 256, + feat_channels: int = 256) -> nn.Sequential: + """Initialize conv sequential for CornerHead.""" + return nn.Sequential( + ConvModule(in_channels, feat_channels, 3, padding=1), + ConvModule( + feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None)) + + def _init_corner_kpt_layers(self) -> None: + """Initialize corner keypoint layers. + + Including corner heatmap branch and corner offset branch. Each branch + has two parts: prefix `tl_` for top-left and `br_` for bottom-right. + """ + self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList() + self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList() + self.tl_off, self.br_off = nn.ModuleList(), nn.ModuleList() + + for _ in range(self.num_feat_levels): + self.tl_pool.append( + BiCornerPool( + self.in_channels, ['top', 'left'], + out_channels=self.in_channels)) + self.br_pool.append( + BiCornerPool( + self.in_channels, ['bottom', 'right'], + out_channels=self.in_channels)) + + self.tl_heat.append( + self._make_layers( + out_channels=self.num_classes, + in_channels=self.in_channels)) + self.br_heat.append( + self._make_layers( + out_channels=self.num_classes, + in_channels=self.in_channels)) + + self.tl_off.append( + self._make_layers( + out_channels=self.corner_offset_channels, + in_channels=self.in_channels)) + self.br_off.append( + self._make_layers( + out_channels=self.corner_offset_channels, + in_channels=self.in_channels)) + + def _init_corner_emb_layers(self) -> None: + """Initialize corner embedding layers. + + Only include corner embedding branch with two parts: prefix `tl_` for + top-left and `br_` for bottom-right. + """ + self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList() + + for _ in range(self.num_feat_levels): + self.tl_emb.append( + self._make_layers( + out_channels=self.corner_emb_channels, + in_channels=self.in_channels)) + self.br_emb.append( + self._make_layers( + out_channels=self.corner_emb_channels, + in_channels=self.in_channels)) + + def _init_layers(self) -> None: + """Initialize layers for CornerHead. + + Including two parts: corner keypoint layers and corner embedding layers + """ + self._init_corner_kpt_layers() + if self.with_corner_emb: + self._init_corner_emb_layers() + + def init_weights(self) -> None: + super().init_weights() + bias_init = bias_init_with_prob(0.1) + for i in range(self.num_feat_levels): + # The initialization of parameters are different between + # nn.Conv2d and ConvModule. Our experiments show that + # using the original initialization of nn.Conv2d increases + # the final mAP by about 0.2% + self.tl_heat[i][-1].conv.reset_parameters() + self.tl_heat[i][-1].conv.bias.data.fill_(bias_init) + self.br_heat[i][-1].conv.reset_parameters() + self.br_heat[i][-1].conv.bias.data.fill_(bias_init) + self.tl_off[i][-1].conv.reset_parameters() + self.br_off[i][-1].conv.reset_parameters() + if self.with_corner_emb: + self.tl_emb[i][-1].conv.reset_parameters() + self.br_emb[i][-1].conv.reset_parameters() + + def forward(self, feats: Tuple[Tensor]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of corner heatmaps, offset heatmaps and + embedding heatmaps. + - tl_heats (list[Tensor]): Top-left corner heatmaps for all + levels, each is a 4D-tensor, the channels number is + num_classes. + - br_heats (list[Tensor]): Bottom-right corner heatmaps for all + levels, each is a 4D-tensor, the channels number is + num_classes. + - tl_embs (list[Tensor] | list[None]): Top-left embedding + heatmaps for all levels, each is a 4D-tensor or None. + If not None, the channels number is corner_emb_channels. + - br_embs (list[Tensor] | list[None]): Bottom-right embedding + heatmaps for all levels, each is a 4D-tensor or None. + If not None, the channels number is corner_emb_channels. + - tl_offs (list[Tensor]): Top-left offset heatmaps for all + levels, each is a 4D-tensor. The channels number is + corner_offset_channels. + - br_offs (list[Tensor]): Bottom-right offset heatmaps for all + levels, each is a 4D-tensor. The channels number is + corner_offset_channels. + """ + lvl_ind = list(range(self.num_feat_levels)) + return multi_apply(self.forward_single, feats, lvl_ind) + + def forward_single(self, + x: Tensor, + lvl_ind: int, + return_pool: bool = False) -> List[Tensor]: + """Forward feature of a single level. + + Args: + x (Tensor): Feature of a single level. + lvl_ind (int): Level index of current feature. + return_pool (bool): Return corner pool feature or not. + Defaults to False. + + Returns: + tuple[Tensor]: A tuple of CornerHead's output for current feature + level. Containing the following Tensors: + + - tl_heat (Tensor): Predicted top-left corner heatmap. + - br_heat (Tensor): Predicted bottom-right corner heatmap. + - tl_emb (Tensor | None): Predicted top-left embedding heatmap. + None for `self.with_corner_emb == False`. + - br_emb (Tensor | None): Predicted bottom-right embedding + heatmap. None for `self.with_corner_emb == False`. + - tl_off (Tensor): Predicted top-left offset heatmap. + - br_off (Tensor): Predicted bottom-right offset heatmap. + - tl_pool (Tensor): Top-left corner pool feature. Not must + have. + - br_pool (Tensor): Bottom-right corner pool feature. Not must + have. + """ + tl_pool = self.tl_pool[lvl_ind](x) + tl_heat = self.tl_heat[lvl_ind](tl_pool) + br_pool = self.br_pool[lvl_ind](x) + br_heat = self.br_heat[lvl_ind](br_pool) + + tl_emb, br_emb = None, None + if self.with_corner_emb: + tl_emb = self.tl_emb[lvl_ind](tl_pool) + br_emb = self.br_emb[lvl_ind](br_pool) + + tl_off = self.tl_off[lvl_ind](tl_pool) + br_off = self.br_off[lvl_ind](br_pool) + + result_list = [tl_heat, br_heat, tl_emb, br_emb, tl_off, br_off] + if return_pool: + result_list.append(tl_pool) + result_list.append(br_pool) + + return result_list + + def get_targets(self, + gt_bboxes: List[Tensor], + gt_labels: List[Tensor], + feat_shape: Sequence[int], + img_shape: Sequence[int], + with_corner_emb: bool = False, + with_guiding_shift: bool = False, + with_centripetal_shift: bool = False) -> dict: + """Generate corner targets. + + Including corner heatmap, corner offset. + + Optional: corner embedding, corner guiding shift, centripetal shift. + + For CornerNet, we generate corner heatmap, corner offset and corner + embedding from this function. + + For CentripetalNet, we generate corner heatmap, corner offset, guiding + shift and centripetal shift from this function. + + Args: + gt_bboxes (list[Tensor]): Ground truth bboxes of each image, each + has shape (num_gt, 4). + gt_labels (list[Tensor]): Ground truth labels of each box, each has + shape (num_gt, ). + feat_shape (Sequence[int]): Shape of output feature, + [batch, channel, height, width]. + img_shape (Sequence[int]): Shape of input image, + [height, width, channel]. + with_corner_emb (bool): Generate corner embedding target or not. + Defaults to False. + with_guiding_shift (bool): Generate guiding shift target or not. + Defaults to False. + with_centripetal_shift (bool): Generate centripetal shift target or + not. Defaults to False. + + Returns: + dict: Ground truth of corner heatmap, corner offset, corner + embedding, guiding shift and centripetal shift. Containing the + following keys: + + - topleft_heatmap (Tensor): Ground truth top-left corner + heatmap. + - bottomright_heatmap (Tensor): Ground truth bottom-right + corner heatmap. + - topleft_offset (Tensor): Ground truth top-left corner offset. + - bottomright_offset (Tensor): Ground truth bottom-right corner + offset. + - corner_embedding (list[list[list[int]]]): Ground truth corner + embedding. Not must have. + - topleft_guiding_shift (Tensor): Ground truth top-left corner + guiding shift. Not must have. + - bottomright_guiding_shift (Tensor): Ground truth bottom-right + corner guiding shift. Not must have. + - topleft_centripetal_shift (Tensor): Ground truth top-left + corner centripetal shift. Not must have. + - bottomright_centripetal_shift (Tensor): Ground truth + bottom-right corner centripetal shift. Not must have. + """ + batch_size, _, height, width = feat_shape + img_h, img_w = img_shape[:2] + + width_ratio = float(width / img_w) + height_ratio = float(height / img_h) + + gt_tl_heatmap = gt_bboxes[-1].new_zeros( + [batch_size, self.num_classes, height, width]) + gt_br_heatmap = gt_bboxes[-1].new_zeros( + [batch_size, self.num_classes, height, width]) + gt_tl_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width]) + gt_br_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width]) + + if with_corner_emb: + match = [] + + # Guiding shift is a kind of offset, from center to corner + if with_guiding_shift: + gt_tl_guiding_shift = gt_bboxes[-1].new_zeros( + [batch_size, 2, height, width]) + gt_br_guiding_shift = gt_bboxes[-1].new_zeros( + [batch_size, 2, height, width]) + # Centripetal shift is also a kind of offset, from center to corner + # and normalized by log. + if with_centripetal_shift: + gt_tl_centripetal_shift = gt_bboxes[-1].new_zeros( + [batch_size, 2, height, width]) + gt_br_centripetal_shift = gt_bboxes[-1].new_zeros( + [batch_size, 2, height, width]) + + for batch_id in range(batch_size): + # Ground truth of corner embedding per image is a list of coord set + corner_match = [] + for box_id in range(len(gt_labels[batch_id])): + left, top, right, bottom = gt_bboxes[batch_id][box_id] + center_x = (left + right) / 2.0 + center_y = (top + bottom) / 2.0 + label = gt_labels[batch_id][box_id] + + # Use coords in the feature level to generate ground truth + scale_left = left * width_ratio + scale_right = right * width_ratio + scale_top = top * height_ratio + scale_bottom = bottom * height_ratio + scale_center_x = center_x * width_ratio + scale_center_y = center_y * height_ratio + + # Int coords on feature map/ground truth tensor + left_idx = int(min(scale_left, width - 1)) + right_idx = int(min(scale_right, width - 1)) + top_idx = int(min(scale_top, height - 1)) + bottom_idx = int(min(scale_bottom, height - 1)) + + # Generate gaussian heatmap + scale_box_width = ceil(scale_right - scale_left) + scale_box_height = ceil(scale_bottom - scale_top) + radius = gaussian_radius((scale_box_height, scale_box_width), + min_overlap=0.3) + radius = max(0, int(radius)) + gt_tl_heatmap[batch_id, label] = gen_gaussian_target( + gt_tl_heatmap[batch_id, label], [left_idx, top_idx], + radius) + gt_br_heatmap[batch_id, label] = gen_gaussian_target( + gt_br_heatmap[batch_id, label], [right_idx, bottom_idx], + radius) + + # Generate corner offset + left_offset = scale_left - left_idx + top_offset = scale_top - top_idx + right_offset = scale_right - right_idx + bottom_offset = scale_bottom - bottom_idx + gt_tl_offset[batch_id, 0, top_idx, left_idx] = left_offset + gt_tl_offset[batch_id, 1, top_idx, left_idx] = top_offset + gt_br_offset[batch_id, 0, bottom_idx, right_idx] = right_offset + gt_br_offset[batch_id, 1, bottom_idx, + right_idx] = bottom_offset + + # Generate corner embedding + if with_corner_emb: + corner_match.append([[top_idx, left_idx], + [bottom_idx, right_idx]]) + # Generate guiding shift + if with_guiding_shift: + gt_tl_guiding_shift[batch_id, 0, top_idx, + left_idx] = scale_center_x - left_idx + gt_tl_guiding_shift[batch_id, 1, top_idx, + left_idx] = scale_center_y - top_idx + gt_br_guiding_shift[batch_id, 0, bottom_idx, + right_idx] = right_idx - scale_center_x + gt_br_guiding_shift[ + batch_id, 1, bottom_idx, + right_idx] = bottom_idx - scale_center_y + # Generate centripetal shift + if with_centripetal_shift: + gt_tl_centripetal_shift[batch_id, 0, top_idx, + left_idx] = log(scale_center_x - + scale_left) + gt_tl_centripetal_shift[batch_id, 1, top_idx, + left_idx] = log(scale_center_y - + scale_top) + gt_br_centripetal_shift[batch_id, 0, bottom_idx, + right_idx] = log(scale_right - + scale_center_x) + gt_br_centripetal_shift[batch_id, 1, bottom_idx, + right_idx] = log(scale_bottom - + scale_center_y) + + if with_corner_emb: + match.append(corner_match) + + target_result = dict( + topleft_heatmap=gt_tl_heatmap, + topleft_offset=gt_tl_offset, + bottomright_heatmap=gt_br_heatmap, + bottomright_offset=gt_br_offset) + + if with_corner_emb: + target_result.update(corner_embedding=match) + if with_guiding_shift: + target_result.update( + topleft_guiding_shift=gt_tl_guiding_shift, + bottomright_guiding_shift=gt_br_guiding_shift) + if with_centripetal_shift: + target_result.update( + topleft_centripetal_shift=gt_tl_centripetal_shift, + bottomright_centripetal_shift=gt_br_centripetal_shift) + + return target_result + + def loss_by_feat( + self, + tl_heats: List[Tensor], + br_heats: List[Tensor], + tl_embs: List[Tensor], + br_embs: List[Tensor], + tl_offs: List[Tensor], + br_offs: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + tl_heats (list[Tensor]): Top-left corner heatmaps for each level + with shape (N, num_classes, H, W). + br_heats (list[Tensor]): Bottom-right corner heatmaps for each + level with shape (N, num_classes, H, W). + tl_embs (list[Tensor]): Top-left corner embeddings for each level + with shape (N, corner_emb_channels, H, W). + br_embs (list[Tensor]): Bottom-right corner embeddings for each + level with shape (N, corner_emb_channels, H, W). + tl_offs (list[Tensor]): Top-left corner offsets for each level + with shape (N, corner_offset_channels, H, W). + br_offs (list[Tensor]): Bottom-right corner offsets for each level + with shape (N, corner_offset_channels, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Specify which bounding boxes can be ignored when computing + the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. Containing the + following losses: + + - det_loss (list[Tensor]): Corner keypoint losses of all + feature levels. + - pull_loss (list[Tensor]): Part one of AssociativeEmbedding + losses of all feature levels. + - push_loss (list[Tensor]): Part two of AssociativeEmbedding + losses of all feature levels. + - off_loss (list[Tensor]): Corner offset losses of all feature + levels. + """ + gt_bboxes = [ + gt_instances.bboxes for gt_instances in batch_gt_instances + ] + gt_labels = [ + gt_instances.labels for gt_instances in batch_gt_instances + ] + + targets = self.get_targets( + gt_bboxes, + gt_labels, + tl_heats[-1].shape, + batch_img_metas[0]['batch_input_shape'], + with_corner_emb=self.with_corner_emb) + mlvl_targets = [targets for _ in range(self.num_feat_levels)] + det_losses, pull_losses, push_losses, off_losses = multi_apply( + self.loss_by_feat_single, tl_heats, br_heats, tl_embs, br_embs, + tl_offs, br_offs, mlvl_targets) + loss_dict = dict(det_loss=det_losses, off_loss=off_losses) + if self.with_corner_emb: + loss_dict.update(pull_loss=pull_losses, push_loss=push_losses) + return loss_dict + + def loss_by_feat_single(self, tl_hmp: Tensor, br_hmp: Tensor, + tl_emb: Optional[Tensor], br_emb: Optional[Tensor], + tl_off: Tensor, br_off: Tensor, + targets: dict) -> Tuple[Tensor, ...]: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + tl_hmp (Tensor): Top-left corner heatmap for current level with + shape (N, num_classes, H, W). + br_hmp (Tensor): Bottom-right corner heatmap for current level with + shape (N, num_classes, H, W). + tl_emb (Tensor, optional): Top-left corner embedding for current + level with shape (N, corner_emb_channels, H, W). + br_emb (Tensor, optional): Bottom-right corner embedding for + current level with shape (N, corner_emb_channels, H, W). + tl_off (Tensor): Top-left corner offset for current level with + shape (N, corner_offset_channels, H, W). + br_off (Tensor): Bottom-right corner offset for current level with + shape (N, corner_offset_channels, H, W). + targets (dict): Corner target generated by `get_targets`. + + Returns: + tuple[torch.Tensor]: Losses of the head's different branches + containing the following losses: + + - det_loss (Tensor): Corner keypoint loss. + - pull_loss (Tensor): Part one of AssociativeEmbedding loss. + - push_loss (Tensor): Part two of AssociativeEmbedding loss. + - off_loss (Tensor): Corner offset loss. + """ + gt_tl_hmp = targets['topleft_heatmap'] + gt_br_hmp = targets['bottomright_heatmap'] + gt_tl_off = targets['topleft_offset'] + gt_br_off = targets['bottomright_offset'] + gt_embedding = targets['corner_embedding'] + + # Detection loss + tl_det_loss = self.loss_heatmap( + tl_hmp.sigmoid(), + gt_tl_hmp, + avg_factor=max(1, + gt_tl_hmp.eq(1).sum())) + br_det_loss = self.loss_heatmap( + br_hmp.sigmoid(), + gt_br_hmp, + avg_factor=max(1, + gt_br_hmp.eq(1).sum())) + det_loss = (tl_det_loss + br_det_loss) / 2.0 + + # AssociativeEmbedding loss + if self.with_corner_emb and self.loss_embedding is not None: + pull_loss, push_loss = self.loss_embedding(tl_emb, br_emb, + gt_embedding) + else: + pull_loss, push_loss = None, None + + # Offset loss + # We only compute the offset loss at the real corner position. + # The value of real corner would be 1 in heatmap ground truth. + # The mask is computed in class agnostic mode and its shape is + # batch * 1 * width * height. + tl_off_mask = gt_tl_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as( + gt_tl_hmp) + br_off_mask = gt_br_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as( + gt_br_hmp) + tl_off_loss = self.loss_offset( + tl_off, + gt_tl_off, + tl_off_mask, + avg_factor=max(1, tl_off_mask.sum())) + br_off_loss = self.loss_offset( + br_off, + gt_br_off, + br_off_mask, + avg_factor=max(1, br_off_mask.sum())) + + off_loss = (tl_off_loss + br_off_loss) / 2.0 + + return det_loss, pull_loss, push_loss, off_loss + + def predict_by_feat(self, + tl_heats: List[Tensor], + br_heats: List[Tensor], + tl_embs: List[Tensor], + br_embs: List[Tensor], + tl_offs: List[Tensor], + br_offs: List[Tensor], + batch_img_metas: Optional[List[dict]] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + tl_heats (list[Tensor]): Top-left corner heatmaps for each level + with shape (N, num_classes, H, W). + br_heats (list[Tensor]): Bottom-right corner heatmaps for each + level with shape (N, num_classes, H, W). + tl_embs (list[Tensor]): Top-left corner embeddings for each level + with shape (N, corner_emb_channels, H, W). + br_embs (list[Tensor]): Bottom-right corner embeddings for each + level with shape (N, corner_emb_channels, H, W). + tl_offs (list[Tensor]): Top-left corner offsets for each level + with shape (N, corner_offset_channels, H, W). + br_offs (list[Tensor]): Bottom-right corner offsets for each level + with shape (N, corner_offset_channels, H, W). + batch_img_metas (list[dict], optional): Batch image meta info. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len( + batch_img_metas) + result_list = [] + for img_id in range(len(batch_img_metas)): + result_list.append( + self._predict_by_feat_single( + tl_heats[-1][img_id:img_id + 1, :], + br_heats[-1][img_id:img_id + 1, :], + tl_offs[-1][img_id:img_id + 1, :], + br_offs[-1][img_id:img_id + 1, :], + batch_img_metas[img_id], + tl_emb=tl_embs[-1][img_id:img_id + 1, :], + br_emb=br_embs[-1][img_id:img_id + 1, :], + rescale=rescale, + with_nms=with_nms)) + + return result_list + + def _predict_by_feat_single(self, + tl_heat: Tensor, + br_heat: Tensor, + tl_off: Tensor, + br_off: Tensor, + img_meta: dict, + tl_emb: Optional[Tensor] = None, + br_emb: Optional[Tensor] = None, + tl_centripetal_shift: Optional[Tensor] = None, + br_centripetal_shift: Optional[Tensor] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + tl_heat (Tensor): Top-left corner heatmap for current level with + shape (N, num_classes, H, W). + br_heat (Tensor): Bottom-right corner heatmap for current level + with shape (N, num_classes, H, W). + tl_off (Tensor): Top-left corner offset for current level with + shape (N, corner_offset_channels, H, W). + br_off (Tensor): Bottom-right corner offset for current level with + shape (N, corner_offset_channels, H, W). + img_meta (dict): Meta information of current image, e.g., + image size, scaling factor, etc. + tl_emb (Tensor): Top-left corner embedding for current level with + shape (N, corner_emb_channels, H, W). + br_emb (Tensor): Bottom-right corner embedding for current level + with shape (N, corner_emb_channels, H, W). + tl_centripetal_shift: Top-left corner's centripetal shift for + current level with shape (N, 2, H, W). + br_centripetal_shift: Bottom-right corner's centripetal shift for + current level with shape (N, 2, H, W). + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if isinstance(img_meta, (list, tuple)): + img_meta = img_meta[0] + + batch_bboxes, batch_scores, batch_clses = self._decode_heatmap( + tl_heat=tl_heat.sigmoid(), + br_heat=br_heat.sigmoid(), + tl_off=tl_off, + br_off=br_off, + tl_emb=tl_emb, + br_emb=br_emb, + tl_centripetal_shift=tl_centripetal_shift, + br_centripetal_shift=br_centripetal_shift, + img_meta=img_meta, + k=self.test_cfg.corner_topk, + kernel=self.test_cfg.local_maximum_kernel, + distance_threshold=self.test_cfg.distance_threshold) + + if rescale and 'scale_factor' in img_meta: + batch_bboxes /= batch_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + + bboxes = batch_bboxes.view([-1, 4]) + scores = batch_scores.view(-1) + clses = batch_clses.view(-1) + + det_bboxes = torch.cat([bboxes, scores.unsqueeze(-1)], -1) + keepinds = (det_bboxes[:, -1] > -0.1) + det_bboxes = det_bboxes[keepinds] + det_labels = clses[keepinds] + + if with_nms: + det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels, + self.test_cfg) + + results = InstanceData() + results.bboxes = det_bboxes[..., :4] + results.scores = det_bboxes[..., 4] + results.labels = det_labels + return results + + def _bboxes_nms(self, bboxes: Tensor, labels: Tensor, + cfg: ConfigDict) -> Tuple[Tensor, Tensor]: + """bboxes nms.""" + if 'nms_cfg' in cfg: + warning.warn('nms_cfg in test_cfg will be deprecated. ' + 'Please rename it as nms') + if 'nms' not in cfg: + cfg.nms = cfg.nms_cfg + + if labels.numel() > 0: + max_num = cfg.max_per_img + bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, + -1].contiguous(), + labels, cfg.nms) + if max_num > 0: + bboxes = bboxes[:max_num] + labels = labels[keep][:max_num] + + return bboxes, labels + + def _decode_heatmap(self, + tl_heat: Tensor, + br_heat: Tensor, + tl_off: Tensor, + br_off: Tensor, + tl_emb: Optional[Tensor] = None, + br_emb: Optional[Tensor] = None, + tl_centripetal_shift: Optional[Tensor] = None, + br_centripetal_shift: Optional[Tensor] = None, + img_meta: Optional[dict] = None, + k: int = 100, + kernel: int = 3, + distance_threshold: float = 0.5, + num_dets: int = 1000) -> Tuple[Tensor, Tensor, Tensor]: + """Transform outputs into detections raw bbox prediction. + + Args: + tl_heat (Tensor): Top-left corner heatmap for current level with + shape (N, num_classes, H, W). + br_heat (Tensor): Bottom-right corner heatmap for current level + with shape (N, num_classes, H, W). + tl_off (Tensor): Top-left corner offset for current level with + shape (N, corner_offset_channels, H, W). + br_off (Tensor): Bottom-right corner offset for current level with + shape (N, corner_offset_channels, H, W). + tl_emb (Tensor, Optional): Top-left corner embedding for current + level with shape (N, corner_emb_channels, H, W). + br_emb (Tensor, Optional): Bottom-right corner embedding for + current level with shape (N, corner_emb_channels, H, W). + tl_centripetal_shift (Tensor, Optional): Top-left centripetal shift + for current level with shape (N, 2, H, W). + br_centripetal_shift (Tensor, Optional): Bottom-right centripetal + shift for current level with shape (N, 2, H, W). + img_meta (dict): Meta information of current image, e.g., + image size, scaling factor, etc. + k (int): Get top k corner keypoints from heatmap. + kernel (int): Max pooling kernel for extract local maximum pixels. + distance_threshold (float): Distance threshold. Top-left and + bottom-right corner keypoints with feature distance less than + the threshold will be regarded as keypoints from same object. + num_dets (int): Num of raw boxes before doing nms. + + Returns: + tuple[torch.Tensor]: Decoded output of CornerHead, containing the + following Tensors: + + - bboxes (Tensor): Coords of each box. + - scores (Tensor): Scores of each box. + - clses (Tensor): Categories of each box. + """ + with_embedding = tl_emb is not None and br_emb is not None + with_centripetal_shift = ( + tl_centripetal_shift is not None + and br_centripetal_shift is not None) + assert with_embedding + with_centripetal_shift == 1 + batch, _, height, width = tl_heat.size() + if torch.onnx.is_in_onnx_export(): + inp_h, inp_w = img_meta['pad_shape_for_onnx'][:2] + else: + inp_h, inp_w = img_meta['batch_input_shape'][:2] + + # perform nms on heatmaps + tl_heat = get_local_maximum(tl_heat, kernel=kernel) + br_heat = get_local_maximum(br_heat, kernel=kernel) + + tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = get_topk_from_heatmap( + tl_heat, k=k) + br_scores, br_inds, br_clses, br_ys, br_xs = get_topk_from_heatmap( + br_heat, k=k) + + # We use repeat instead of expand here because expand is a + # shallow-copy function. Thus it could cause unexpected testing result + # sometimes. Using expand will decrease about 10% mAP during testing + # compared to repeat. + tl_ys = tl_ys.view(batch, k, 1).repeat(1, 1, k) + tl_xs = tl_xs.view(batch, k, 1).repeat(1, 1, k) + br_ys = br_ys.view(batch, 1, k).repeat(1, k, 1) + br_xs = br_xs.view(batch, 1, k).repeat(1, k, 1) + + tl_off = transpose_and_gather_feat(tl_off, tl_inds) + tl_off = tl_off.view(batch, k, 1, 2) + br_off = transpose_and_gather_feat(br_off, br_inds) + br_off = br_off.view(batch, 1, k, 2) + + tl_xs = tl_xs + tl_off[..., 0] + tl_ys = tl_ys + tl_off[..., 1] + br_xs = br_xs + br_off[..., 0] + br_ys = br_ys + br_off[..., 1] + + if with_centripetal_shift: + tl_centripetal_shift = transpose_and_gather_feat( + tl_centripetal_shift, tl_inds).view(batch, k, 1, 2).exp() + br_centripetal_shift = transpose_and_gather_feat( + br_centripetal_shift, br_inds).view(batch, 1, k, 2).exp() + + tl_ctxs = tl_xs + tl_centripetal_shift[..., 0] + tl_ctys = tl_ys + tl_centripetal_shift[..., 1] + br_ctxs = br_xs - br_centripetal_shift[..., 0] + br_ctys = br_ys - br_centripetal_shift[..., 1] + + # all possible boxes based on top k corners (ignoring class) + tl_xs *= (inp_w / width) + tl_ys *= (inp_h / height) + br_xs *= (inp_w / width) + br_ys *= (inp_h / height) + + if with_centripetal_shift: + tl_ctxs *= (inp_w / width) + tl_ctys *= (inp_h / height) + br_ctxs *= (inp_w / width) + br_ctys *= (inp_h / height) + + x_off, y_off = 0, 0 # no crop + if not torch.onnx.is_in_onnx_export(): + # since `RandomCenterCropPad` is done on CPU with numpy and it's + # not dynamic traceable when exporting to ONNX, thus 'border' + # does not appears as key in 'img_meta'. As a tmp solution, + # we move this 'border' handle part to the postprocess after + # finished exporting to ONNX, which is handle in + # `mmdet/core/export/model_wrappers.py`. Though difference between + # pytorch and exported onnx model, it might be ignored since + # comparable performance is achieved between them (e.g. 40.4 vs + # 40.6 on COCO val2017, for CornerNet without test-time flip) + if 'border' in img_meta: + x_off = img_meta['border'][2] + y_off = img_meta['border'][0] + + tl_xs -= x_off + tl_ys -= y_off + br_xs -= x_off + br_ys -= y_off + + zeros = tl_xs.new_zeros(*tl_xs.size()) + tl_xs = torch.where(tl_xs > 0.0, tl_xs, zeros) + tl_ys = torch.where(tl_ys > 0.0, tl_ys, zeros) + br_xs = torch.where(br_xs > 0.0, br_xs, zeros) + br_ys = torch.where(br_ys > 0.0, br_ys, zeros) + + bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3) + area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs() + + if with_centripetal_shift: + tl_ctxs -= x_off + tl_ctys -= y_off + br_ctxs -= x_off + br_ctys -= y_off + + tl_ctxs *= tl_ctxs.gt(0.0).type_as(tl_ctxs) + tl_ctys *= tl_ctys.gt(0.0).type_as(tl_ctys) + br_ctxs *= br_ctxs.gt(0.0).type_as(br_ctxs) + br_ctys *= br_ctys.gt(0.0).type_as(br_ctys) + + ct_bboxes = torch.stack((tl_ctxs, tl_ctys, br_ctxs, br_ctys), + dim=3) + area_ct_bboxes = ((br_ctxs - tl_ctxs) * (br_ctys - tl_ctys)).abs() + + rcentral = torch.zeros_like(ct_bboxes) + # magic nums from paper section 4.1 + mu = torch.ones_like(area_bboxes) / 2.4 + mu[area_bboxes > 3500] = 1 / 2.1 # large bbox have smaller mu + + bboxes_center_x = (bboxes[..., 0] + bboxes[..., 2]) / 2 + bboxes_center_y = (bboxes[..., 1] + bboxes[..., 3]) / 2 + rcentral[..., 0] = bboxes_center_x - mu * (bboxes[..., 2] - + bboxes[..., 0]) / 2 + rcentral[..., 1] = bboxes_center_y - mu * (bboxes[..., 3] - + bboxes[..., 1]) / 2 + rcentral[..., 2] = bboxes_center_x + mu * (bboxes[..., 2] - + bboxes[..., 0]) / 2 + rcentral[..., 3] = bboxes_center_y + mu * (bboxes[..., 3] - + bboxes[..., 1]) / 2 + area_rcentral = ((rcentral[..., 2] - rcentral[..., 0]) * + (rcentral[..., 3] - rcentral[..., 1])).abs() + dists = area_ct_bboxes / area_rcentral + + tl_ctx_inds = (ct_bboxes[..., 0] <= rcentral[..., 0]) | ( + ct_bboxes[..., 0] >= rcentral[..., 2]) + tl_cty_inds = (ct_bboxes[..., 1] <= rcentral[..., 1]) | ( + ct_bboxes[..., 1] >= rcentral[..., 3]) + br_ctx_inds = (ct_bboxes[..., 2] <= rcentral[..., 0]) | ( + ct_bboxes[..., 2] >= rcentral[..., 2]) + br_cty_inds = (ct_bboxes[..., 3] <= rcentral[..., 1]) | ( + ct_bboxes[..., 3] >= rcentral[..., 3]) + + if with_embedding: + tl_emb = transpose_and_gather_feat(tl_emb, tl_inds) + tl_emb = tl_emb.view(batch, k, 1) + br_emb = transpose_and_gather_feat(br_emb, br_inds) + br_emb = br_emb.view(batch, 1, k) + dists = torch.abs(tl_emb - br_emb) + + tl_scores = tl_scores.view(batch, k, 1).repeat(1, 1, k) + br_scores = br_scores.view(batch, 1, k).repeat(1, k, 1) + + scores = (tl_scores + br_scores) / 2 # scores for all possible boxes + + # tl and br should have same class + tl_clses = tl_clses.view(batch, k, 1).repeat(1, 1, k) + br_clses = br_clses.view(batch, 1, k).repeat(1, k, 1) + cls_inds = (tl_clses != br_clses) + + # reject boxes based on distances + dist_inds = dists > distance_threshold + + # reject boxes based on widths and heights + width_inds = (br_xs <= tl_xs) + height_inds = (br_ys <= tl_ys) + + # No use `scores[cls_inds]`, instead we use `torch.where` here. + # Since only 1-D indices with type 'tensor(bool)' are supported + # when exporting to ONNX, any other bool indices with more dimensions + # (e.g. 2-D bool tensor) as input parameter in node is invalid + negative_scores = -1 * torch.ones_like(scores) + scores = torch.where(cls_inds, negative_scores, scores) + scores = torch.where(width_inds, negative_scores, scores) + scores = torch.where(height_inds, negative_scores, scores) + scores = torch.where(dist_inds, negative_scores, scores) + + if with_centripetal_shift: + scores[tl_ctx_inds] = -1 + scores[tl_cty_inds] = -1 + scores[br_ctx_inds] = -1 + scores[br_cty_inds] = -1 + + scores = scores.view(batch, -1) + scores, inds = torch.topk(scores, num_dets) + scores = scores.unsqueeze(2) + + bboxes = bboxes.view(batch, -1, 4) + bboxes = gather_feat(bboxes, inds) + + clses = tl_clses.contiguous().view(batch, -1, 1) + clses = gather_feat(clses, inds) + + return bboxes, scores, clses diff --git a/mmdet/models/dense_heads/dab_detr_head.py b/mmdet/models/dense_heads/dab_detr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..892833ffce5f17f6f9e82e67b7d32c6b9c1bafc0 --- /dev/null +++ b/mmdet/models/dense_heads/dab_detr_head.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch.nn as nn +from mmcv.cnn import Linear +from mmengine.model import bias_init_with_prob, constant_init +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import InstanceList +from ..layers import MLP, inverse_sigmoid +from .conditional_detr_head import ConditionalDETRHead + + +@MODELS.register_module() +class DABDETRHead(ConditionalDETRHead): + """Head of DAB-DETR. DAB-DETR: Dynamic Anchor Boxes are Better Queries for + DETR. + + More details can be found in the `paper + `_ . + """ + + def _init_layers(self) -> None: + """Initialize layers of the transformer head.""" + # cls branch + self.fc_cls = Linear(self.embed_dims, self.cls_out_channels) + # reg branch + self.fc_reg = MLP(self.embed_dims, self.embed_dims, 4, 3) + + def init_weights(self) -> None: + """initialize weights.""" + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.fc_cls.bias, bias_init) + constant_init(self.fc_reg.layers[-1], 0., bias=0.) + + def forward(self, hidden_states: Tensor, + references: Tensor) -> Tuple[Tensor, Tensor]: + """"Forward function. + + Args: + hidden_states (Tensor): Features from transformer decoder. If + `return_intermediate_dec` is True output has shape + (num_decoder_layers, bs, num_queries, dim), else has shape (1, + bs, num_queries, dim) which only contains the last layer + outputs. + references (Tensor): References from transformer decoder. If + `return_intermediate_dec` is True output has shape + (num_decoder_layers, bs, num_queries, 2/4), else has shape (1, + bs, num_queries, 2/4) + which only contains the last layer reference. + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - layers_cls_scores (Tensor): Outputs from the classification head, + shape (num_decoder_layers, bs, num_queries, cls_out_channels). + Note cls_out_channels should include background. + - layers_bbox_preds (Tensor): Sigmoid outputs from the regression + head with normalized coordinate format (cx, cy, w, h), has shape + (num_decoder_layers, bs, num_queries, 4). + """ + layers_cls_scores = self.fc_cls(hidden_states) + references_before_sigmoid = inverse_sigmoid(references, eps=1e-3) + tmp_reg_preds = self.fc_reg(hidden_states) + tmp_reg_preds[..., :references_before_sigmoid. + size(-1)] += references_before_sigmoid + layers_bbox_preds = tmp_reg_preds.sigmoid() + return layers_cls_scores, layers_bbox_preds + + def predict(self, + hidden_states: Tensor, + references: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. Over-write + because img_metas are needed as inputs for bbox_head. + + Args: + hidden_states (Tensor): Feature from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, dim). + references (Tensor): references from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, 2/4). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + last_layer_hidden_state = hidden_states[-1].unsqueeze(0) + last_layer_reference = references[-1].unsqueeze(0) + outs = self(last_layer_hidden_state, last_layer_reference) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions diff --git a/mmdet/models/dense_heads/ddod_head.py b/mmdet/models/dense_heads/ddod_head.py new file mode 100644 index 0000000000000000000000000000000000000000..64e91ff0135230a8d634c5964eb520e1461c872a --- /dev/null +++ b/mmdet/models/dense_heads/ddod_head.py @@ -0,0 +1,794 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Scale +from mmengine.model import bias_init_with_prob, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, reduce_mean) +from ..task_modules.prior_generators import anchor_inside_flags +from ..utils import images_to_levels, multi_apply, unmap +from .anchor_head import AnchorHead + +EPS = 1e-12 + + +@MODELS.register_module() +class DDODHead(AnchorHead): + """Detection Head of `DDOD `_. + + DDOD head decomposes conjunctions lying in most current one-stage + detectors via label assignment disentanglement, spatial feature + disentanglement, and pyramid supervision disentanglement. + + Args: + num_classes (int): Number of categories excluding the + background category. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): The number of stacked Conv. Defaults to 4. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + use_dcn (bool): Use dcn, Same as ATSS when False. Defaults to True. + norm_cfg (:obj:`ConfigDict` or dict): Normal config of ddod head. + Defaults to dict(type='GN', num_groups=32, requires_grad=True). + loss_iou (:obj:`ConfigDict` or dict): Config of IoU loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0). + """ + + def __init__(self, + num_classes: int, + in_channels: int, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + use_dcn: bool = True, + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + loss_iou: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + **kwargs) -> None: + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.use_dcn = use_dcn + super().__init__(num_classes, in_channels, **kwargs) + + if self.train_cfg: + self.cls_assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.reg_assigner = TASK_UTILS.build( + self.train_cfg['reg_assigner']) + self.loss_iou = MODELS.build(loss_iou) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=dict(type='DCN', deform_groups=1) + if i == 0 and self.use_dcn else self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=dict(type='DCN', deform_groups=1) + if i == 0 and self.use_dcn else self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.atss_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.atss_reg = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + self.atss_iou = nn.Conv2d( + self.feat_channels, self.num_base_priors * 1, 3, padding=1) + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + # we use the global list in loss + self.cls_num_pos_samples_per_level = [ + 0. for _ in range(len(self.prior_generator.strides)) + ] + self.reg_num_pos_samples_per_level = [ + 0. for _ in range(len(self.prior_generator.strides)) + ] + + def init_weights(self) -> None: + """Initialize weights of the head.""" + for m in self.cls_convs: + normal_init(m.conv, std=0.01) + for m in self.reg_convs: + normal_init(m.conv, std=0.01) + normal_init(self.atss_reg, std=0.01) + normal_init(self.atss_iou, std=0.01) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.atss_cls, std=0.01, bias=bias_cls) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores, bbox predictions, + and iou predictions. + + - cls_scores (list[Tensor]): Classification scores for all \ + scale levels, each is a 4D-tensor, the channels number is \ + num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all \ + scale levels, each is a 4D-tensor, the channels number is \ + num_base_priors * 4. + - iou_preds (list[Tensor]): IoU scores for all scale levels, \ + each is a 4D-tensor, the channels number is num_base_priors * 1. + """ + return multi_apply(self.forward_single, x, self.scales) + + def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + + Returns: + tuple: + + - cls_score (Tensor): Cls scores for a single scale level \ + the channels number is num_base_priors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for a single \ + scale level, the channels number is num_base_priors * 4. + - iou_pred (Tensor): Iou for a single scale level, the \ + channel number is (N, num_base_priors * 1, H, W). + """ + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.atss_cls(cls_feat) + # we just follow atss, not apply exp in bbox_pred + bbox_pred = scale(self.atss_reg(reg_feat)).float() + iou_pred = self.atss_iou(reg_feat) + return cls_score, bbox_pred, iou_pred + + def loss_cls_by_feat_single(self, cls_score: Tensor, labels: Tensor, + label_weights: Tensor, + reweight_factor: List[float], + avg_factor: float) -> Tuple[Tensor]: + """Compute cls loss of a single scale level. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_base_priors * num_classes, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + reweight_factor (List[float]): Reweight factor for cls and reg + loss. + avg_factor (float): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + Tuple[Tensor]: A tuple of loss components. + """ + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels).contiguous() + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + return reweight_factor * loss_cls, + + def loss_reg_by_feat_single(self, anchors: Tensor, bbox_pred: Tensor, + iou_pred: Tensor, labels, + label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, + reweight_factor: List[float], + avg_factor: float) -> Tuple[Tensor, Tensor]: + """Compute reg loss of a single scale level based on the features + extracted by the detection head. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_base_priors * 4, H, W). + iou_pred (Tensor): Iou for a single scale level, the + channel number is (N, num_base_priors * 1, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4) + reweight_factor (List[float]): Reweight factor for cls and reg + loss. + avg_factor (float): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + Returns: + Tuple[Tensor, Tensor]: A tuple of loss components. + """ + anchors = anchors.reshape(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1, ) + bbox_targets = bbox_targets.reshape(-1, 4) + bbox_weights = bbox_weights.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + iou_targets = label_weights.new_zeros(labels.shape) + iou_weights = label_weights.new_zeros(labels.shape) + iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero( + as_tuple=False)] = 1. + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & + (labels < bg_class_ind)).nonzero(as_tuple=False).squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + + pos_decode_bbox_pred = self.bbox_coder.decode( + pos_anchors, pos_bbox_pred) + pos_decode_bbox_targets = self.bbox_coder.decode( + pos_anchors, pos_bbox_targets) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + avg_factor=avg_factor) + + iou_targets[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + loss_iou = self.loss_iou( + iou_pred, iou_targets, iou_weights, avg_factor=avg_factor) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_iou = iou_pred.sum() * 0 + + return reweight_factor * loss_bbox, reweight_factor * loss_iou + + def calc_reweight_factor(self, labels_list: List[Tensor]) -> List[float]: + """Compute reweight_factor for regression and classification loss.""" + # get pos samples for each level + bg_class_ind = self.num_classes + for ii, each_level_label in enumerate(labels_list): + pos_inds = ((each_level_label >= 0) & + (each_level_label < bg_class_ind)).nonzero( + as_tuple=False).squeeze(1) + self.cls_num_pos_samples_per_level[ii] += len(pos_inds) + # get reweight factor from 1 ~ 2 with bilinear interpolation + min_pos_samples = min(self.cls_num_pos_samples_per_level) + max_pos_samples = max(self.cls_num_pos_samples_per_level) + interval = 1. / (max_pos_samples - min_pos_samples + 1e-10) + reweight_factor_per_level = [] + for pos_samples in self.cls_num_pos_samples_per_level: + factor = 2. - (pos_samples - min_pos_samples) * interval + reweight_factor_per_level.append(factor) + return reweight_factor_per_level + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + iou_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_base_priors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_base_priors * 4, H, W) + iou_preds (list[Tensor]): Score factor for all scale level, + each is a 4D-tensor, has shape (batch_size, 1, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + # calculate common vars for cls and reg assigners at once + targets_com = self.process_predictions_and_anchors( + anchor_list, valid_flag_list, cls_scores, bbox_preds, + batch_img_metas, batch_gt_instances_ignore) + (anchor_list, valid_flag_list, num_level_anchors_list, cls_score_list, + bbox_pred_list, batch_gt_instances_ignore) = targets_com + + # classification branch assigner + cls_targets = self.get_cls_targets( + anchor_list, + valid_flag_list, + num_level_anchors_list, + cls_score_list, + bbox_pred_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (cls_anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_targets + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + avg_factor = max(avg_factor, 1.0) + + reweight_factor_per_level = self.calc_reweight_factor(labels_list) + + cls_losses_cls, = multi_apply( + self.loss_cls_by_feat_single, + cls_scores, + labels_list, + label_weights_list, + reweight_factor_per_level, + avg_factor=avg_factor) + + # regression branch assigner + reg_targets = self.get_reg_targets( + anchor_list, + valid_flag_list, + num_level_anchors_list, + cls_score_list, + bbox_pred_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (reg_anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = reg_targets + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + avg_factor = max(avg_factor, 1.0) + + reweight_factor_per_level = self.calc_reweight_factor(labels_list) + + reg_losses_bbox, reg_losses_iou = multi_apply( + self.loss_reg_by_feat_single, + reg_anchor_list, + bbox_preds, + iou_preds, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + reweight_factor_per_level, + avg_factor=avg_factor) + + return dict( + loss_cls=cls_losses_cls, + loss_bbox=reg_losses_bbox, + loss_iou=reg_losses_iou) + + def process_predictions_and_anchors( + self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> tuple: + """Compute common vars for regression and classification targets. + + Args: + anchor_list (List[List[Tensor]]): anchors of each image. + valid_flag_list (List[List[Tensor]]): Valid flags of each image. + cls_scores (List[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Return: + tuple[Tensor]: A tuple of common loss vars. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + anchor_list_ = [] + valid_flag_list_ = [] + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list_.append(torch.cat(anchor_list[i])) + valid_flag_list_.append(torch.cat(valid_flag_list[i])) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None for _ in range(num_imgs)] + + num_levels = len(cls_scores) + cls_score_list = [] + bbox_pred_list = [] + + mlvl_cls_score_list = [ + cls_score.permute(0, 2, 3, 1).reshape( + num_imgs, -1, self.num_base_priors * self.cls_out_channels) + for cls_score in cls_scores + ] + mlvl_bbox_pred_list = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.num_base_priors * 4) + for bbox_pred in bbox_preds + ] + + for i in range(num_imgs): + mlvl_cls_tensor_list = [ + mlvl_cls_score_list[j][i] for j in range(num_levels) + ] + mlvl_bbox_tensor_list = [ + mlvl_bbox_pred_list[j][i] for j in range(num_levels) + ] + cat_mlvl_cls_score = torch.cat(mlvl_cls_tensor_list, dim=0) + cat_mlvl_bbox_pred = torch.cat(mlvl_bbox_tensor_list, dim=0) + cls_score_list.append(cat_mlvl_cls_score) + bbox_pred_list.append(cat_mlvl_bbox_pred) + return (anchor_list_, valid_flag_list_, num_level_anchors_list, + cls_score_list, bbox_pred_list, batch_gt_instances_ignore) + + def get_cls_targets(self, + anchor_list: List[Tensor], + valid_flag_list: List[Tensor], + num_level_anchors_list: List[int], + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Get cls targets for DDOD head. + + This method is almost the same as `AnchorHead.get_targets()`. + Besides returning the targets as the parent method does, + it also returns the anchors as the first element of the + returned tuple. + + Args: + anchor_list (list[Tensor]): anchors of each image. + valid_flag_list (list[Tensor]): Valid flags of each image. + num_level_anchors_list (list[Tensor]): Number of anchors of each + scale level of all image. + cls_score_list (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + bbox_pred_list (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Return: + tuple[Tensor]: A tuple of cls targets components. + """ + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, + anchor_list, + valid_flag_list, + cls_score_list, + bbox_pred_list, + num_level_anchors_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs, + is_cls_assigner=True) + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0]) + labels_list = images_to_levels(all_labels, num_level_anchors_list[0]) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors_list[0]) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors_list[0]) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors_list[0]) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, avg_factor) + + def get_reg_targets(self, + anchor_list: List[Tensor], + valid_flag_list: List[Tensor], + num_level_anchors_list: List[int], + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Get reg targets for DDOD head. + + This method is almost the same as `AnchorHead.get_targets()` when + is_cls_assigner is False. Besides returning the targets as the parent + method does, it also returns the anchors as the first element of the + returned tuple. + + Args: + anchor_list (list[Tensor]): anchors of each image. + valid_flag_list (list[Tensor]): Valid flags of each image. + num_level_anchors_list (list[Tensor]): Number of anchors of each + scale level of all image. + cls_score_list (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + bbox_pred_list (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Return: + tuple[Tensor]: A tuple of reg targets components. + """ + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, + anchor_list, + valid_flag_list, + cls_score_list, + bbox_pred_list, + num_level_anchors_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs, + is_cls_assigner=False) + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0]) + labels_list = images_to_levels(all_labels, num_level_anchors_list[0]) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors_list[0]) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors_list[0]) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors_list[0]) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, avg_factor) + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + cls_scores: Tensor, + bbox_preds: Tensor, + num_level_anchors: List[int], + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True, + is_cls_assigner: bool = True) -> tuple: + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, + which are concatenated into a single tensor of shape + (num_base_priors, 4). + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_base_priors,). + cls_scores (Tensor): Classification scores for all scale + levels of the image. + bbox_preds (Tensor): Box energies / deltas for all scale + levels of the image. + num_level_anchors (List[int]): Number of anchors of each + scale level. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + is_cls_assigner (bool): Classification or regression. + Defaults to True. + + Returns: + tuple: N is the number of total anchors in the image. + - anchors (Tensor): all anchors in the image with shape (N, 4). + - labels (Tensor): Labels of all anchors in the image with \ + shape (N, ). + - label_weights (Tensor): Label weights of all anchor in the \ + image with shape (N, ). + - bbox_targets (Tensor): BBox targets of all anchors in the \ + image with shape (N, 4). + - bbox_weights (Tensor): BBox weights of all anchors in the \ + image with shape (N, 4) + - pos_inds (Tensor): Indices of positive anchor with shape \ + (num_pos, ). + - neg_inds (Tensor): Indices of negative anchor with shape \ + (num_neg, ). + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + bbox_preds_valid = bbox_preds[inside_flags, :] + cls_scores_valid = cls_scores[inside_flags, :] + + assigner = self.cls_assigner if is_cls_assigner else self.reg_assigner + + # decode prediction out of assigner + bbox_preds_valid = self.bbox_coder.decode(anchors, bbox_preds_valid) + pred_instances = InstanceData( + priors=anchors, bboxes=bbox_preds_valid, scores=cls_scores_valid) + + assign_result = assigner.assign( + pred_instances=pred_instances, + num_level_priors=num_level_anchors_inside, + gt_instances=gt_instances, + gt_instances_ignore=gt_instances_ignore) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds, sampling_result) + + def get_num_level_anchors_inside(self, num_level_anchors: List[int], + inside_flags: Tensor) -> List[int]: + """Get the anchors of each scale level inside. + + Args: + num_level_anchors (list[int]): Number of anchors of each + scale level. + inside_flags (Tensor): Multi level inside flags of the image, + which are concatenated into a single tensor of + shape (num_base_priors,). + + Returns: + list[int]: Number of anchors of each scale level inside. + """ + split_inside_flags = torch.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside diff --git a/mmdet/models/dense_heads/ddq_detr_head.py b/mmdet/models/dense_heads/ddq_detr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0580653ac264ea0a597eec76624ab7eb3c7f6a10 --- /dev/null +++ b/mmdet/models/dense_heads/ddq_detr_head.py @@ -0,0 +1,550 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Tuple + +import torch +from mmengine.model import bias_init_with_prob, constant_init +from torch import Tensor, nn + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy +from mmdet.utils import InstanceList, OptInstanceList, reduce_mean +from ..layers import inverse_sigmoid +from ..losses import DDQAuxLoss +from ..utils import multi_apply +from .dino_head import DINOHead + + +@MODELS.register_module() +class DDQDETRHead(DINOHead): + r"""Head of DDQDETR: Dense Distinct Query for + End-to-End Object Detection. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + aux_num_pos (int): Number of positive targets assigned to a + perdicted object. Defaults to 4. + """ + + def __init__(self, *args, aux_num_pos=4, **kwargs): + super(DDQDETRHead, self).__init__(*args, **kwargs) + self.aux_loss_for_dense = DDQAuxLoss( + train_cfg=dict( + assigner=dict(type='TopkHungarianAssigner', topk=aux_num_pos), + alpha=1, + beta=6)) + + def _init_layers(self) -> None: + """Initialize classification branch and regression branch of aux head + for dense queries.""" + super(DDQDETRHead, self)._init_layers() + # If decoder `num_layers` = 6 and `as_two_stage` = True, then: + # 1) 6 main heads are required for + # each decoder output of distinct queries. + # 2) 1 main head is required for `output_memory` of distinct queries. + # 3) 1 aux head is required for `output_memory` of dense queries, + # which is done by code below this comment. + # So 8 heads are required in sum. + # aux head for dense queries on encoder feature map + self.cls_branches.append(copy.deepcopy(self.cls_branches[-1])) + self.reg_branches.append(copy.deepcopy(self.reg_branches[-1])) + + # If decoder `num_layers` = 6 and `as_two_stage` = True, then: + # 6 aux heads are required for each decoder output of dense queries. + # So 8 + 6 = 14 heads and heads are requires in sum. + # self.num_pred_layer is 7 + # aux head for dense queries in decoder + self.aux_cls_branches = nn.ModuleList([ + copy.deepcopy(self.cls_branches[-1]) + for _ in range(self.num_pred_layer - 1) + ]) + self.aux_reg_branches = nn.ModuleList([ + copy.deepcopy(self.reg_branches[-1]) + for _ in range(self.num_pred_layer - 1) + ]) + + def init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + nn.init.constant_(m.bias, bias_init) + for m in self.aux_cls_branches: + nn.init.constant_(m.bias, bias_init) + for m in self.reg_branches: + constant_init(m[-1], 0, bias=0) + for m in self.reg_branches: + nn.init.constant_(m[-1].bias.data[2:], 0.0) + + for m in self.aux_reg_branches: + constant_init(m[-1], 0, bias=0) + + for m in self.aux_reg_branches: + nn.init.constant_(m[-1].bias.data[2:], 0.0) + + def forward(self, hidden_states: Tensor, + references: List[Tensor]) -> Tuple[Tensor]: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries_total, + dim), where `num_queries_total` is the sum of + `num_denoising_queries`, `num_queries` and `num_dense_queries` + when `self.training` is `True`, else `num_queries`. + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). Each reference has shape (bs, + num_queries_total, 4) with the last dimension arranged as + (cx, cy, w, h). + + Returns: + tuple[Tensor]: results of head containing the following tensors. + + - all_layers_outputs_classes (Tensor): Outputs from the + classification head, has shape (num_decoder_layers, bs, + num_queries_total, cls_out_channels). + - all_layers_outputs_coords (Tensor): Sigmoid outputs from the + regression head with normalized coordinate format (cx, cy, w, + h), has shape (num_decoder_layers, bs, num_queries_total, 4) + with the last dimension arranged as (cx, cy, w, h). + """ + all_layers_outputs_classes = [] + all_layers_outputs_coords = [] + if self.training: + num_dense = self.cache_dict['num_dense_queries'] + for layer_id in range(hidden_states.shape[0]): + reference = inverse_sigmoid(references[layer_id]) + hidden_state = hidden_states[layer_id] + if self.training: + dense_hidden_state = hidden_state[:, -num_dense:] + hidden_state = hidden_state[:, :-num_dense] + + outputs_class = self.cls_branches[layer_id](hidden_state) + tmp_reg_preds = self.reg_branches[layer_id](hidden_state) + if self.training: + dense_outputs_class = self.aux_cls_branches[layer_id]( + dense_hidden_state) + dense_tmp_reg_preds = self.aux_reg_branches[layer_id]( + dense_hidden_state) + outputs_class = torch.cat([outputs_class, dense_outputs_class], + dim=1) + tmp_reg_preds = torch.cat([tmp_reg_preds, dense_tmp_reg_preds], + dim=1) + + if reference.shape[-1] == 4: + tmp_reg_preds += reference + else: + assert reference.shape[-1] == 2 + tmp_reg_preds[..., :2] += reference + outputs_coord = tmp_reg_preds.sigmoid() + all_layers_outputs_classes.append(outputs_class) + all_layers_outputs_coords.append(outputs_coord) + + all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) + all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) + + return all_layers_outputs_classes, all_layers_outputs_coords + + def loss(self, + hidden_states: Tensor, + references: List[Tensor], + enc_outputs_class: Tensor, + enc_outputs_coord: Tensor, + batch_data_samples: SampleList, + dn_meta: Dict[str, int], + aux_enc_outputs_class=None, + aux_enc_outputs_coord=None) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries_total, + dim), where `num_queries_total` is the sum of + `num_denoising_queries`, `num_queries` and `num_dense_queries` + when `self.training` is `True`, else `num_queries`. + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). Each reference has shape (bs, + num_queries_total, 4) with the last dimension arranged as + (cx, cy, w, h). + enc_outputs_class (Tensor): The top k classification score of + each point on encoder feature map, has shape (bs, num_queries, + cls_out_channels). + enc_outputs_coord (Tensor): The proposal generated from points + with top k score, has shape (bs, num_queries, 4) with the + last dimension arranged as (cx, cy, w, h). + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + aux_enc_outputs_class (Tensor): The `dense_topk` classification + score of each point on encoder feature map, has shape (bs, + num_dense_queries, cls_out_channels). + It is `None` when `self.training` is `False`. + aux_enc_outputs_coord (Tensor): The proposal generated from points + with `dense_topk` score, has shape (bs, num_dense_queries, 4) + with the last dimension arranged as (cx, cy, w, h). + It is `None` when `self.training` is `False`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references) + loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, + batch_gt_instances, batch_img_metas, dn_meta) + losses = self.loss_by_feat(*loss_inputs) + + aux_enc_outputs_coord = bbox_cxcywh_to_xyxy(aux_enc_outputs_coord) + aux_enc_outputs_coord_list = [] + for img_id in range(len(aux_enc_outputs_coord)): + det_bboxes = aux_enc_outputs_coord[img_id] + img_shape = batch_img_metas[img_id]['img_shape'] + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + aux_enc_outputs_coord_list.append(det_bboxes) + aux_enc_outputs_coord = torch.stack(aux_enc_outputs_coord_list) + aux_loss = self.aux_loss_for_dense.loss( + aux_enc_outputs_class.sigmoid(), aux_enc_outputs_coord, + [item.bboxes for item in batch_gt_instances], + [item.labels for item in batch_gt_instances], batch_img_metas) + for k, v in aux_loss.items(): + losses[f'aux_enc_{k}'] = v + + return losses + + def loss_by_feat( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + dn_meta: Dict[str, int], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries_total, cls_out_channels). + all_layers_bbox_preds (Tensor): Bbox coordinates of all decoder + layers. Each has shape (num_decoder_layers, bs, + num_queries_total, 4) with normalized coordinate format + (cx, cy, w, h). + enc_cls_scores (Tensor): The top k score of each point on + encoder feature map, has shape (bs, num_queries, + cls_out_channels). + enc_bbox_preds (Tensor): The proposal generated from points + with top k score, has shape (bs, num_queries, 4) with the + last dimension arranged as (cx, cy, w, h). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + (all_layers_matching_cls_scores, all_layers_matching_bbox_preds, + all_layers_denoising_cls_scores, all_layers_denoising_bbox_preds) = \ + self.split_outputs( + all_layers_cls_scores, all_layers_bbox_preds, dn_meta) + + num_dense_queries = dn_meta['num_dense_queries'] + num_layer = all_layers_matching_bbox_preds.size(0) + dense_all_layers_matching_cls_scores = all_layers_matching_cls_scores[:, :, # noqa: E501 + -num_dense_queries:] # noqa: E501 + dense_all_layers_matching_bbox_preds = all_layers_matching_bbox_preds[:, :, # noqa: E501 + -num_dense_queries:] # noqa: E501 + + all_layers_matching_cls_scores = all_layers_matching_cls_scores[:, :, : # noqa: E501 + -num_dense_queries] # noqa: E501 + all_layers_matching_bbox_preds = all_layers_matching_bbox_preds[:, :, : # noqa: E501 + -num_dense_queries] # noqa: E501 + + loss_dict = self.loss_for_distinct_queries( + all_layers_matching_cls_scores, all_layers_matching_bbox_preds, + batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) + + if enc_cls_scores is not None: + + enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ + self.loss_by_feat_single( + enc_cls_scores, enc_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + loss_dict['enc_loss_iou'] = enc_losses_iou + + if all_layers_denoising_cls_scores is not None: + dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn( + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + dn_meta=dn_meta) + loss_dict['dn_loss_cls'] = dn_losses_cls[-1] + loss_dict['dn_loss_bbox'] = dn_losses_bbox[-1] + loss_dict['dn_loss_iou'] = dn_losses_iou[-1] + for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in \ + enumerate(zip(dn_losses_cls[:-1], dn_losses_bbox[:-1], + dn_losses_iou[:-1])): + loss_dict[f'd{num_dec_layer}.dn_loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.dn_loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.dn_loss_iou'] = loss_iou_i + + for l_id in range(num_layer): + cls_scores = dense_all_layers_matching_cls_scores[l_id].sigmoid() + bbox_preds = dense_all_layers_matching_bbox_preds[l_id] + + bbox_preds = bbox_cxcywh_to_xyxy(bbox_preds) + bbox_preds_list = [] + for img_id in range(len(bbox_preds)): + det_bboxes = bbox_preds[img_id] + img_shape = batch_img_metas[img_id]['img_shape'] + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + bbox_preds_list.append(det_bboxes) + bbox_preds = torch.stack(bbox_preds_list) + aux_loss = self.aux_loss_for_dense.loss( + cls_scores, bbox_preds, + [item.bboxes for item in batch_gt_instances], + [item.labels for item in batch_gt_instances], batch_img_metas) + for k, v in aux_loss.items(): + loss_dict[f'{l_id}_aux_{k}'] = v + + return loss_dict + + def loss_for_distinct_queries( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss of distinct queries, that is, excluding denoising + and dense queries. Only select the distinct queries in decoder for + loss. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries, cls_out_channels). + all_layers_bbox_preds (Tensor): Bbox coordinates of all decoder + layers. It has shape (num_decoder_layers, bs, + num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert batch_gt_instances_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + 'for batch_gt_instances_ignore setting to None.' + + losses_cls, losses_bbox, losses_iou = multi_apply( + self._loss_for_distinct_queries_single, + all_layers_cls_scores, + all_layers_bbox_preds, + [i for i in range(len(all_layers_bbox_preds))], + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + loss_dict['loss_iou'] = losses_iou[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i in \ + zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i + num_dec_layer += 1 + return loss_dict + + def _loss_for_distinct_queries_single(self, cls_scores, bbox_preds, l_id, + batch_gt_instances, batch_img_metas): + """Calculate the loss for outputs from a single decoder layer of + distinct queries, that is, excluding denoising and dense queries. Only + select the distinct queries in decoder for loss. + + Args: + cls_scores (Tensor): Classification scores of a single + decoder layer, has shape (bs, num_queries, cls_out_channels). + bbox_preds (Tensor): Bbox coordinates of a single decoder + layer. It has shape (bs, num_queries, 4) with the last + dimension arranged as (cx, cy, w, h). + l_id (int): Decoder layer index for these outputs. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + num_imgs = cls_scores.size(0) + if 0 < l_id: + batch_mask = [ + self.cache_dict['distinct_query_mask'][l_id - 1][ + img_id * self.cache_dict['num_heads']][0] + for img_id in range(num_imgs) + ] + else: + batch_mask = [ + torch.ones(len(cls_scores[i]), + device=cls_scores.device).bool() + for i in range(num_imgs) + ] + # only select the distinct queries in decoder for loss + cls_scores_list = [ + cls_scores[i][batch_mask[i]] for i in range(num_imgs) + ] + bbox_preds_list = [ + bbox_preds[i][batch_mask[i]] for i in range(num_imgs) + ] + cls_scores = torch.cat(cls_scores_list) + + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + batch_gt_instances, batch_img_metas) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds_list): + img_h, img_w, = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = torch.cat(bbox_preds_list) + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou + + def predict_by_feat(self, + layer_cls_scores: Tensor, + layer_bbox_preds: Tensor, + batch_img_metas: List[dict], + rescale: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + layer_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries, cls_out_channels). + layer_bbox_preds (Tensor): Bbox coordinates of all decoder layers. + Each has shape (num_decoder_layers, bs, num_queries, 4) + with normalized coordinate format (cx, cy, w, h). + batch_img_metas (list[dict]): Meta information of each image. + rescale (bool, optional): If `True`, return boxes in original + image space. Default `False`. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + cls_scores = layer_cls_scores[-1] + bbox_preds = layer_bbox_preds[-1] + + num_imgs = cls_scores.size(0) + # -1 is last layer input query mask + + batch_mask = [ + self.cache_dict['distinct_query_mask'][-1][ + img_id * self.cache_dict['num_heads']][0] + for img_id in range(num_imgs) + ] + + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = cls_scores[img_id][batch_mask[img_id]] + bbox_pred = bbox_preds[img_id][batch_mask[img_id]] + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + img_meta, rescale) + result_list.append(results) + return result_list diff --git a/mmdet/models/dense_heads/deformable_detr_head.py b/mmdet/models/dense_heads/deformable_detr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..adedd4aa6b533bcfece618eed4045c95bf0fdebb --- /dev/null +++ b/mmdet/models/dense_heads/deformable_detr_head.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import Linear +from mmengine.model import bias_init_with_prob, constant_init +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import InstanceList, OptInstanceList +from ..layers import inverse_sigmoid +from .detr_head import DETRHead + + +@MODELS.register_module() +class DeformableDETRHead(DETRHead): + r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for + End-to-End Object Detection. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + share_pred_layer (bool): Whether to share parameters for all the + prediction layers. Defaults to `False`. + num_pred_layer (int): The number of the prediction layers. + Defaults to 6. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + """ + + def __init__(self, + *args, + share_pred_layer: bool = False, + num_pred_layer: int = 6, + as_two_stage: bool = False, + **kwargs) -> None: + self.share_pred_layer = share_pred_layer + self.num_pred_layer = num_pred_layer + self.as_two_stage = as_two_stage + + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize classification branch and regression branch of head.""" + fc_cls = Linear(self.embed_dims, self.cls_out_channels) + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, 4)) + reg_branch = nn.Sequential(*reg_branch) + + if self.share_pred_layer: + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(self.num_pred_layer)]) + else: + self.cls_branches = nn.ModuleList( + [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList([ + copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) + ]) + + def init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, bias_init) + for m in self.reg_branches: + constant_init(m[-1], 0, bias=0) + nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) + if self.as_two_stage: + for m in self.reg_branches: + nn.init.constant_(m[-1].bias.data[2:], 0.0) + + def forward(self, hidden_states: Tensor, + references: List[Tensor]) -> Tuple[Tensor, Tensor]: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - all_layers_outputs_classes (Tensor): Outputs from the + classification head, has shape (num_decoder_layers, bs, + num_queries, cls_out_channels). + - all_layers_outputs_coords (Tensor): Sigmoid outputs from the + regression head with normalized coordinate format (cx, cy, w, + h), has shape (num_decoder_layers, bs, num_queries, 4) with the + last dimension arranged as (cx, cy, w, h). + """ + all_layers_outputs_classes = [] + all_layers_outputs_coords = [] + + for layer_id in range(hidden_states.shape[0]): + reference = inverse_sigmoid(references[layer_id]) + # NOTE The last reference will not be used. + hidden_state = hidden_states[layer_id] + outputs_class = self.cls_branches[layer_id](hidden_state) + tmp_reg_preds = self.reg_branches[layer_id](hidden_state) + if reference.shape[-1] == 4: + # When `layer` is 0 and `as_two_stage` of the detector + # is `True`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `True`. + tmp_reg_preds += reference + else: + # When `layer` is 0 and `as_two_stage` of the detector + # is `False`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `False`. + assert reference.shape[-1] == 2 + tmp_reg_preds[..., :2] += reference + outputs_coord = tmp_reg_preds.sigmoid() + all_layers_outputs_classes.append(outputs_class) + all_layers_outputs_coords.append(outputs_coord) + + all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) + all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) + + return all_layers_outputs_classes, all_layers_outputs_coords + + def loss(self, hidden_states: Tensor, references: List[Tensor], + enc_outputs_class: Tensor, enc_outputs_coord: Tensor, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_queries, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + enc_outputs_class (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + Only when `as_two_stage` is `True` it would be passed in, + otherwise it would be `None`. + enc_outputs_coord (Tensor): The proposal generate from the encode + feature map, has shape (bs, num_feat_points, 4) with the last + dimension arranged as (cx, cy, w, h). Only when `as_two_stage` + is `True` it would be passed in, otherwise it would be `None`. + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references) + loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, + batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, num_queries, + cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and has shape (num_decoder_layers, bs, + num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + enc_cls_scores (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + Only when `as_two_stage` is `True` it would be passes in, + otherwise, it would be `None`. + enc_bbox_preds (Tensor): The proposal generate from the encode + feature map, has shape (bs, num_feat_points, 4) with the last + dimension arranged as (cx, cy, w, h). Only when `as_two_stage` + is `True` it would be passed in, otherwise it would be `None`. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + loss_dict = super().loss_by_feat(all_layers_cls_scores, + all_layers_bbox_preds, + batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + proposal_gt_instances = copy.deepcopy(batch_gt_instances) + for i in range(len(proposal_gt_instances)): + proposal_gt_instances[i].labels = torch.zeros_like( + proposal_gt_instances[i].labels) + enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ + self.loss_by_feat_single( + enc_cls_scores, enc_bbox_preds, + batch_gt_instances=proposal_gt_instances, + batch_img_metas=batch_img_metas) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + loss_dict['enc_loss_iou'] = enc_losses_iou + return loss_dict + + def predict(self, + hidden_states: Tensor, + references: List[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_queries, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + outs = self(hidden_states, references) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions + + def predict_by_feat(self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + batch_img_metas: List[Dict], + rescale: bool = False) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, num_queries, + cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, + 4) with the last dimension arranged as (cx, cy, w, h). + batch_img_metas (list[dict]): Meta information of each image. + rescale (bool, optional): If `True`, return boxes in original + image space. Default `False`. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + cls_scores = all_layers_cls_scores[-1] + bbox_preds = all_layers_bbox_preds[-1] + + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + img_meta, rescale) + result_list.append(results) + return result_list diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..a7526d48430d6bc6b82777980d0bef418e80b91c --- /dev/null +++ b/mmdet/models/dense_heads/dense_test_mixins.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +import warnings +from inspect import signature + +import torch +from mmcv.ops import batched_nms +from mmengine.structures import InstanceData + +from mmdet.structures.bbox import bbox_mapping_back +from ..test_time_augs import merge_aug_proposals + +if sys.version_info >= (3, 7): + from mmdet.utils.contextmanagers import completed + + +class BBoxTestMixin(object): + """Mixin class for testing det bboxes via DenseHead.""" + + def simple_test_bboxes(self, feats, img_metas, rescale=False): + """Test det bboxes without test-time augmentation, can be applied in + DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``, + etc. + + Args: + feats (tuple[torch.Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each + image after the post process. \ + Each item usually contains following keys. \ + + - scores (Tensor): Classification scores, has a shape + (num_instance,) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances,). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + warnings.warn('You are calling `simple_test_bboxes` in ' + '`dense_test_mixins`, but the `dense_test_mixins`' + 'will be deprecated soon. Please use ' + '`simple_test` instead.') + outs = self.forward(feats) + results_list = self.get_results( + *outs, img_metas=img_metas, rescale=rescale) + return results_list + + def aug_test_bboxes(self, feats, img_metas, rescale=False): + """Test det bboxes with test time augmentation, can be applied in + DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``, + etc. + + Args: + feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is ``bboxes`` with shape (n, 5), + where 5 represent (tl_x, tl_y, br_x, br_y, score). + The shape of the second tensor in the tuple is ``labels`` + with shape (n,). The length of list should always be 1. + """ + + warnings.warn('You are calling `aug_test_bboxes` in ' + '`dense_test_mixins`, but the `dense_test_mixins`' + 'will be deprecated soon. Please use ' + '`aug_test` instead.') + # check with_nms argument + gb_sig = signature(self.get_results) + gb_args = [p.name for p in gb_sig.parameters.values()] + gbs_sig = signature(self._get_results_single) + gbs_args = [p.name for p in gbs_sig.parameters.values()] + assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ + f'{self.__class__.__name__}' \ + ' does not support test-time augmentation' + + aug_bboxes = [] + aug_scores = [] + aug_labels = [] + for x, img_meta in zip(feats, img_metas): + # only one image in the batch + outs = self.forward(x) + bbox_outputs = self.get_results( + *outs, + img_metas=img_meta, + cfg=self.test_cfg, + rescale=False, + with_nms=False)[0] + aug_bboxes.append(bbox_outputs.bboxes) + aug_scores.append(bbox_outputs.scores) + if len(bbox_outputs) >= 3: + aug_labels.append(bbox_outputs.labels) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes, merged_scores = self.merge_aug_bboxes( + aug_bboxes, aug_scores, img_metas) + merged_labels = torch.cat(aug_labels, dim=0) if aug_labels else None + + if merged_bboxes.numel() == 0: + det_bboxes = torch.cat([merged_bboxes, merged_scores[:, None]], -1) + return [ + (det_bboxes, merged_labels), + ] + + det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores, + merged_labels, self.test_cfg.nms) + det_bboxes = det_bboxes[:self.test_cfg.max_per_img] + det_labels = merged_labels[keep_idxs][:self.test_cfg.max_per_img] + + if rescale: + _det_bboxes = det_bboxes + else: + _det_bboxes = det_bboxes.clone() + _det_bboxes[:, :4] *= det_bboxes.new_tensor( + img_metas[0][0]['scale_factor']) + + results = InstanceData() + results.bboxes = _det_bboxes[:, :4] + results.scores = _det_bboxes[:, 4] + results.labels = det_labels + return [results] + + def aug_test_rpn(self, feats, img_metas): + """Test with augmentation for only for ``RPNHead`` and its variants, + e.g., ``GARPNHead``, etc. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + img_metas (list[dict]): Meta info of each image. + + Returns: + list[Tensor]: Proposals of each image, each item has shape (n, 5), + where 5 represent (tl_x, tl_y, br_x, br_y, score). + """ + samples_per_gpu = len(img_metas[0]) + aug_proposals = [[] for _ in range(samples_per_gpu)] + for x, img_meta in zip(feats, img_metas): + results_list = self.simple_test_rpn(x, img_meta) + for i, results in enumerate(results_list): + proposals = torch.cat( + [results.bboxes, results.scores[:, None]], dim=-1) + aug_proposals[i].append(proposals) + # reorganize the order of 'img_metas' to match the dimensions + # of 'aug_proposals' + aug_img_metas = [] + for i in range(samples_per_gpu): + aug_img_meta = [] + for j in range(len(img_metas)): + aug_img_meta.append(img_metas[j][i]) + aug_img_metas.append(aug_img_meta) + # after merging, proposals will be rescaled to the original image size + + merged_proposals = [] + for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas): + merged_proposal = merge_aug_proposals(proposals, aug_img_meta, + self.test_cfg) + results = InstanceData() + results.bboxes = merged_proposal[:, :4] + results.scores = merged_proposal[:, 4] + merged_proposals.append(results) + return merged_proposals + + if sys.version_info >= (3, 7): + + async def async_simple_test_rpn(self, x, img_metas): + sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025) + async with completed( + __name__, 'rpn_head_forward', + sleep_interval=sleep_interval): + rpn_outs = self(x) + + proposal_list = self.get_results(*rpn_outs, img_metas=img_metas) + return proposal_list + + def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): + """Merge augmented detection bboxes and scores. + + Args: + aug_bboxes (list[Tensor]): shape (n, 4*#class) + aug_scores (list[Tensor] or None): shape (n, #class) + img_shapes (list[Tensor]): shape (3, ). + + Returns: + tuple[Tensor]: ``bboxes`` with shape (n,4), where + 4 represent (tl_x, tl_y, br_x, br_y) + and ``scores`` with shape (n,). + """ + recovered_bboxes = [] + for bboxes, img_info in zip(aug_bboxes, img_metas): + img_shape = img_info[0]['img_shape'] + scale_factor = img_info[0]['scale_factor'] + flip = img_info[0]['flip'] + flip_direction = img_info[0]['flip_direction'] + bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, + flip_direction) + recovered_bboxes.append(bboxes) + bboxes = torch.cat(recovered_bboxes, dim=0) + if aug_scores is None: + return bboxes + else: + scores = torch.cat(aug_scores, dim=0) + return bboxes, scores diff --git a/mmdet/models/dense_heads/detr_head.py b/mmdet/models/dense_heads/detr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9daeb4740057c1f07095ffbf97b73ea40fc93106 --- /dev/null +++ b/mmdet/models/dense_heads/detr_head.py @@ -0,0 +1,634 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Linear +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) +from mmdet.utils import (ConfigType, InstanceList, OptInstanceList, + OptMultiConfig, reduce_mean) +from ..losses import QualityFocalLoss +from ..utils import multi_apply + + +@MODELS.register_module() +class DETRHead(BaseModule): + r"""Head of DETR. DETR:End-to-End Object Detection with Transformers. + + More details can be found in the `paper + `_ . + + Args: + num_classes (int): Number of categories excluding the background. + embed_dims (int): The dims of Transformer embedding. + num_reg_fcs (int): Number of fully-connected layers used in `FFN`, + which is then used for the regression head. Defaults to 2. + sync_cls_avg_factor (bool): Whether to sync the `avg_factor` of + all ranks. Default to `False`. + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to `CrossEntropyLoss`. + loss_bbox (:obj:`ConfigDict` or dict): Config of the regression bbox + loss. Defaults to `L1Loss`. + loss_iou (:obj:`ConfigDict` or dict): Config of the regression iou + loss. Defaults to `GIoULoss`. + train_cfg (:obj:`ConfigDict` or dict): Training config of transformer + head. + test_cfg (:obj:`ConfigDict` or dict): Testing config of transformer + head. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + _version = 2 + + def __init__( + self, + num_classes: int, + embed_dims: int = 256, + num_reg_fcs: int = 2, + sync_cls_avg_factor: bool = False, + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_bbox: ConfigType = dict(type='L1Loss', loss_weight=5.0), + loss_iou: ConfigType = dict(type='GIoULoss', loss_weight=2.0), + train_cfg: ConfigType = dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='ClassificationCost', weight=1.), + dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), + dict(type='IoUCost', iou_mode='giou', weight=2.0) + ])), + test_cfg: ConfigType = dict(max_per_img=100), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.bg_cls_weight = 0 + self.sync_cls_avg_factor = sync_cls_avg_factor + class_weight = loss_cls.get('class_weight', None) + if class_weight is not None and (self.__class__ is DETRHead): + assert isinstance(class_weight, float), 'Expected ' \ + 'class_weight to have type float. Found ' \ + f'{type(class_weight)}.' + # NOTE following the official DETR repo, bg_cls_weight means + # relative classification weight of the no-object class. + bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) + assert isinstance(bg_cls_weight, float), 'Expected ' \ + 'bg_cls_weight to have type float. Found ' \ + f'{type(bg_cls_weight)}.' + class_weight = torch.ones(num_classes + 1) * class_weight + # set background class as the last indice + class_weight[num_classes] = bg_cls_weight + loss_cls.update({'class_weight': class_weight}) + if 'bg_cls_weight' in loss_cls: + loss_cls.pop('bg_cls_weight') + self.bg_cls_weight = bg_cls_weight + + if train_cfg: + assert 'assigner' in train_cfg, 'assigner should be provided ' \ + 'when train_cfg is set.' + assigner = train_cfg['assigner'] + self.assigner = TASK_UTILS.build(assigner) + if train_cfg.get('sampler', None) is not None: + raise RuntimeError('DETR do not build sampler.') + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_reg_fcs = num_reg_fcs + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.loss_iou = MODELS.build(loss_iou) + + if self.loss_cls.use_sigmoid: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the transformer head.""" + # cls branch + self.fc_cls = Linear(self.embed_dims, self.cls_out_channels) + # reg branch + self.activate = nn.ReLU() + self.reg_ffn = FFN( + self.embed_dims, + self.embed_dims, + self.num_reg_fcs, + dict(type='ReLU', inplace=True), + dropout=0.0, + add_residual=False) + # NOTE the activations of reg_branch here is the same as + # those in transformer, but they are actually different + # in DAB-DETR (prelu in transformer and relu in reg_branch) + self.fc_reg = Linear(self.embed_dims, 4) + + def forward(self, hidden_states: Tensor) -> Tuple[Tensor]: + """"Forward function. + + Args: + hidden_states (Tensor): Features from transformer decoder. If + `return_intermediate_dec` in detr.py is True output has shape + (num_decoder_layers, bs, num_queries, dim), else has shape + (1, bs, num_queries, dim) which only contains the last layer + outputs. + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - layers_cls_scores (Tensor): Outputs from the classification head, + shape (num_decoder_layers, bs, num_queries, cls_out_channels). + Note cls_out_channels should include background. + - layers_bbox_preds (Tensor): Sigmoid outputs from the regression + head with normalized coordinate format (cx, cy, w, h), has shape + (num_decoder_layers, bs, num_queries, 4). + """ + layers_cls_scores = self.fc_cls(hidden_states) + layers_bbox_preds = self.fc_reg( + self.activate(self.reg_ffn(hidden_states))).sigmoid() + return layers_cls_scores, layers_bbox_preds + + def loss(self, hidden_states: Tensor, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + hidden_states (Tensor): Feature from the transformer decoder, has + shape (num_decoder_layers, bs, num_queries, cls_out_channels) + or (num_decoder_layers, num_queries, bs, cls_out_channels). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """"Loss function. + + Only outputs from the last feature level are used for computing + losses by default. + + Args: + all_layers_cls_scores (Tensor): Classification outputs + of each decoder layers. Each is a 4D-tensor, has shape + (num_decoder_layers, bs, num_queries, cls_out_channels). + all_layers_bbox_preds (Tensor): Sigmoid regression + outputs of each decoder layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + (num_decoder_layers, bs, num_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert batch_gt_instances_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + 'for batch_gt_instances_ignore setting to None.' + + losses_cls, losses_bbox, losses_iou = multi_apply( + self.loss_by_feat_single, + all_layers_cls_scores, + all_layers_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + loss_dict['loss_iou'] = losses_iou[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i in \ + zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i + num_dec_layer += 1 + return loss_dict + + def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images, has shape (bs, num_queries, cls_out_channels). + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape (bs, num_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + batch_gt_instances, batch_img_metas) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + if isinstance(self.loss_cls, QualityFocalLoss): + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + scores = label_weights.new_zeros(labels.shape) + pos_bbox_targets = bbox_targets[pos_inds] + pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) + pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds] + pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) + scores[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + loss_cls = self.loss_cls( + cls_scores, (labels, scores), + label_weights, + avg_factor=cls_avg_factor) + else: + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds): + img_h, img_w, = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou + + def get_targets(self, cls_scores_list: List[Tensor], + bbox_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict]) -> tuple: + """Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image, has shape [num_queries, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_queries, 4]. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all images. + - bbox_targets_list (list[Tensor]): BBox targets for all images. + - bbox_weights_list (list[Tensor]): BBox weights for all images. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_targets_single, + cls_scores_list, bbox_preds_list, + batch_gt_instances, batch_img_metas) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + + def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> tuple: + """Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_queries, 4]. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + img_h, img_w = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + num_bboxes = bbox_pred.size(0) + # convert bbox_pred from xywh, normalized to xyxy, unnormalized + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + bbox_pred = bbox_pred * factor + + pred_instances = InstanceData(scores=cls_score, bboxes=bbox_pred) + # assigner and sampler + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=gt_instances, + img_meta=img_meta) + + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long(), :] + + # label targets + labels = gt_bboxes.new_full((num_bboxes, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred, dtype=gt_bboxes.dtype) + bbox_weights = torch.zeros_like(bbox_pred, dtype=gt_bboxes.dtype) + bbox_weights[pos_inds] = 1.0 + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + pos_gt_bboxes_normalized = pos_gt_bboxes / factor + pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + bbox_targets[pos_inds] = pos_gt_bboxes_targets + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + def loss_and_predict( + self, hidden_states: Tuple[Tensor], + batch_data_samples: SampleList) -> Tuple[dict, InstanceList]: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. Over-write because + img_metas are needed as inputs for bbox_head. + + Args: + hidden_states (tuple[Tensor]): Feature from the transformer + decoder, has shape (num_decoder_layers, bs, num_queries, dim). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple: the return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each image after the post process. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas) + return losses, predictions + + def predict(self, + hidden_states: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. Over-write + because img_metas are needed as inputs for bbox_head. + + Args: + hidden_states (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + last_layer_hidden_state = hidden_states[-1].unsqueeze(0) + outs = self(last_layer_hidden_state) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + + return predictions + + def predict_by_feat(self, + layer_cls_scores: Tensor, + layer_bbox_preds: Tensor, + batch_img_metas: List[dict], + rescale: bool = True) -> InstanceList: + """Transform network outputs for a batch into bbox predictions. + + Args: + layer_cls_scores (Tensor): Classification outputs of the last or + all decoder layer. Each is a 4D-tensor, has shape + (num_decoder_layers, bs, num_queries, cls_out_channels). + layer_bbox_preds (Tensor): Sigmoid regression outputs of the last + or all decoder layer. Each is a 4D-tensor with normalized + coordinate format (cx, cy, w, h) and shape + (num_decoder_layers, bs, num_queries, 4). + batch_img_metas (list[dict]): Meta information of each image. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + # NOTE only using outputs from the last feature level, + # and only the outputs from the last decoder layer is used. + cls_scores = layer_cls_scores[-1] + bbox_preds = layer_bbox_preds[-1] + + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + img_meta, rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score: Tensor, + bbox_pred: Tensor, + img_meta: dict, + rescale: bool = True) -> InstanceData: + """Transform outputs from the last decoder layer into bbox predictions + for each image. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (cx, cy, w, h) and + shape [num_queries, 4]. + img_meta (dict): Image meta info. + rescale (bool): If True, return boxes in original image + space. Default True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_score) == len(bbox_pred) # num_queries + max_per_img = self.test_cfg.get('max_per_img', len(cls_score)) + img_shape = img_meta['img_shape'] + # exclude background + if self.loss_cls.use_sigmoid: + cls_score = cls_score.sigmoid() + scores, indexes = cls_score.view(-1).topk(max_per_img) + det_labels = indexes % self.num_classes + bbox_index = indexes // self.num_classes + bbox_pred = bbox_pred[bbox_index] + else: + scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1) + scores, bbox_index = scores.topk(max_per_img) + bbox_pred = bbox_pred[bbox_index] + det_labels = det_labels[bbox_index] + + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + if rescale: + assert img_meta.get('scale_factor') is not None + det_bboxes /= det_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + + results = InstanceData() + results.bboxes = det_bboxes + results.scores = scores + results.labels = det_labels + return results diff --git a/mmdet/models/dense_heads/dino_head.py b/mmdet/models/dense_heads/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..54f46d1474f97f2d183926a6dc68a0be79f7cef1 --- /dev/null +++ b/mmdet/models/dense_heads/dino_head.py @@ -0,0 +1,479 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) +from mmdet.utils import InstanceList, OptInstanceList, reduce_mean +from ..losses import QualityFocalLoss +from ..utils import multi_apply +from .deformable_detr_head import DeformableDETRHead + + +@MODELS.register_module() +class DINOHead(DeformableDETRHead): + r"""Head of the DINO: DETR with Improved DeNoising Anchor Boxes + for End-to-End Object Detection + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + """ + + def loss(self, hidden_states: Tensor, references: List[Tensor], + enc_outputs_class: Tensor, enc_outputs_coord: Tensor, + batch_data_samples: SampleList, dn_meta: Dict[str, int]) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries_total, + dim), where `num_queries_total` is the sum of + `num_denoising_queries` and `num_matching_queries` when + `self.training` is `True`, else `num_matching_queries`. + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries_total, 4) and each `inter_reference` has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + enc_outputs_class (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + enc_outputs_coord (Tensor): The proposal generate from the + encode feature map, has shape (bs, num_feat_points, 4) with the + last dimension arranged as (cx, cy, w, h). + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references) + loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, + batch_gt_instances, batch_img_metas, dn_meta) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + dn_meta: Dict[str, int], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries_total, cls_out_channels), where + `num_queries_total` is the sum of `num_denoising_queries` + and `num_matching_queries`. + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and has shape (num_decoder_layers, bs, + num_queries_total, 4). + enc_cls_scores (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + enc_bbox_preds (Tensor): The proposal generate from the encode + feature map, has shape (bs, num_feat_points, 4) with the last + dimension arranged as (cx, cy, w, h). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + # extract denoising and matching part of outputs + (all_layers_matching_cls_scores, all_layers_matching_bbox_preds, + all_layers_denoising_cls_scores, all_layers_denoising_bbox_preds) = \ + self.split_outputs( + all_layers_cls_scores, all_layers_bbox_preds, dn_meta) + + loss_dict = super(DeformableDETRHead, self).loss_by_feat( + all_layers_matching_cls_scores, all_layers_matching_bbox_preds, + batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) + # NOTE DETRHead.loss_by_feat but not DeformableDETRHead.loss_by_feat + # is called, because the encoder loss calculations are different + # between DINO and DeformableDETR. + + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + # NOTE The enc_loss calculation of the DINO is + # different from that of Deformable DETR. + enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ + self.loss_by_feat_single( + enc_cls_scores, enc_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + loss_dict['enc_loss_iou'] = enc_losses_iou + + if all_layers_denoising_cls_scores is not None: + # calculate denoising loss from all decoder layers + dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn( + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + dn_meta=dn_meta) + # collate denoising loss + loss_dict['dn_loss_cls'] = dn_losses_cls[-1] + loss_dict['dn_loss_bbox'] = dn_losses_bbox[-1] + loss_dict['dn_loss_iou'] = dn_losses_iou[-1] + for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in \ + enumerate(zip(dn_losses_cls[:-1], dn_losses_bbox[:-1], + dn_losses_iou[:-1])): + loss_dict[f'd{num_dec_layer}.dn_loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.dn_loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.dn_loss_iou'] = loss_iou_i + return loss_dict + + def loss_dn(self, all_layers_denoising_cls_scores: Tensor, + all_layers_denoising_bbox_preds: Tensor, + batch_gt_instances: InstanceList, batch_img_metas: List[dict], + dn_meta: Dict[str, int]) -> Tuple[List[Tensor]]: + """Calculate denoising loss. + + Args: + all_layers_denoising_cls_scores (Tensor): Classification scores of + all decoder layers in denoising part, has shape ( + num_decoder_layers, bs, num_denoising_queries, + cls_out_channels). + all_layers_denoising_bbox_preds (Tensor): Regression outputs of all + decoder layers in denoising part. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and has shape + (num_decoder_layers, bs, num_denoising_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + Tuple[List[Tensor]]: The loss_dn_cls, loss_dn_bbox, and loss_dn_iou + of each decoder layers. + """ + return multi_apply( + self._loss_dn_single, + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + dn_meta=dn_meta) + + def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + dn_meta: Dict[str, int]) -> Tuple[Tensor]: + """Denoising loss for outputs from a single decoder layer. + + Args: + dn_cls_scores (Tensor): Classification scores of a single decoder + layer in denoising part, has shape (bs, num_denoising_queries, + cls_out_channels). + dn_bbox_preds (Tensor): Regression outputs of a single decoder + layer in denoising part. Each is a 4D-tensor with normalized + coordinate format (cx, cy, w, h) and has shape + (bs, num_denoising_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + cls_reg_targets = self.get_dn_targets(batch_gt_instances, + batch_img_metas, dn_meta) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # classification loss + cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = \ + num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + if len(cls_scores) > 0: + if isinstance(self.loss_cls, QualityFocalLoss): + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + scores = label_weights.new_zeros(labels.shape) + pos_bbox_targets = bbox_targets[pos_inds] + pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) + pos_bbox_pred = dn_bbox_preds.reshape(-1, 4)[pos_inds] + pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) + scores[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + loss_cls = self.loss_cls( + cls_scores, (labels, scores), + weight=label_weights, + avg_factor=cls_avg_factor) + else: + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=cls_avg_factor) + else: + loss_cls = torch.zeros( + 1, dtype=cls_scores.dtype, device=cls_scores.device) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds): + img_h, img_w = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = dn_bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou + + def get_dn_targets(self, batch_gt_instances: InstanceList, + batch_img_metas: dict, dn_meta: Dict[str, + int]) -> tuple: + """Get targets in denoising part for a batch of images. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all images. + - bbox_targets_list (list[Tensor]): BBox targets for all images. + - bbox_weights_list (list[Tensor]): BBox weights for all images. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + pos_inds_list, neg_inds_list) = multi_apply( + self._get_dn_targets_single, + batch_gt_instances, + batch_img_metas, + dn_meta=dn_meta) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + + def _get_dn_targets_single(self, gt_instances: InstanceData, + img_meta: dict, dn_meta: Dict[str, + int]) -> tuple: + """Get targets in denoising part for one image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + num_groups = dn_meta['num_denoising_groups'] + num_denoising_queries = dn_meta['num_denoising_queries'] + num_queries_each_group = int(num_denoising_queries / num_groups) + device = gt_bboxes.device + + if len(gt_labels) > 0: + t = torch.arange(len(gt_labels), dtype=torch.long, device=device) + t = t.unsqueeze(0).repeat(num_groups, 1) + pos_assigned_gt_inds = t.flatten() + pos_inds = torch.arange( + num_groups, dtype=torch.long, device=device) + pos_inds = pos_inds.unsqueeze(1) * num_queries_each_group + t + pos_inds = pos_inds.flatten() + else: + pos_inds = pos_assigned_gt_inds = \ + gt_bboxes.new_tensor([], dtype=torch.long) + + neg_inds = pos_inds + num_queries_each_group // 2 + + # label targets + labels = gt_bboxes.new_full((num_denoising_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_denoising_queries) + + # bbox targets + bbox_targets = torch.zeros(num_denoising_queries, 4, device=device) + bbox_weights = torch.zeros(num_denoising_queries, 4, device=device) + bbox_weights[pos_inds] = 1.0 + img_h, img_w = img_meta['img_shape'] + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + factor = gt_bboxes.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + gt_bboxes_normalized = gt_bboxes / factor + gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized) + bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1]) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + @staticmethod + def split_outputs(all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + dn_meta: Dict[str, int]) -> Tuple[Tensor]: + """Split outputs of the denoising part and the matching part. + + For the total outputs of `num_queries_total` length, the former + `num_denoising_queries` outputs are from denoising queries, and + the rest `num_matching_queries` ones are from matching queries, + where `num_queries_total` is the sum of `num_denoising_queries` and + `num_matching_queries`. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries_total, cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and has shape (num_decoder_layers, bs, + num_queries_total, 4). + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. + + Returns: + Tuple[Tensor]: a tuple containing the following outputs. + + - all_layers_matching_cls_scores (Tensor): Classification scores + of all decoder layers in matching part, has shape + (num_decoder_layers, bs, num_matching_queries, cls_out_channels). + - all_layers_matching_bbox_preds (Tensor): Regression outputs of + all decoder layers in matching part. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and has shape + (num_decoder_layers, bs, num_matching_queries, 4). + - all_layers_denoising_cls_scores (Tensor): Classification scores + of all decoder layers in denoising part, has shape + (num_decoder_layers, bs, num_denoising_queries, + cls_out_channels). + - all_layers_denoising_bbox_preds (Tensor): Regression outputs of + all decoder layers in denoising part. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and has shape + (num_decoder_layers, bs, num_denoising_queries, 4). + """ + num_denoising_queries = dn_meta['num_denoising_queries'] + if dn_meta is not None: + all_layers_denoising_cls_scores = \ + all_layers_cls_scores[:, :, : num_denoising_queries, :] + all_layers_denoising_bbox_preds = \ + all_layers_bbox_preds[:, :, : num_denoising_queries, :] + all_layers_matching_cls_scores = \ + all_layers_cls_scores[:, :, num_denoising_queries:, :] + all_layers_matching_bbox_preds = \ + all_layers_bbox_preds[:, :, num_denoising_queries:, :] + else: + all_layers_denoising_cls_scores = None + all_layers_denoising_bbox_preds = None + all_layers_matching_cls_scores = all_layers_cls_scores + all_layers_matching_bbox_preds = all_layers_bbox_preds + return (all_layers_matching_cls_scores, all_layers_matching_bbox_preds, + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds) diff --git a/mmdet/models/dense_heads/embedding_rpn_head.py b/mmdet/models/dense_heads/embedding_rpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..97e84fa83b892c0274615d582fe43a6693541617 --- /dev/null +++ b/mmdet/models/dense_heads/embedding_rpn_head.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy +from mmdet.structures.det_data_sample import SampleList +from mmdet.utils import InstanceList, OptConfigType + + +@MODELS.register_module() +class EmbeddingRPNHead(BaseModule): + """RPNHead in the `Sparse R-CNN `_ . + + Unlike traditional RPNHead, this module does not need FPN input, but just + decode `init_proposal_bboxes` and expand the first dimension of + `init_proposal_bboxes` and `init_proposal_features` to the batch_size. + + Args: + num_proposals (int): Number of init_proposals. Defaults to 100. + proposal_feature_channel (int): Channel number of + init_proposal_feature. Defaults to 256. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. Defaults to None. + """ + + def __init__(self, + num_proposals: int = 100, + proposal_feature_channel: int = 256, + init_cfg: OptConfigType = None, + **kwargs) -> None: + # `**kwargs` is necessary to avoid some potential error. + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__(init_cfg=init_cfg) + self.num_proposals = num_proposals + self.proposal_feature_channel = proposal_feature_channel + self._init_layers() + + def _init_layers(self) -> None: + """Initialize a sparse set of proposal boxes and proposal features.""" + self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4) + self.init_proposal_features = nn.Embedding( + self.num_proposals, self.proposal_feature_channel) + + def init_weights(self) -> None: + """Initialize the init_proposal_bboxes as normalized. + + [c_x, c_y, w, h], and we initialize it to the size of the entire + image. + """ + super().init_weights() + nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5) + nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1) + + def _decode_init_proposals(self, x: List[Tensor], + batch_data_samples: SampleList) -> InstanceList: + """Decode init_proposal_bboxes according to the size of images and + expand dimension of init_proposal_features to batch_size. + + Args: + x (list[Tensor]): List of FPN features. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + List[:obj:`InstanceData`:] Detection results of each image. + Each item usually contains following keys. + + - proposals: Decoded proposal bboxes, + has shape (num_proposals, 4). + - features: init_proposal_features, expanded proposal + features, has shape + (num_proposals, proposal_feature_channel). + - imgs_whwh: Tensor with shape + (num_proposals, 4), the dimension means + [img_width, img_height, img_width, img_height]. + """ + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + + proposals = self.init_proposal_bboxes.weight.clone() + proposals = bbox_cxcywh_to_xyxy(proposals) + imgs_whwh = [] + for meta in batch_img_metas: + h, w = meta['img_shape'][:2] + imgs_whwh.append(x[0].new_tensor([[w, h, w, h]])) + imgs_whwh = torch.cat(imgs_whwh, dim=0) + imgs_whwh = imgs_whwh[:, None, :] + proposals = proposals * imgs_whwh + + rpn_results_list = [] + for idx in range(len(batch_img_metas)): + rpn_results = InstanceData() + rpn_results.bboxes = proposals[idx] + rpn_results.imgs_whwh = imgs_whwh[idx].repeat( + self.num_proposals, 1) + rpn_results.features = self.init_proposal_features.weight.clone() + rpn_results_list.append(rpn_results) + return rpn_results_list + + def loss(self, *args, **kwargs): + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network.""" + raise NotImplementedError( + 'EmbeddingRPNHead does not have `loss`, please use ' + '`predict` or `loss_and_predict` instead.') + + def predict(self, x: List[Tensor], batch_data_samples: SampleList, + **kwargs) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network.""" + # `**kwargs` is necessary to avoid some potential error. + return self._decode_init_proposals( + x=x, batch_data_samples=batch_data_samples) + + def loss_and_predict(self, x: List[Tensor], batch_data_samples: SampleList, + **kwargs) -> tuple: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples.""" + # `**kwargs` is necessary to avoid some potential error. + predictions = self._decode_init_proposals( + x=x, batch_data_samples=batch_data_samples) + + return dict(), predictions diff --git a/mmdet/models/dense_heads/fcos_head.py b/mmdet/models/dense_heads/fcos_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4d4640010c7e8e7c6a4db3e0fce887b4105217 --- /dev/null +++ b/mmdet/models/dense_heads/fcos_head.py @@ -0,0 +1,476 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import Scale +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.layers import NormedConv2d +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, + OptInstanceList, RangeType, reduce_mean) +from ..utils import multi_apply +from .anchor_free_head import AnchorFreeHead + +INF = 1e8 + + +@MODELS.register_module() +class FCOSHead(AnchorFreeHead): + """Anchor-free head used in `FCOS `_. + + The FCOS head does not use anchor boxes. Instead bounding boxes are + predicted at each pixel and a centerness measure is used to suppress + low-quality predictions. + Here norm_on_bbox, centerness_on_reg, dcn_on_last_conv are training + tricks used in official repo, which will bring remarkable mAP gains + of up to 4.9. Please see https://github.com/tianzhi0549/FCOS for + more detail. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + strides (Sequence[int] or Sequence[Tuple[int, int]]): Strides of points + in multiple feature levels. Defaults to (4, 8, 16, 32, 64). + regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple + level points. + center_sampling (bool): If true, use center sampling. + Defaults to False. + center_sample_radius (float): Radius of center sampling. + Defaults to 1.5. + norm_on_bbox (bool): If true, normalize the regression targets with + FPN strides. Defaults to False. + centerness_on_reg (bool): If true, position centerness on the + regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042. + Defaults to False. + conv_bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias of conv will be set as True if `norm_cfg` is + None, otherwise False. Defaults to "auto". + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. + loss_centerness (:obj:`ConfigDict`, or dict): Config of centerness + loss. + norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config norm layer. Defaults to + ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``. + cls_predictor_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config conv_cls. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + + Example: + >>> self = FCOSHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred, centerness = self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ # noqa: E501 + + def __init__(self, + num_classes: int, + in_channels: int, + regress_ranges: RangeType = ((-1, 64), (64, 128), (128, 256), + (256, 512), (512, INF)), + center_sampling: bool = False, + center_sample_radius: float = 1.5, + norm_on_bbox: bool = False, + centerness_on_reg: bool = False, + loss_cls: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox: ConfigType = dict(type='IoULoss', loss_weight=1.0), + loss_centerness: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + cls_predictor_cfg=None, + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='conv_cls', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + self.regress_ranges = regress_ranges + self.center_sampling = center_sampling + self.center_sample_radius = center_sample_radius + self.norm_on_bbox = norm_on_bbox + self.centerness_on_reg = centerness_on_reg + self.cls_predictor_cfg = cls_predictor_cfg + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + norm_cfg=norm_cfg, + init_cfg=init_cfg, + **kwargs) + self.loss_centerness = MODELS.build(loss_centerness) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + super()._init_layers() + self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + if self.cls_predictor_cfg is not None: + self.cls_predictor_cfg.pop('type') + self.conv_cls = NormedConv2d( + self.feat_channels, + self.cls_out_channels, + 1, + padding=0, + **self.cls_predictor_cfg) + + def forward( + self, x: Tuple[Tensor] + ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of each level outputs. + + - cls_scores (list[Tensor]): Box scores for each scale level, \ + each is a 4D-tensor, the channel number is \ + num_points * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for each \ + scale level, each is a 4D-tensor, the channel number is \ + num_points * 4. + - centernesses (list[Tensor]): centerness for each scale level, \ + each is a 4D-tensor, the channel number is num_points * 1. + """ + return multi_apply(self.forward_single, x, self.scales, self.strides) + + def forward_single(self, x: Tensor, scale: Scale, + stride: int) -> Tuple[Tensor, Tensor, Tensor]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps, only + used to normalize the bbox prediction when self.norm_on_bbox + is True. + + Returns: + tuple: scores for each class, bbox predictions and centerness + predictions of input feature maps. + """ + cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x) + if self.centerness_on_reg: + centerness = self.conv_centerness(reg_feat) + else: + centerness = self.conv_centerness(cls_feat) + # scale the bbox_pred of different level + # float to avoid overflow when enabling FP16 + bbox_pred = scale(bbox_pred).float() + if self.norm_on_bbox: + # bbox_pred needed for gradient computation has been modified + # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace + # F.relu(bbox_pred) with bbox_pred.clamp(min=0) + bbox_pred = bbox_pred.clamp(min=0) + if not self.training: + bbox_pred *= stride + else: + bbox_pred = bbox_pred.exp() + return cls_score, bbox_pred, centerness + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + centernesses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * 4. + centernesses (list[Tensor]): centerness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) == len(centernesses) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + labels, bbox_targets = self.get_targets(all_level_points, + batch_gt_instances) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores, bbox_preds and centerness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_centerness = [ + centerness.permute(0, 2, 3, 1).reshape(-1) + for centerness in centernesses + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_centerness = torch.cat(flatten_centerness) + flatten_labels = torch.cat(labels) + flatten_bbox_targets = torch.cat(bbox_targets) + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + losses = dict() + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((flatten_labels >= 0) + & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) + num_pos = torch.tensor( + len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) + num_pos = max(reduce_mean(num_pos), 1.0) + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, avg_factor=num_pos) + + if getattr(self.loss_cls, 'custom_accuracy', False): + acc = self.loss_cls.get_accuracy(flatten_cls_scores, + flatten_labels) + losses.update(acc) + + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_centerness = flatten_centerness[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_centerness_targets = self.centerness_target(pos_bbox_targets) + # centerness weighted iou loss + centerness_denorm = max( + reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) + + if len(pos_inds) > 0: + pos_points = flatten_points[pos_inds] + pos_decoded_bbox_preds = self.bbox_coder.decode( + pos_points, pos_bbox_preds) + pos_decoded_target_preds = self.bbox_coder.decode( + pos_points, pos_bbox_targets) + loss_bbox = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + weight=pos_centerness_targets, + avg_factor=centerness_denorm) + loss_centerness = self.loss_centerness( + pos_centerness, pos_centerness_targets, avg_factor=num_pos) + else: + loss_bbox = pos_bbox_preds.sum() + loss_centerness = pos_centerness.sum() + + losses['loss_cls'] = loss_cls + losses['loss_bbox'] = loss_bbox + losses['loss_centerness'] = loss_centerness + + return losses + + def get_targets( + self, points: List[Tensor], batch_gt_instances: InstanceList + ) -> Tuple[List[Tensor], List[Tensor]]: + """Compute regression, classification and centerness targets for points + in multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: Targets of each level. + + - concat_lvl_labels (list[Tensor]): Labels of each level. + - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ + level. + """ + assert len(points) == len(self.regress_ranges) + num_levels = len(points) + # expand regress ranges to align with points + expanded_regress_ranges = [ + points[i].new_tensor(self.regress_ranges[i])[None].expand_as( + points[i]) for i in range(num_levels) + ] + # concat all levels points and regress ranges + concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) + concat_points = torch.cat(points, dim=0) + + # the number of points per img, per lvl + num_points = [center.size(0) for center in points] + + # get labels and bbox_targets of each image + labels_list, bbox_targets_list = multi_apply( + self._get_targets_single, + batch_gt_instances, + points=concat_points, + regress_ranges=concat_regress_ranges, + num_points_per_lvl=num_points) + + # split to per img, per level + labels_list = [labels.split(num_points, 0) for labels in labels_list] + bbox_targets_list = [ + bbox_targets.split(num_points, 0) + for bbox_targets in bbox_targets_list + ] + + # concat per level image + concat_lvl_labels = [] + concat_lvl_bbox_targets = [] + for i in range(num_levels): + concat_lvl_labels.append( + torch.cat([labels[i] for labels in labels_list])) + bbox_targets = torch.cat( + [bbox_targets[i] for bbox_targets in bbox_targets_list]) + if self.norm_on_bbox: + bbox_targets = bbox_targets / self.strides[i] + concat_lvl_bbox_targets.append(bbox_targets) + return concat_lvl_labels, concat_lvl_bbox_targets + + def _get_targets_single( + self, gt_instances: InstanceData, points: Tensor, + regress_ranges: Tensor, + num_points_per_lvl: List[int]) -> Tuple[Tensor, Tensor]: + """Compute regression and classification targets for a single image.""" + num_points = points.size(0) + num_gts = len(gt_instances) + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + + if num_gts == 0: + return gt_labels.new_full((num_points,), self.num_classes), \ + gt_bboxes.new_zeros((num_points, 4)) + + areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) + # TODO: figure out why these two are different + # areas = areas[None].expand(num_points, num_gts) + areas = areas[None].repeat(num_points, 1) + regress_ranges = regress_ranges[:, None, :].expand( + num_points, num_gts, 2) + gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) + xs, ys = points[:, 0], points[:, 1] + xs = xs[:, None].expand(num_points, num_gts) + ys = ys[:, None].expand(num_points, num_gts) + + left = xs - gt_bboxes[..., 0] + right = gt_bboxes[..., 2] - xs + top = ys - gt_bboxes[..., 1] + bottom = gt_bboxes[..., 3] - ys + bbox_targets = torch.stack((left, top, right, bottom), -1) + + if self.center_sampling: + # condition1: inside a `center bbox` + radius = self.center_sample_radius + center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 + center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 + center_gts = torch.zeros_like(gt_bboxes) + stride = center_xs.new_zeros(center_xs.shape) + + # project the points on current lvl back to the `original` sizes + lvl_begin = 0 + for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): + lvl_end = lvl_begin + num_points_lvl + stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius + lvl_begin = lvl_end + + x_mins = center_xs - stride + y_mins = center_ys - stride + x_maxs = center_xs + stride + y_maxs = center_ys + stride + center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], + x_mins, gt_bboxes[..., 0]) + center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], + y_mins, gt_bboxes[..., 1]) + center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], + gt_bboxes[..., 2], x_maxs) + center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], + gt_bboxes[..., 3], y_maxs) + + cb_dist_left = xs - center_gts[..., 0] + cb_dist_right = center_gts[..., 2] - xs + cb_dist_top = ys - center_gts[..., 1] + cb_dist_bottom = center_gts[..., 3] - ys + center_bbox = torch.stack( + (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) + inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 + else: + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 + + # condition2: limit the regression range for each location + max_regress_distance = bbox_targets.max(-1)[0] + inside_regress_range = ( + (max_regress_distance >= regress_ranges[..., 0]) + & (max_regress_distance <= regress_ranges[..., 1])) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + areas[inside_gt_bbox_mask == 0] = INF + areas[inside_regress_range == 0] = INF + min_area, min_area_inds = areas.min(dim=1) + + labels = gt_labels[min_area_inds] + labels[min_area == INF] = self.num_classes # set as BG + bbox_targets = bbox_targets[range(num_points), min_area_inds] + + return labels, bbox_targets + + def centerness_target(self, pos_bbox_targets: Tensor) -> Tensor: + """Compute centerness targets. + + Args: + pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape + (num_pos, 4) + + Returns: + Tensor: Centerness target. + """ + # only calculate pos centerness targets, otherwise there may be nan + left_right = pos_bbox_targets[:, [0, 2]] + top_bottom = pos_bbox_targets[:, [1, 3]] + if len(left_right) == 0: + centerness_targets = left_right[..., 0] + else: + centerness_targets = ( + left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + return torch.sqrt(centerness_targets) diff --git a/mmdet/models/dense_heads/fovea_head.py b/mmdet/models/dense_heads/fovea_head.py new file mode 100644 index 0000000000000000000000000000000000000000..89353deac7f0189c1e464288521ee8e4238f0107 --- /dev/null +++ b/mmdet/models/dense_heads/fovea_head.py @@ -0,0 +1,509 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import DeformConv2d +from mmengine.config import ConfigDict +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig +from ..utils import filter_scores_and_topk, multi_apply +from .anchor_free_head import AnchorFreeHead + +INF = 1e8 + + +class FeatureAlign(BaseModule): + """Feature Align Module. + + Feature Align Module is implemented based on DCN v1. + It uses anchor shape prediction rather than feature map to + predict offsets of deform conv layer. + + Args: + in_channels (int): Number of channels in the input feature map. + out_channels (int): Number of channels in the output feature map. + kernel_size (int): Size of the convolution kernel. + ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``. + deform_groups: (int): Group number of DCN in + FeatureAdaption module. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + deform_groups: int = 4, + init_cfg: OptMultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.1, + override=dict(type='Normal', name='conv_adaption', std=0.01)) + ) -> None: + super().__init__(init_cfg=init_cfg) + offset_channels = kernel_size * kernel_size * 2 + self.conv_offset = nn.Conv2d( + 4, deform_groups * offset_channels, 1, bias=False) + self.conv_adaption = DeformConv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + deform_groups=deform_groups) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: Tensor, shape: Tensor) -> Tensor: + """Forward function of feature align module. + + Args: + x (Tensor): Features from the upstream network. + shape (Tensor): Exponential of bbox predictions. + + Returns: + x (Tensor): The aligned features. + """ + offset = self.conv_offset(shape) + x = self.relu(self.conv_adaption(x, offset)) + return x + + +@MODELS.register_module() +class FoveaHead(AnchorFreeHead): + """Detection Head of `FoveaBox: Beyond Anchor-based Object Detector. + + `_. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + base_edge_list (list[int]): List of edges. + scale_ranges (list[tuple]): Range of scales. + sigma (float): Super parameter of ``FoveaHead``. + with_deform (bool): Whether use deform conv. + deform_groups (int): Deformable conv group size. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + base_edge_list: List[int] = (16, 32, 64, 128, 256), + scale_ranges: List[tuple] = ((8, 32), (16, 64), (32, 128), + (64, 256), (128, 512)), + sigma: float = 0.4, + with_deform: bool = False, + deform_groups: int = 4, + init_cfg: OptMultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='conv_cls', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + self.base_edge_list = base_edge_list + self.scale_ranges = scale_ranges + self.sigma = sigma + self.with_deform = with_deform + self.deform_groups = deform_groups + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + # box branch + super()._init_reg_convs() + self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + + # cls branch + if not self.with_deform: + super()._init_cls_convs() + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + else: + self.cls_convs = nn.ModuleList() + self.cls_convs.append( + ConvModule( + self.feat_channels, (self.feat_channels * 4), + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None)) + self.cls_convs.append( + ConvModule((self.feat_channels * 4), (self.feat_channels * 4), + 1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None)) + self.feature_adaption = FeatureAlign( + self.feat_channels, + self.feat_channels, + kernel_size=3, + deform_groups=self.deform_groups) + self.conv_cls = nn.Conv2d( + int(self.feat_channels * 4), + self.cls_out_channels, + 3, + padding=1) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + + Returns: + tuple: scores for each class and bbox predictions of input + feature maps. + """ + cls_feat = x + reg_feat = x + for reg_layer in self.reg_convs: + reg_feat = reg_layer(reg_feat) + bbox_pred = self.conv_reg(reg_feat) + if self.with_deform: + cls_feat = self.feature_adaption(cls_feat, bbox_pred.exp()) + for cls_layer in self.cls_convs: + cls_feat = cls_layer(cls_feat) + cls_score = self.conv_cls(cls_feat) + return cls_score, bbox_pred + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_priors * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_priors * 4. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + num_imgs = cls_scores[0].size(0) + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_labels, flatten_bbox_targets = self.get_targets( + batch_gt_instances, featmap_sizes, priors) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + pos_inds = ((flatten_labels >= 0) + & (flatten_labels < self.num_classes)).nonzero().view(-1) + num_pos = len(pos_inds) + + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs) + if num_pos > 0: + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_weights = pos_bbox_targets.new_ones(pos_bbox_targets.size()) + loss_bbox = self.loss_bbox( + pos_bbox_preds, + pos_bbox_targets, + pos_weights, + avg_factor=num_pos) + else: + loss_bbox = torch.tensor( + 0, + dtype=flatten_bbox_preds.dtype, + device=flatten_bbox_preds.device) + return dict(loss_cls=loss_cls, loss_bbox=loss_bbox) + + def get_targets( + self, batch_gt_instances: InstanceList, featmap_sizes: List[tuple], + priors_list: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + """Compute regression and classification for priors in multiple images. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + featmap_sizes (list[tuple]): Size tuple of feature maps. + priors_list (list[Tensor]): Priors list of each fpn level, each has + shape (num_priors, 2). + + Returns: + tuple: Targets of each level. + + - flatten_labels (list[Tensor]): Labels of each level. + - flatten_bbox_targets (list[Tensor]): BBox targets of each + level. + """ + label_list, bbox_target_list = multi_apply( + self._get_targets_single, + batch_gt_instances, + featmap_size_list=featmap_sizes, + priors_list=priors_list) + flatten_labels = [ + torch.cat([ + labels_level_img.flatten() for labels_level_img in labels_level + ]) for labels_level in zip(*label_list) + ] + flatten_bbox_targets = [ + torch.cat([ + bbox_targets_level_img.reshape(-1, 4) + for bbox_targets_level_img in bbox_targets_level + ]) for bbox_targets_level in zip(*bbox_target_list) + ] + flatten_labels = torch.cat(flatten_labels) + flatten_bbox_targets = torch.cat(flatten_bbox_targets) + return flatten_labels, flatten_bbox_targets + + def _get_targets_single(self, + gt_instances: InstanceData, + featmap_size_list: List[tuple] = None, + priors_list: List[Tensor] = None) -> tuple: + """Compute regression and classification targets for a single image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + featmap_size_list (list[tuple]): Size tuple of feature maps. + priors_list (list[Tensor]): Priors of each fpn level, each has + shape (num_priors, 2). + + Returns: + tuple: + + - label_list (list[Tensor]): Labels of all anchors in the image. + - box_target_list (list[Tensor]): BBox targets of all anchors in + the image. + """ + gt_bboxes_raw = gt_instances.bboxes + gt_labels_raw = gt_instances.labels + gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * + (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1])) + label_list = [] + bbox_target_list = [] + # for each pyramid, find the cls and box target + for base_len, (lower_bound, upper_bound), stride, featmap_size, \ + priors in zip(self.base_edge_list, self.scale_ranges, + self.strides, featmap_size_list, priors_list): + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + priors = priors.view(*featmap_size, 2) + x, y = priors[..., 0], priors[..., 1] + labels = gt_labels_raw.new_full(featmap_size, self.num_classes) + bbox_targets = gt_bboxes_raw.new_ones(featmap_size[0], + featmap_size[1], 4) + # scale assignment + hit_indices = ((gt_areas >= lower_bound) & + (gt_areas <= upper_bound)).nonzero().flatten() + if len(hit_indices) == 0: + label_list.append(labels) + bbox_target_list.append(torch.log(bbox_targets)) + continue + _, hit_index_order = torch.sort(-gt_areas[hit_indices]) + hit_indices = hit_indices[hit_index_order] + gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride + gt_labels = gt_labels_raw[hit_indices] + half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) + half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) + # valid fovea area: left, right, top, down + pos_left = torch.ceil( + gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long(). \ + clamp(0, featmap_size[1] - 1) + pos_right = torch.floor( + gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long(). \ + clamp(0, featmap_size[1] - 1) + pos_top = torch.ceil( + gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long(). \ + clamp(0, featmap_size[0] - 1) + pos_down = torch.floor( + gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long(). \ + clamp(0, featmap_size[0] - 1) + for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \ + zip(pos_left, pos_top, pos_right, pos_down, gt_labels, + gt_bboxes_raw[hit_indices, :]): + labels[py1:py2 + 1, px1:px2 + 1] = label + bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \ + (x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len + bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \ + (y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len + bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \ + (gt_x2 - x[py1:py2 + 1, px1:px2 + 1]) / base_len + bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \ + (gt_y2 - y[py1:py2 + 1, px1:px2 + 1]) / base_len + bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.) + label_list.append(labels) + bbox_target_list.append(torch.log(bbox_targets)) + return label_list, bbox_target_list + + # Same as base_dense_head/_predict_by_feat_single except self._bbox_decode + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, has shape + (num_priors, 2). + img_meta (dict): Image meta info. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_score_list) == len(bbox_pred_list) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_labels = [] + for level_idx, (cls_score, bbox_pred, stride, base_len, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, self.strides, + self.base_edge_list, mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, _, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + + bboxes = self._bbox_decode(priors, bbox_pred, base_len, img_shape) + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def _bbox_decode(self, priors: Tensor, bbox_pred: Tensor, base_len: int, + max_shape: int) -> Tensor: + """Function to decode bbox. + + Args: + priors (Tensor): Center proiors of an image, has shape + (num_instances, 2). + bbox_preds (Tensor): Box energies / deltas for all instances, + has shape (batch_size, num_instances, 4). + base_len (int): The base length. + max_shape (int): The max shape of bbox. + + Returns: + Tensor: Decoded bboxes in (tl_x, tl_y, br_x, br_y) format. Has + shape (batch_size, num_instances, 4). + """ + bbox_pred = bbox_pred.exp() + + y = priors[:, 1] + x = priors[:, 0] + x1 = (x - base_len * bbox_pred[:, 0]). \ + clamp(min=0, max=max_shape[1] - 1) + y1 = (y - base_len * bbox_pred[:, 1]). \ + clamp(min=0, max=max_shape[0] - 1) + x2 = (x + base_len * bbox_pred[:, 2]). \ + clamp(min=0, max=max_shape[1] - 1) + y2 = (y + base_len * bbox_pred[:, 3]). \ + clamp(min=0, max=max_shape[0] - 1) + decoded_bboxes = torch.stack([x1, y1, x2, y2], -1) + return decoded_bboxes diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py new file mode 100644 index 0000000000000000000000000000000000000000..df6fb9202c32735121bf7738e332fbfc5ac7e6bd --- /dev/null +++ b/mmdet/models/dense_heads/free_anchor_retina_head.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import InstanceList, OptConfigType, OptInstanceList +from ..utils import multi_apply +from .retina_head import RetinaHead + +EPS = 1e-12 + + +@MODELS.register_module() +class FreeAnchorRetinaHead(RetinaHead): + """FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Defaults to 4. + conv_cfg (:obj:`ConfigDict` or dict, optional): dictionary to + construct and config conv layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to + construct and config norm layer. Defaults to + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True). + pre_anchor_topk (int): Number of boxes that be token in each bag. + Defaults to 50 + bbox_thr (float): The threshold of the saturated linear function. + It is usually the same with the IoU threshold used in NMS. + Defaults to 0.6. + gamma (float): Gamma parameter in focal loss. Defaults to 2.0. + alpha (float): Alpha parameter in focal loss. Defaults to 0.5. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + pre_anchor_topk: int = 50, + bbox_thr: float = 0.6, + gamma: float = 2.0, + alpha: float = 0.5, + **kwargs) -> None: + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + stacked_convs=stacked_convs, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs) + + self.pre_anchor_topk = pre_anchor_topk + self.bbox_thr = bbox_thr + self.gamma = gamma + self.alpha = alpha + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, _ = self.get_anchors( + featmap_sizes=featmap_sizes, + batch_img_metas=batch_img_metas, + device=device) + concat_anchor_list = [torch.cat(anchor) for anchor in anchor_list] + + # concatenate each level + cls_scores = [ + cls.permute(0, 2, 3, + 1).reshape(cls.size(0), -1, self.cls_out_channels) + for cls in cls_scores + ] + bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4) + for bbox_pred in bbox_preds + ] + cls_scores = torch.cat(cls_scores, dim=1) + cls_probs = torch.sigmoid(cls_scores) + bbox_preds = torch.cat(bbox_preds, dim=1) + + box_probs, positive_losses, num_pos_list = multi_apply( + self.positive_loss_single, cls_probs, bbox_preds, + concat_anchor_list, batch_gt_instances) + + num_pos = sum(num_pos_list) + positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos) + + # box_prob: P{a_{j} \in A_{+}} + box_probs = torch.stack(box_probs, dim=0) + + # negative_loss: + # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B|| + negative_loss = self.negative_bag_loss(cls_probs, box_probs).sum() / \ + max(1, num_pos * self.pre_anchor_topk) + + # avoid the absence of gradients in regression subnet + # when no ground-truth in a batch + if num_pos == 0: + positive_loss = bbox_preds.sum() * 0 + + losses = { + 'positive_bag_loss': positive_loss, + 'negative_bag_loss': negative_loss + } + return losses + + def positive_loss_single(self, cls_prob: Tensor, bbox_pred: Tensor, + flat_anchors: Tensor, + gt_instances: InstanceData) -> tuple: + """Compute positive loss. + + Args: + cls_prob (Tensor): Classification probability of shape + (num_anchors, num_classes). + bbox_pred (Tensor): Box probability of shape (num_anchors, 4). + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors, 4) + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: + + - box_prob (Tensor): Box probability of shape (num_anchors, 4). + - positive_loss (Tensor): Positive loss of shape (num_pos, ). + - num_pos (int): positive samples indexes. + """ + + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + with torch.no_grad(): + if len(gt_bboxes) == 0: + image_box_prob = torch.zeros( + flat_anchors.size(0), + self.cls_out_channels).type_as(bbox_pred) + else: + # box_localization: a_{j}^{loc}, shape: [j, 4] + pred_boxes = self.bbox_coder.decode(flat_anchors, bbox_pred) + + # object_box_iou: IoU_{ij}^{loc}, shape: [i, j] + object_box_iou = bbox_overlaps(gt_bboxes, pred_boxes) + + # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j] + t1 = self.bbox_thr + t2 = object_box_iou.max( + dim=1, keepdim=True).values.clamp(min=t1 + 1e-12) + object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp( + min=0, max=1) + + # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j] + num_obj = gt_labels.size(0) + indices = torch.stack( + [torch.arange(num_obj).type_as(gt_labels), gt_labels], + dim=0) + object_cls_box_prob = torch.sparse_coo_tensor( + indices, object_box_prob) + + # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j] + """ + from "start" to "end" implement: + image_box_iou = torch.sparse.max(object_cls_box_prob, + dim=0).t() + + """ + # start + box_cls_prob = torch.sparse.sum( + object_cls_box_prob, dim=0).to_dense() + + indices = torch.nonzero(box_cls_prob, as_tuple=False).t_() + if indices.numel() == 0: + image_box_prob = torch.zeros( + flat_anchors.size(0), + self.cls_out_channels).type_as(object_box_prob) + else: + nonzero_box_prob = torch.where( + (gt_labels.unsqueeze(dim=-1) == indices[0]), + object_box_prob[:, indices[1]], + torch.tensor( + [0]).type_as(object_box_prob)).max(dim=0).values + + # upmap to shape [j, c] + image_box_prob = torch.sparse_coo_tensor( + indices.flip([0]), + nonzero_box_prob, + size=(flat_anchors.size(0), + self.cls_out_channels)).to_dense() + # end + box_prob = image_box_prob + + # construct bags for objects + match_quality_matrix = bbox_overlaps(gt_bboxes, flat_anchors) + _, matched = torch.topk( + match_quality_matrix, self.pre_anchor_topk, dim=1, sorted=False) + del match_quality_matrix + + # matched_cls_prob: P_{ij}^{cls} + matched_cls_prob = torch.gather( + cls_prob[matched], 2, + gt_labels.view(-1, 1, 1).repeat(1, self.pre_anchor_topk, + 1)).squeeze(2) + + # matched_box_prob: P_{ij}^{loc} + matched_anchors = flat_anchors[matched] + matched_object_targets = self.bbox_coder.encode( + matched_anchors, + gt_bboxes.unsqueeze(dim=1).expand_as(matched_anchors)) + loss_bbox = self.loss_bbox( + bbox_pred[matched], + matched_object_targets, + reduction_override='none').sum(-1) + matched_box_prob = torch.exp(-loss_bbox) + + # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )} + num_pos = len(gt_bboxes) + positive_loss = self.positive_bag_loss(matched_cls_prob, + matched_box_prob) + + return box_prob, positive_loss, num_pos + + def positive_bag_loss(self, matched_cls_prob: Tensor, + matched_box_prob: Tensor) -> Tensor: + """Compute positive bag loss. + + :math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`. + + :math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples. + + :math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples. + + Args: + matched_cls_prob (Tensor): Classification probability of matched + samples in shape (num_gt, pre_anchor_topk). + matched_box_prob (Tensor): BBox probability of matched samples, + in shape (num_gt, pre_anchor_topk). + + Returns: + Tensor: Positive bag loss in shape (num_gt,). + """ # noqa: E501, W605 + # bag_prob = Mean-max(matched_prob) + matched_prob = matched_cls_prob * matched_box_prob + weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None) + weight /= weight.sum(dim=1).unsqueeze(dim=-1) + bag_prob = (weight * matched_prob).sum(dim=1) + # positive_bag_loss = -self.alpha * log(bag_prob) + return self.alpha * F.binary_cross_entropy( + bag_prob, torch.ones_like(bag_prob), reduction='none') + + def negative_bag_loss(self, cls_prob: Tensor, box_prob: Tensor) -> Tensor: + """Compute negative bag loss. + + :math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`. + + :math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples. + + :math:`P_{j}^{bg}`: Classification probability of negative samples. + + Args: + cls_prob (Tensor): Classification probability, in shape + (num_img, num_anchors, num_classes). + box_prob (Tensor): Box probability, in shape + (num_img, num_anchors, num_classes). + + Returns: + Tensor: Negative bag loss in shape (num_img, num_anchors, + num_classes). + """ # noqa: E501, W605 + prob = cls_prob * (1 - box_prob) + # There are some cases when neg_prob = 0. + # This will cause the neg_prob.log() to be inf without clamp. + prob = prob.clamp(min=EPS, max=1 - EPS) + negative_bag_loss = prob**self.gamma * F.binary_cross_entropy( + prob, torch.zeros_like(prob), reduction='none') + return (1 - self.alpha) * negative_bag_loss diff --git a/mmdet/models/dense_heads/fsaf_head.py b/mmdet/models/dense_heads/fsaf_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0a01c487406693253eb17b883cac9ed06cf95802 --- /dev/null +++ b/mmdet/models/dense_heads/fsaf_head.py @@ -0,0 +1,458 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig +from ..losses.accuracy import accuracy +from ..losses.utils import weight_reduce_loss +from ..task_modules.prior_generators import anchor_inside_flags +from ..utils import images_to_levels, multi_apply, unmap +from .retina_head import RetinaHead + + +@MODELS.register_module() +class FSAFHead(RetinaHead): + """Anchor-free head used in `FSAF `_. + + The head contains two subnetworks. The first classifies anchor boxes and + the second regresses deltas for the anchors (num_anchors is 1 for anchor- + free methods) + + Args: + *args: Same as its base class in :class:`RetinaHead` + score_threshold (float, optional): The score_threshold to calculate + positive recall. If given, prediction scores lower than this value + is counted as incorrect prediction. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + **kwargs: Same as its base class in :class:`RetinaHead` + + Example: + >>> import torch + >>> self = FSAFHead(11, 7) + >>> x = torch.rand(1, 7, 32, 32) + >>> cls_score, bbox_pred = self.forward_single(x) + >>> # Each anchor predicts a score for each class except background + >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors + >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors + >>> assert cls_per_anchor == self.num_classes + >>> assert box_per_anchor == 4 + """ + + def __init__(self, + *args, + score_threshold: Optional[float] = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + # The positive bias in self.retina_reg conv is to prevent predicted \ + # bbox with 0 area + if init_cfg is None: + init_cfg = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=[ + dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01), + dict( + type='Normal', name='retina_reg', std=0.01, bias=0.25) + ]) + super().__init__(*args, init_cfg=init_cfg, **kwargs) + self.score_threshold = score_threshold + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Forward feature map of a single scale level. + + Args: + x (Tensor): Feature map of a single scale level. + + Returns: + tuple[Tensor, Tensor]: + + - cls_score (Tensor): Box scores for each scale level Has \ + shape (N, num_points * num_classes, H, W). + - bbox_pred (Tensor): Box energies / deltas for each scale \ + level with shape (N, num_points * 4, H, W). + """ + cls_score, bbox_pred = super().forward_single(x) + # relu: TBLR encoder only accepts positive bbox_pred + return cls_score, self.relu(bbox_pred) + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + Most of the codes are the same with the base class :obj: `AnchorHead`, + except that it also collects and returns the matched gt index in the + image (from 0 to num_gt-1). If the anchor bbox is not matched to any + gt, the corresponding value in pos_gt_inds is -1. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors, ). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # Assign gt and sample anchors + anchors = flat_anchors[inside_flags.type(torch.bool), :] + + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros( + (num_valid_anchors, self.cls_out_channels), dtype=torch.float) + pos_gt_inds = anchors.new_full((num_valid_anchors, ), + -1, + dtype=torch.long) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + if len(pos_inds) > 0: + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + else: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, both + # the predicted boxes and regression targets should be with + # absolute coordinate format. + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + # The assigned gt_index for each anchor. (0-based) + pos_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # shadowed_labels is a tensor composed of tuples + # (anchor_inds, class_label) that indicate those anchors lying in the + # outer region of a gt or overlapped by another gt with a smaller + # area. + # + # Therefore, only the shadowed labels are ignored for loss calculation. + # the key `shadowed_labels` is defined in :obj:`CenterRegionAssigner` + shadowed_labels = assign_result.get_extra_property('shadowed_labels') + if shadowed_labels is not None and shadowed_labels.numel(): + if len(shadowed_labels.shape) == 2: + idx_, label_ = shadowed_labels[:, 0], shadowed_labels[:, 1] + assert (labels[idx_] != label_).all(), \ + 'One label cannot be both positive and ignored' + label_weights[idx_, label_] = 0 + else: + label_weights[shadowed_labels] = 0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + pos_gt_inds = unmap( + pos_gt_inds, num_total_anchors, inside_flags, fill=-1) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, sampling_result, pos_gt_inds) + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Compute loss of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_points * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_points * 4, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + for i in range(len(bbox_preds)): # loop over fpn level + # avoid 0 area of the predicted bbox + bbox_preds[i] = bbox_preds[i].clamp(min=1e-4) + # TODO: It may directly use the base-class loss function. + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + batch_size = len(batch_img_metas) + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + return_sampling_results=True) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor, sampling_results_list, + pos_assigned_gt_inds_list) = cls_reg_targets + + num_gts = np.array(list(map(len, batch_gt_instances))) + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i in range(len(anchor_list)): + concat_anchor_list.append(torch.cat(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + losses_cls, losses_bbox = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + all_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + avg_factor=avg_factor) + + # `pos_assigned_gt_inds_list` (length: fpn_levels) stores the assigned + # gt index of each anchor bbox in each fpn level. + cum_num_gts = list(np.cumsum(num_gts)) # length of batch_size + for i, assign in enumerate(pos_assigned_gt_inds_list): + # loop over fpn levels + for j in range(1, batch_size): + # loop over batch size + # Convert gt indices in each img to those in the batch + assign[j][assign[j] >= 0] += int(cum_num_gts[j - 1]) + pos_assigned_gt_inds_list[i] = assign.flatten() + labels_list[i] = labels_list[i].flatten() + num_gts = num_gts.sum() # total number of gt in the batch + # The unique label index of each gt in the batch + label_sequence = torch.arange(num_gts, device=device) + # Collect the average loss of each gt in each level + with torch.no_grad(): + loss_levels, = multi_apply( + self.collect_loss_level_single, + losses_cls, + losses_bbox, + pos_assigned_gt_inds_list, + labels_seq=label_sequence) + # Shape: (fpn_levels, num_gts). Loss of each gt at each fpn level + loss_levels = torch.stack(loss_levels, dim=0) + # Locate the best fpn level for loss back-propagation + if loss_levels.numel() == 0: # zero gt + argmin = loss_levels.new_empty((num_gts, ), dtype=torch.long) + else: + _, argmin = loss_levels.min(dim=0) + + # Reweight the loss of each (anchor, label) pair, so that only those + # at the best gt level are back-propagated. + losses_cls, losses_bbox, pos_inds = multi_apply( + self.reweight_loss_single, + losses_cls, + losses_bbox, + pos_assigned_gt_inds_list, + labels_list, + list(range(len(losses_cls))), + min_levels=argmin) + num_pos = torch.cat(pos_inds, 0).sum().float() + pos_recall = self.calculate_pos_recall(cls_scores, labels_list, + pos_inds) + + if num_pos == 0: # No gt + num_total_neg = sum( + [results.num_neg for results in sampling_results_list]) + avg_factor = num_pos + num_total_neg + else: + avg_factor = num_pos + for i in range(len(losses_cls)): + losses_cls[i] /= avg_factor + losses_bbox[i] /= avg_factor + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + num_pos=num_pos / batch_size, + pos_recall=pos_recall) + + def calculate_pos_recall(self, cls_scores: List[Tensor], + labels_list: List[Tensor], + pos_inds: List[Tensor]) -> Tensor: + """Calculate positive recall with score threshold. + + Args: + cls_scores (list[Tensor]): Classification scores at all fpn levels. + Each tensor is in shape (N, num_classes * num_anchors, H, W) + labels_list (list[Tensor]): The label that each anchor is assigned + to. Shape (N * H * W * num_anchors, ) + pos_inds (list[Tensor]): List of bool tensors indicating whether + the anchor is assigned to a positive label. + Shape (N * H * W * num_anchors, ) + + Returns: + Tensor: A single float number indicating the positive recall. + """ + with torch.no_grad(): + num_class = self.num_classes + scores = [ + cls.permute(0, 2, 3, 1).reshape(-1, num_class)[pos] + for cls, pos in zip(cls_scores, pos_inds) + ] + labels = [ + label.reshape(-1)[pos] + for label, pos in zip(labels_list, pos_inds) + ] + scores = torch.cat(scores, dim=0) + labels = torch.cat(labels, dim=0) + if self.use_sigmoid_cls: + scores = scores.sigmoid() + else: + scores = scores.softmax(dim=1) + + return accuracy(scores, labels, thresh=self.score_threshold) + + def collect_loss_level_single(self, cls_loss: Tensor, reg_loss: Tensor, + assigned_gt_inds: Tensor, + labels_seq: Tensor) -> Tensor: + """Get the average loss in each FPN level w.r.t. each gt label. + + Args: + cls_loss (Tensor): Classification loss of each feature map pixel, + shape (num_anchor, num_class) + reg_loss (Tensor): Regression loss of each feature map pixel, + shape (num_anchor, 4) + assigned_gt_inds (Tensor): It indicates which gt the prior is + assigned to (0-based, -1: no assignment). shape (num_anchor), + labels_seq: The rank of labels. shape (num_gt) + + Returns: + Tensor: shape (num_gt), average loss of each gt in this level + """ + if len(reg_loss.shape) == 2: # iou loss has shape (num_prior, 4) + reg_loss = reg_loss.sum(dim=-1) # sum loss in tblr dims + if len(cls_loss.shape) == 2: + cls_loss = cls_loss.sum(dim=-1) # sum loss in class dims + loss = cls_loss + reg_loss + assert loss.size(0) == assigned_gt_inds.size(0) + # Default loss value is 1e6 for a layer where no anchor is positive + # to ensure it will not be chosen to back-propagate gradient + losses_ = loss.new_full(labels_seq.shape, 1e6) + for i, l in enumerate(labels_seq): + match = assigned_gt_inds == l + if match.any(): + losses_[i] = loss[match].mean() + return losses_, + + def reweight_loss_single(self, cls_loss: Tensor, reg_loss: Tensor, + assigned_gt_inds: Tensor, labels: Tensor, + level: int, min_levels: Tensor) -> tuple: + """Reweight loss values at each level. + + Reassign loss values at each level by masking those where the + pre-calculated loss is too large. Then return the reduced losses. + + Args: + cls_loss (Tensor): Element-wise classification loss. + Shape: (num_anchors, num_classes) + reg_loss (Tensor): Element-wise regression loss. + Shape: (num_anchors, 4) + assigned_gt_inds (Tensor): The gt indices that each anchor bbox + is assigned to. -1 denotes a negative anchor, otherwise it is the + gt index (0-based). Shape: (num_anchors, ), + labels (Tensor): Label assigned to anchors. Shape: (num_anchors, ). + level (int): The current level index in the pyramid + (0-4 for RetinaNet) + min_levels (Tensor): The best-matching level for each gt. + Shape: (num_gts, ), + + Returns: + tuple: + + - cls_loss: Reduced corrected classification loss. Scalar. + - reg_loss: Reduced corrected regression loss. Scalar. + - pos_flags (Tensor): Corrected bool tensor indicating the \ + final positive anchors. Shape: (num_anchors, ). + """ + loc_weight = torch.ones_like(reg_loss) + cls_weight = torch.ones_like(cls_loss) + pos_flags = assigned_gt_inds >= 0 # positive pixel flag + pos_indices = torch.nonzero(pos_flags, as_tuple=False).flatten() + + if pos_flags.any(): # pos pixels exist + pos_assigned_gt_inds = assigned_gt_inds[pos_flags] + zeroing_indices = (min_levels[pos_assigned_gt_inds] != level) + neg_indices = pos_indices[zeroing_indices] + + if neg_indices.numel(): + pos_flags[neg_indices] = 0 + loc_weight[neg_indices] = 0 + # Only the weight corresponding to the label is + # zeroed out if not selected + zeroing_labels = labels[neg_indices] + assert (zeroing_labels >= 0).all() + cls_weight[neg_indices, zeroing_labels] = 0 + + # Weighted loss for both cls and reg loss + cls_loss = weight_reduce_loss(cls_loss, cls_weight, reduction='sum') + reg_loss = weight_reduce_loss(reg_loss, loc_weight, reduction='sum') + + return cls_loss, reg_loss, pos_flags diff --git a/mmdet/models/dense_heads/ga_retina_head.py b/mmdet/models/dense_heads/ga_retina_head.py new file mode 100644 index 0000000000000000000000000000000000000000..569910b365126e90638256f0d10addfa230fd141 --- /dev/null +++ b/mmdet/models/dense_heads/ga_retina_head.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import MaskedConv2d +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig +from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead + + +@MODELS.register_module() +class GARetinaHead(GuidedAnchorHead): + """Guided-Anchor-based RetinaNet head.""" + + def __init__(self, + num_classes: int, + in_channels: int, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + if init_cfg is None: + init_cfg = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=[ + dict( + type='Normal', + name='conv_loc', + std=0.01, + bias_prob=0.01), + dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01) + ]) + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + + self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1) + num_anchors = self.square_anchor_generator.num_base_priors[0] + self.conv_shape = nn.Conv2d(self.feat_channels, num_anchors * 2, 1) + self.feature_adaption_cls = FeatureAdaption( + self.feat_channels, + self.feat_channels, + kernel_size=3, + deform_groups=self.deform_groups) + self.feature_adaption_reg = FeatureAdaption( + self.feat_channels, + self.feat_channels, + kernel_size=3, + deform_groups=self.deform_groups) + self.retina_cls = MaskedConv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.retina_reg = MaskedConv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + + def forward_single(self, x: Tensor) -> Tuple[Tensor]: + """Forward feature map of a single scale level.""" + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + + loc_pred = self.conv_loc(cls_feat) + shape_pred = self.conv_shape(reg_feat) + + cls_feat = self.feature_adaption_cls(cls_feat, shape_pred) + reg_feat = self.feature_adaption_reg(reg_feat, shape_pred) + + if not self.training: + mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr + else: + mask = None + cls_score = self.retina_cls(cls_feat, mask) + bbox_pred = self.retina_reg(reg_feat, mask) + return cls_score, bbox_pred, shape_pred, loc_pred diff --git a/mmdet/models/dense_heads/ga_rpn_head.py b/mmdet/models/dense_heads/ga_rpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9614463165533358b8465420a87dfa47e7de1177 --- /dev/null +++ b/mmdet/models/dense_heads/ga_rpn_head.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import nms +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList +from .guided_anchor_head import GuidedAnchorHead + + +@MODELS.register_module() +class GARPNHead(GuidedAnchorHead): + """Guided-Anchor-based RPN head.""" + + def __init__(self, + in_channels: int, + num_classes: int = 1, + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='conv_loc', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.rpn_conv = nn.Conv2d( + self.in_channels, self.feat_channels, 3, padding=1) + super(GARPNHead, self)._init_layers() + + def forward_single(self, x: Tensor) -> Tuple[Tensor]: + """Forward feature of a single scale level.""" + + x = self.rpn_conv(x) + x = F.relu(x, inplace=True) + (cls_score, bbox_pred, shape_pred, + loc_pred) = super().forward_single(x) + return cls_score, bbox_pred, shape_pred, loc_pred + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + shape_preds: List[Tensor], + loc_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + shape_preds (list[Tensor]): shape predictions for each scale + level with shape (N, 1, H, W). + loc_preds (list[Tensor]): location predictions for each scale + level with shape (N, num_anchors * 2, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + losses = super().loss_by_feat( + cls_scores, + bbox_preds, + shape_preds, + loc_preds, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + return dict( + loss_rpn_cls=losses['loss_cls'], + loss_rpn_bbox=losses['loss_bbox'], + loss_anchor_shape=losses['loss_shape'], + loss_anchor_loc=losses['loss_loc']) + + def _predict_by_feat_single(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + mlvl_anchors: List[Tensor], + mlvl_masks: List[Tensor], + img_meta: dict, + cfg: ConfigType, + rescale: bool = False) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + mlvl_anchors (list[Tensor]): Each element in the list is + the anchors of a single level in feature pyramid. it has + shape (num_priors, 4). + mlvl_masks (list[Tensor]): Each element in the list is location + masks of a single level. + img_meta (dict): Image meta info. + cfg (:obj:`ConfigDict` or dict): Test / postprocessing + configuration, if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), the last + dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \ + 'naive nms.' + + mlvl_proposals = [] + for idx in range(len(cls_scores)): + rpn_cls_score = cls_scores[idx] + rpn_bbox_pred = bbox_preds[idx] + anchors = mlvl_anchors[idx] + mask = mlvl_masks[idx] + assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] + # if no location is kept, end. + if mask.sum() == 0: + continue + rpn_cls_score = rpn_cls_score.permute(1, 2, 0) + if self.use_sigmoid_cls: + rpn_cls_score = rpn_cls_score.reshape(-1) + scores = rpn_cls_score.sigmoid() + else: + rpn_cls_score = rpn_cls_score.reshape(-1, 2) + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = rpn_cls_score.softmax(dim=1)[:, :-1] + # filter scores, bbox_pred w.r.t. mask. + # anchors are filtered in get_anchors() beforehand. + scores = scores[mask] + rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, + 4)[mask, :] + if scores.dim() == 0: + rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0) + anchors = anchors.unsqueeze(0) + scores = scores.unsqueeze(0) + # filter anchors, bbox_pred, scores w.r.t. scores + if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: + _, topk_inds = scores.topk(cfg.nms_pre) + rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] + anchors = anchors[topk_inds, :] + scores = scores[topk_inds] + # get proposals w.r.t. anchors and rpn_bbox_pred + proposals = self.bbox_coder.decode( + anchors, rpn_bbox_pred, max_shape=img_meta['img_shape']) + # filter out too small bboxes + if cfg.min_bbox_size >= 0: + w = proposals[:, 2] - proposals[:, 0] + h = proposals[:, 3] - proposals[:, 1] + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + proposals = proposals[valid_mask] + scores = scores[valid_mask] + + # NMS in current level + proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold) + proposals = proposals[:cfg.nms_post, :] + mlvl_proposals.append(proposals) + proposals = torch.cat(mlvl_proposals, 0) + if cfg.get('nms_across_levels', False): + # NMS across multi levels + proposals, _ = nms(proposals[:, :4], proposals[:, -1], + cfg.nms.iou_threshold) + proposals = proposals[:cfg.max_per_img, :] + else: + scores = proposals[:, 4] + num = min(cfg.max_per_img, proposals.shape[0]) + _, topk_inds = scores.topk(num) + proposals = proposals[topk_inds, :] + + bboxes = proposals[:, :-1] + scores = proposals[:, -1] + if rescale: + assert img_meta.get('scale_factor') is not None + bboxes /= bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + + results = InstanceData() + results.bboxes = bboxes + results.scores = scores + results.labels = scores.new_zeros(scores.size(0), dtype=torch.long) + return results diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..be43d9b4da39da602b3b87bd3c9739c67367615b --- /dev/null +++ b/mmdet/models/dense_heads/gfl_head.py @@ -0,0 +1,667 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList, reduce_mean) +from ..task_modules.prior_generators import anchor_inside_flags +from ..task_modules.samplers import PseudoSampler +from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply, + unmap) +from .anchor_head import AnchorHead + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + + This layer calculates the target location by :math: ``sum{P(y_i) * y_i}``, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + + Args: + reg_max (int): The maximal value of the discrete set. Defaults to 16. + You may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max: int = 16) -> None: + super().__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x: Tensor) -> Tensor: + """Forward feature from the regression head to get integral result of + bounding box location. + + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1) + x = F.linear(x, self.project.type_as(x)).reshape(-1, 4) + return x + + +@MODELS.register_module() +class GFLHead(AnchorHead): + """Generalized Focal Loss: Learning Qualified and Distributed Bounding + Boxes for Dense Object Detection. + + GFL head structure is similar with ATSS, however GFL uses + 1) joint representation for classification and localization quality, and + 2) flexible General distribution for bounding box locations, + which are supervised by + Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively + + https://arxiv.org/abs/2006.04388 + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Defaults to 4. + conv_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct + and config conv layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config norm layer. Default: dict(type='GN', num_groups=32, + requires_grad=True). + loss_qfl (:obj:`ConfigDict` or dict): Config of Quality Focal Loss + (QFL). + bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults + to 'DistancePointBBoxCoder'. + reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}`` + in QFL setting. Defaults to 16. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. + Example: + >>> self = GFLHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_quality_score, bbox_pred = self.forward(feats) + >>> assert len(cls_quality_score) == len(self.scales) + """ + + def __init__(self, + num_classes: int, + in_channels: int, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + loss_dfl: ConfigType = dict( + type='DistributionFocalLoss', loss_weight=0.25), + bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), + reg_max: int = 16, + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='gfl_cls', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.reg_max = reg_max + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + bbox_coder=bbox_coder, + init_cfg=init_cfg, + **kwargs) + + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + if self.train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) + + self.integral = Integral(self.reg_max) + self.loss_dfl = MODELS.build(loss_dfl) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU() + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + assert self.num_anchors == 1, 'anchor free version' + self.gfl_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + self.gfl_reg = nn.Conv2d( + self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1) + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + + - cls_scores (list[Tensor]): Classification and quality (IoU) + joint scores for all scale levels, each is a 4D-tensor, + the channel number is num_classes. + - bbox_preds (list[Tensor]): Box distribution logits for all + scale levels, each is a 4D-tensor, the channel number is + 4*(n+1), n is max value of integral set. + """ + return multi_apply(self.forward_single, x, self.scales) + + def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + + Returns: + tuple: + + - cls_score (Tensor): Cls and quality joint scores for a single + scale level the channel number is num_classes. + - bbox_pred (Tensor): Box distribution logits for a single scale + level, the channel number is 4*(n+1), n is max value of + integral set. + """ + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.gfl_cls(cls_feat) + bbox_pred = scale(self.gfl_reg(reg_feat)).float() + return cls_score, bbox_pred + + def anchor_center(self, anchors: Tensor) -> Tensor: + """Get anchor centers from anchors. + + Args: + anchors (Tensor): Anchor list with shape (N, 4), ``xyxy`` format. + + Returns: + Tensor: Anchor centers with shape (N, 2), ``xy`` format. + """ + anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2 + anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2 + return torch.stack([anchors_cx, anchors_cy], dim=-1) + + def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + stride: Tuple[int], avg_factor: int) -> dict: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Cls and quality joint scores for each scale + level has shape (N, num_classes, H, W). + bbox_pred (Tensor): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + stride (Tuple[int]): Stride in this scale level. + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(-1, 4 * (self.reg_max + 1)) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + score = label_weights.new_zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] + + weight_targets = cls_score.detach().sigmoid() + weight_targets = weight_targets.max(dim=1)[0][pos_inds] + pos_bbox_pred_corners = self.integral(pos_bbox_pred) + pos_decode_bbox_pred = self.bbox_coder.decode( + pos_anchor_centers, pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + score[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) + target_corners = self.bbox_coder.encode(pos_anchor_centers, + pos_decode_bbox_targets, + self.reg_max).reshape(-1) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=weight_targets, + avg_factor=1.0) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + weight_targets = bbox_pred.new_tensor(0) + + # cls (qfl) loss + loss_cls = self.loss_cls( + cls_score, (labels, score), + weight=label_weights, + avg_factor=avg_factor) + + return loss_cls, loss_bbox, loss_dfl, weight_targets.sum() + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Cls and quality scores for each scale + level has shape (N, num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_reg_targets + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + losses_cls, losses_bbox, losses_dfl,\ + avg_factor = multi_apply( + self.loss_by_feat_single, + anchor_list, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + self.prior_generator.strides, + avg_factor=avg_factor) + + avg_factor = sum(avg_factor) + avg_factor = reduce_mean(avg_factor).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox)) + losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl)) + return dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl) + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. GFL head does not need this value. + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (:obj: `ConfigDict`): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape + [num_bboxes, 5], where the first 4 columns are bounding + box positions (tl_x, tl_y, br_x, br_y) and the 5-th + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding + box with shape [num_bboxes]. + """ + cfg = self.test_cfg if cfg is None else cfg + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_labels = [] + for level_idx, (cls_score, bbox_pred, stride, priors) in enumerate( + zip(cls_score_list, bbox_pred_list, + self.prior_generator.strides, mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + assert stride[0] == stride[1] + + bbox_pred = bbox_pred.permute(1, 2, 0) + bbox_pred = self.integral(bbox_pred) * stride[0] + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, _, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + + bboxes = self.bbox_coder.decode( + self.anchor_center(priors), bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def get_targets(self, + anchor_list: List[Tensor], + valid_flag_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs=True) -> tuple: + """Get targets for GFL head. + + This method is almost the same as `AnchorHead.get_targets()`. Besides + returning the targets as the parent method does, it also returns the + anchors as the first element of the returned tuple. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, avg_factor) + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + num_level_anchors: List[int], + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors (list[int]): Number of anchors of each scale + level. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: N is the number of total anchors in the image. + + - anchors (Tensor): All anchors in the image with shape (N, 4). + - labels (Tensor): Labels of all anchors in the image with + shape (N,). + - label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + - bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + - bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4). + - pos_inds (Tensor): Indices of positive anchor with shape + (num_pos,). + - neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign( + pred_instances=pred_instances, + num_level_priors=num_level_anchors_inside, + gt_instances=gt_instances, + gt_instances_ignore=gt_instances_ignore) + + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds, sampling_result) + + def get_num_level_anchors_inside(self, num_level_anchors: List[int], + inside_flags: Tensor) -> List[int]: + """Get the number of valid anchors in every level.""" + + split_inside_flags = torch.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3aced6265554605b3102029a7c59e1d86cd9eb27 --- /dev/null +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -0,0 +1,767 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import Linear +from mmengine.model import constant_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.losses import QualityFocalLoss +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from mmdet.utils import InstanceList, reduce_mean +from ..layers import inverse_sigmoid +from .atss_vlfusion_head import convert_grounding_to_cls_scores +from .dino_head import DINOHead + + +class ContrastiveEmbed(nn.Module): + """text visual ContrastiveEmbed layer. + + Args: + max_text_len (int, optional): Maximum length of text. + log_scale (Optional[Union[str, float]]): The initial value of a + learnable parameter to multiply with the similarity + matrix to normalize the output. Defaults to 0.0. + - If set to 'auto', the similarity matrix will be normalized by + a fixed value ``sqrt(d_c)`` where ``d_c`` is the channel number. + - If set to 'none' or ``None``, there is no normalization applied. + - If set to a float number, the similarity matrix will be multiplied + by ``exp(log_scale)``, where ``log_scale`` is learnable. + bias (bool, optional): Whether to add bias to the output. + If set to ``True``, a learnable bias that is initialized as -4.6 + will be added to the output. Useful when training from scratch. + Defaults to False. + """ + + def __init__(self, + max_text_len: int = 256, + log_scale: Optional[Union[str, float]] = None, + bias: bool = False): + super().__init__() + self.max_text_len = max_text_len + self.log_scale = log_scale + if isinstance(log_scale, float): + self.log_scale = nn.Parameter( + torch.Tensor([float(log_scale)]), requires_grad=True) + elif log_scale not in ['auto', 'none', None]: + raise ValueError(f'log_scale should be one of ' + f'"auto", "none", None, but got {log_scale}') + + self.bias = None + if bias: + bias_value = -math.log((1 - 0.01) / 0.01) + self.bias = nn.Parameter( + torch.Tensor([bias_value]), requires_grad=True) + + def forward(self, visual_feat: Tensor, text_feat: Tensor, + text_token_mask: Tensor) -> Tensor: + """Forward function. + + Args: + visual_feat (Tensor): Visual features. + text_feat (Tensor): Text features. + text_token_mask (Tensor): A mask used for text feats. + + Returns: + Tensor: Classification score. + """ + res = visual_feat @ text_feat.transpose(-1, -2) + if isinstance(self.log_scale, nn.Parameter): + res = res * self.log_scale.exp() + elif self.log_scale == 'auto': + # NOTE: similar to the normalizer in self-attention + res = res / math.sqrt(visual_feat.shape[-1]) + if self.bias is not None: + res = res + self.bias + res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) + + new_res = torch.full((*res.shape[:-1], self.max_text_len), + float('-inf'), + device=res.device) + new_res[..., :res.shape[-1]] = res + + return new_res + + +@MODELS.register_module() +class GroundingDINOHead(DINOHead): + """Head of the Grounding DINO: Marrying DINO with Grounded Pre-Training for + Open-Set Object Detection. + + Args: + contrastive_cfg (dict, optional): Contrastive config that contains + keys like ``max_text_len``. Defaults to dict(max_text_len=256). + """ + + def __init__(self, contrastive_cfg=dict(max_text_len=256), **kwargs): + self.contrastive_cfg = contrastive_cfg + self.max_text_len = contrastive_cfg.get('max_text_len', 256) + super().__init__(**kwargs) + + def _init_layers(self) -> None: + """Initialize classification branch and regression branch of head.""" + fc_cls = ContrastiveEmbed(**self.contrastive_cfg) + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, 4)) + reg_branch = nn.Sequential(*reg_branch) + + # NOTE: due to the fc_cls is a contrastive embedding and don't + # have any trainable parameters,we do not need to copy it. + if self.share_pred_layer: + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(self.num_pred_layer)]) + else: + self.cls_branches = nn.ModuleList( + [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList([ + copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) + ]) + + def init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + for m in self.reg_branches: + constant_init(m[-1], 0, bias=0) + nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) + if self.as_two_stage: + for m in self.reg_branches: + nn.init.constant_(m[-1].bias.data[2:], 0.0) + + def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> tuple: + """Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_queries, 4]. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + img_h, img_w = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + num_bboxes = bbox_pred.size(0) + # convert bbox_pred from xywh, normalized to xyxy, unnormalized + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + bbox_pred = bbox_pred * factor + + pred_instances = InstanceData(scores=cls_score, bboxes=bbox_pred) + # assigner and sampler + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=gt_instances, + img_meta=img_meta) + gt_bboxes = gt_instances.bboxes + + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long(), :] + + # Major changes. The labels are 0-1 binary labels for each bbox + # and text tokens. + labels = gt_bboxes.new_full((num_bboxes, self.max_text_len), + 0, + dtype=torch.float32) + labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred, dtype=gt_bboxes.dtype) + bbox_weights = torch.zeros_like(bbox_pred, dtype=gt_bboxes.dtype) + bbox_weights[pos_inds] = 1.0 + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + pos_gt_bboxes_normalized = pos_gt_bboxes / factor + pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + bbox_targets[pos_inds] = pos_gt_bboxes_targets + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + def forward( + self, + hidden_states: Tensor, + references: List[Tensor], + memory_text: Tensor, + text_token_mask: Tensor, + ) -> Tuple[Tensor]: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries, dim). + references (List[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + text_token_mask (Tensor): Text token mask. It has shape (bs, + len_text). + + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - all_layers_outputs_classes (Tensor): Outputs from the + classification head, has shape (num_decoder_layers, bs, + num_queries, cls_out_channels). + - all_layers_outputs_coords (Tensor): Sigmoid outputs from the + regression head with normalized coordinate format (cx, cy, w, + h), has shape (num_decoder_layers, bs, num_queries, 4) with the + last dimension arranged as (cx, cy, w, h). + """ + all_layers_outputs_classes = [] + all_layers_outputs_coords = [] + + for layer_id in range(hidden_states.shape[0]): + reference = inverse_sigmoid(references[layer_id]) + # NOTE The last reference will not be used. + hidden_state = hidden_states[layer_id] + outputs_class = self.cls_branches[layer_id](hidden_state, + memory_text, + text_token_mask) + tmp_reg_preds = self.reg_branches[layer_id](hidden_state) + if reference.shape[-1] == 4: + # When `layer` is 0 and `as_two_stage` of the detector + # is `True`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `True`. + tmp_reg_preds += reference + else: + # When `layer` is 0 and `as_two_stage` of the detector + # is `False`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `False`. + assert reference.shape[-1] == 2 + tmp_reg_preds[..., :2] += reference + outputs_coord = tmp_reg_preds.sigmoid() + all_layers_outputs_classes.append(outputs_class) + all_layers_outputs_coords.append(outputs_coord) + + all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) + all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) + + return all_layers_outputs_classes, all_layers_outputs_coords + + def predict(self, + hidden_states: Tensor, + references: List[Tensor], + memory_text: Tensor, + text_token_mask: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_queries, bs, dim). + references (List[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + text_token_mask (Tensor): Text token mask. It has shape (bs, + len_text). + batch_data_samples (SampleList): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. + + Returns: + InstanceList: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + batch_token_positive_maps = [ + data_samples.token_positive_map + for data_samples in batch_data_samples + ] + + outs = self(hidden_states, references, memory_text, text_token_mask) + + predictions = self.predict_by_feat( + *outs, + batch_img_metas=batch_img_metas, + batch_token_positive_maps=batch_token_positive_maps, + rescale=rescale) + return predictions + + def predict_by_feat(self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + batch_img_metas: List[Dict], + batch_token_positive_maps: Optional[List[dict]] = None, + rescale: bool = False) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, num_queries, + cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, + 4) with the last dimension arranged as (cx, cy, w, h). + batch_img_metas (List[Dict]): _description_ + batch_token_positive_maps (list[dict], Optional): Batch token + positive map. Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cls_scores = all_layers_cls_scores[-1] + bbox_preds = all_layers_bbox_preds[-1] + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_meta = batch_img_metas[img_id] + token_positive_maps = batch_token_positive_maps[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + token_positive_maps, + img_meta, rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score: Tensor, + bbox_pred: Tensor, + token_positive_maps: dict, + img_meta: dict, + rescale: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (cx, cy, w, h) and + shape [num_queries, 4]. + token_positive_maps (dict): Token positive map. + img_meta (dict): Image meta info. + rescale (bool, optional): If True, return boxes in original image + space. Default True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_score) == len(bbox_pred) # num_queries + max_per_img = self.test_cfg.get('max_per_img', len(cls_score)) + img_shape = img_meta['img_shape'] + + cls_score = convert_grounding_to_cls_scores( + logits=cls_score.sigmoid()[None], + positive_maps=[token_positive_maps])[0] + scores, indexes = cls_score.view(-1).topk(max_per_img) + num_classes = cls_score.shape[-1] + det_labels = indexes % num_classes + bbox_index = indexes // num_classes + bbox_pred = bbox_pred[bbox_index] + + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + if rescale: + assert img_meta.get('scale_factor') is not None + det_bboxes /= det_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + results = InstanceData() + results.bboxes = det_bboxes + results.scores = scores + results.labels = det_labels + return results + + def loss(self, hidden_states: Tensor, references: List[Tensor], + memory_text: Tensor, text_token_mask: Tensor, + enc_outputs_class: Tensor, enc_outputs_coord: Tensor, + batch_data_samples: SampleList, dn_meta: Dict[str, int]) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries_total, + dim), where `num_queries_total` is the sum of + `num_denoising_queries` and `num_matching_queries` when + `self.training` is `True`, else `num_matching_queries`. + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries_total, 4) and each `inter_reference` has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + enc_outputs_class (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + enc_outputs_coord (Tensor): The proposal generate from the + encode feature map, has shape (bs, num_feat_points, 4) with the + last dimension arranged as (cx, cy, w, h). + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(hidden_states, references, memory_text, text_token_mask) + self.text_masks = text_token_mask + loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, + batch_gt_instances, batch_img_metas, dn_meta) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images, has shape (bs, num_queries, cls_out_channels). + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape (bs, num_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + with torch.no_grad(): + cls_reg_targets = self.get_targets(cls_scores_list, + bbox_preds_list, + batch_gt_instances, + batch_img_metas) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.stack(labels_list, 0) + label_weights = torch.stack(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # ===== this change ===== + # Loss is not computed for the padded regions of the text. + assert (self.text_masks.dim() == 2) + text_masks = self.text_masks.new_zeros( + (self.text_masks.size(0), self.max_text_len)) + text_masks[:, :self.text_masks.size(1)] = self.text_masks + text_mask = (text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, cls_scores.size(1), 1) + cls_scores = torch.masked_select(cls_scores, text_mask).contiguous() + + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., + None].repeat(1, 1, text_mask.size(-1)) + label_weights = torch.masked_select(label_weights, text_mask) + + # classification loss + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + if isinstance(self.loss_cls, QualityFocalLoss): + raise NotImplementedError( + 'QualityFocalLoss for GroundingDINOHead is not supported yet.') + else: + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds): + img_h, img_w, = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou + + def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + dn_meta: Dict[str, int]) -> Tuple[Tensor]: + """Denoising loss for outputs from a single decoder layer. + + Args: + dn_cls_scores (Tensor): Classification scores of a single decoder + layer in denoising part, has shape (bs, num_denoising_queries, + cls_out_channels). + dn_bbox_preds (Tensor): Regression outputs of a single decoder + layer in denoising part. Each is a 4D-tensor with normalized + coordinate format (cx, cy, w, h) and has shape + (bs, num_denoising_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + cls_reg_targets = self.get_dn_targets(batch_gt_instances, + batch_img_metas, dn_meta) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.stack(labels_list, 0) + label_weights = torch.stack(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + # ===== this change ===== + # Loss is not computed for the padded regions of the text. + assert (self.text_masks.dim() == 2) + text_masks = self.text_masks.new_zeros( + (self.text_masks.size(0), self.max_text_len)) + text_masks[:, :self.text_masks.size(1)] = self.text_masks + text_mask = (text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, dn_cls_scores.size(1), 1) + cls_scores = torch.masked_select(dn_cls_scores, text_mask).contiguous() + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., + None].repeat(1, 1, text_mask.size(-1)) + label_weights = torch.masked_select(label_weights, text_mask) + # ======================= + + # classification loss + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = \ + num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + if len(cls_scores) > 0: + if isinstance(self.loss_cls, QualityFocalLoss): + raise NotImplementedError('QualityFocalLoss is not supported') + else: + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=cls_avg_factor) + else: + loss_cls = torch.zeros( + 1, dtype=cls_scores.dtype, device=cls_scores.device) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds): + img_h, img_w = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = dn_bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou + + def _get_dn_targets_single(self, gt_instances: InstanceData, + img_meta: dict, dn_meta: Dict[str, + int]) -> tuple: + """Get targets in denoising part for one image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + num_groups = dn_meta['num_denoising_groups'] + num_denoising_queries = dn_meta['num_denoising_queries'] + num_queries_each_group = int(num_denoising_queries / num_groups) + device = gt_bboxes.device + + if len(gt_labels) > 0: + t = torch.arange(len(gt_labels), dtype=torch.long, device=device) + t = t.unsqueeze(0).repeat(num_groups, 1) + pos_assigned_gt_inds = t.flatten() + pos_inds = torch.arange( + num_groups, dtype=torch.long, device=device) + pos_inds = pos_inds.unsqueeze(1) * num_queries_each_group + t + pos_inds = pos_inds.flatten() + else: + pos_inds = pos_assigned_gt_inds = \ + gt_bboxes.new_tensor([], dtype=torch.long) + + neg_inds = pos_inds + num_queries_each_group // 2 + # label targets + # this change + labels = gt_bboxes.new_full((num_denoising_queries, self.max_text_len), + 0, + dtype=torch.float32) + labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_denoising_queries) + + # bbox targets + bbox_targets = torch.zeros(num_denoising_queries, 4, device=device) + bbox_weights = torch.zeros(num_denoising_queries, 4, device=device) + bbox_weights[pos_inds] = 1.0 + img_h, img_w = img_meta['img_shape'] + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + factor = gt_bboxes.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + gt_bboxes_normalized = gt_bboxes / factor + gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized) + bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1]) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) diff --git a/mmdet/models/dense_heads/guided_anchor_head.py b/mmdet/models/dense_heads/guided_anchor_head.py new file mode 100644 index 0000000000000000000000000000000000000000..59f6dd3336e66065dc88b702e925965d4089c72f --- /dev/null +++ b/mmdet/models/dense_heads/guided_anchor_head.py @@ -0,0 +1,994 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.ops import DeformConv2d, MaskedConv2d +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList) +from ..layers import multiclass_nms +from ..task_modules.prior_generators import anchor_inside_flags, calc_region +from ..task_modules.samplers import PseudoSampler +from ..utils import images_to_levels, multi_apply, unmap +from .anchor_head import AnchorHead + + +class FeatureAdaption(BaseModule): + """Feature Adaption Module. + + Feature Adaption Module is implemented based on DCN v1. + It uses anchor shape prediction rather than feature map to + predict offsets of deform conv layer. + + Args: + in_channels (int): Number of channels in the input feature map. + out_channels (int): Number of channels in the output feature map. + kernel_size (int): Deformable conv kernel size. Defaults to 3. + deform_groups (int): Deformable conv group size. Defaults to 4. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \ + list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + deform_groups: int = 4, + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.1, + override=dict(type='Normal', name='conv_adaption', std=0.01)) + ) -> None: + super().__init__(init_cfg=init_cfg) + offset_channels = kernel_size * kernel_size * 2 + self.conv_offset = nn.Conv2d( + 2, deform_groups * offset_channels, 1, bias=False) + self.conv_adaption = DeformConv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + deform_groups=deform_groups) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: Tensor, shape: Tensor) -> Tensor: + offset = self.conv_offset(shape.detach()) + x = self.relu(self.conv_adaption(x, offset)) + return x + + +@MODELS.register_module() +class GuidedAnchorHead(AnchorHead): + """Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.). + + This GuidedAnchorHead will predict high-quality feature guided + anchors and locations where anchors will be kept in inference. + There are mainly 3 categories of bounding-boxes. + + - Sampled 9 pairs for target assignment. (approxes) + - The square boxes where the predicted anchors are based on. (squares) + - Guided anchors. + + Please refer to https://arxiv.org/abs/1901.03278 for more details. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Defaults to 256. + approx_anchor_generator (:obj:`ConfigDict` or dict): Config dict + for approx generator + square_anchor_generator (:obj:`ConfigDict` or dict): Config dict + for square generator + anchor_coder (:obj:`ConfigDict` or dict): Config dict for anchor coder + bbox_coder (:obj:`ConfigDict` or dict): Config dict for bbox coder + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Defaults to False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + deform_groups: (int): Group number of DCN in FeatureAdaption module. + Defaults to 4. + loc_filter_thr (float): Threshold to filter out unconcerned regions. + Defaults to 0.01. + loss_loc (:obj:`ConfigDict` or dict): Config of location loss. + loss_shape (:obj:`ConfigDict` or dict): Config of anchor shape loss. + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of bbox regression loss. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \ + list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + approx_anchor_generator: ConfigType = dict( + type='AnchorGenerator', + octave_base_scale=8, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + square_anchor_generator: ConfigType = dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[8], + strides=[4, 8, 16, 32, 64]), + anchor_coder: ConfigType = dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + bbox_coder: ConfigType = dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + reg_decoded_bbox: bool = False, + deform_groups: int = 4, + loc_filter_thr: float = 0.01, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + loss_loc: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_shape: ConfigType = dict( + type='BoundedIoULoss', beta=0.2, loss_weight=1.0), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='SmoothL1Loss', beta=1.0, loss_weight=1.0), + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', name='conv_loc', std=0.01, lbias_prob=0.01)) + ) -> None: + super(AnchorHead, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_classes = num_classes + self.feat_channels = feat_channels + self.deform_groups = deform_groups + self.loc_filter_thr = loc_filter_thr + + # build approx_anchor_generator and square_anchor_generator + assert (approx_anchor_generator['octave_base_scale'] == + square_anchor_generator['scales'][0]) + assert (approx_anchor_generator['strides'] == + square_anchor_generator['strides']) + self.approx_anchor_generator = TASK_UTILS.build( + approx_anchor_generator) + self.square_anchor_generator = TASK_UTILS.build( + square_anchor_generator) + self.approxs_per_octave = self.approx_anchor_generator \ + .num_base_priors[0] + + self.reg_decoded_bbox = reg_decoded_bbox + + # one anchor per location + self.num_base_priors = self.square_anchor_generator.num_base_priors[0] + + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + self.loc_focal_loss = loss_loc['type'] in ['FocalLoss'] + if self.use_sigmoid_cls: + self.cls_out_channels = self.num_classes + else: + self.cls_out_channels = self.num_classes + 1 + + # build bbox_coder + self.anchor_coder = TASK_UTILS.build(anchor_coder) + self.bbox_coder = TASK_UTILS.build(bbox_coder) + + # build losses + self.loss_loc = MODELS.build(loss_loc) + self.loss_shape = MODELS.build(loss_shape) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + # use PseudoSampler when no sampler in train_cfg + if train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler() + + self.ga_assigner = TASK_UTILS.build(self.train_cfg['ga_assigner']) + if train_cfg.get('ga_sampler', None) is not None: + self.ga_sampler = TASK_UTILS.build( + self.train_cfg['ga_sampler'], + default_args=dict(context=self)) + else: + self.ga_sampler = PseudoSampler() + + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.conv_loc = nn.Conv2d(self.in_channels, 1, 1) + self.conv_shape = nn.Conv2d(self.in_channels, self.num_base_priors * 2, + 1) + self.feature_adaption = FeatureAdaption( + self.in_channels, + self.feat_channels, + kernel_size=3, + deform_groups=self.deform_groups) + self.conv_cls = MaskedConv2d( + self.feat_channels, self.num_base_priors * self.cls_out_channels, + 1) + self.conv_reg = MaskedConv2d(self.feat_channels, + self.num_base_priors * 4, 1) + + def forward_single(self, x: Tensor) -> Tuple[Tensor]: + """Forward feature of a single scale level.""" + loc_pred = self.conv_loc(x) + shape_pred = self.conv_shape(x) + x = self.feature_adaption(x, shape_pred) + # masked conv is only used during inference for speed-up + if not self.training: + mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr + else: + mask = None + cls_score = self.conv_cls(x, mask) + bbox_pred = self.conv_reg(x, mask) + return cls_score, bbox_pred, shape_pred, loc_pred + + def forward(self, x: List[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network.""" + return multi_apply(self.forward_single, x) + + def get_sampled_approxs(self, + featmap_sizes: List[Tuple[int, int]], + batch_img_metas: List[dict], + device: str = 'cuda') -> tuple: + """Get sampled approxs and inside flags according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + batch_img_metas (list[dict]): Image meta info. + device (str): device for returned tensors + + Returns: + tuple: approxes of each image, inside flags of each image + """ + num_imgs = len(batch_img_metas) + + # since feature map sizes of all images are the same, we only compute + # approxes for one time + multi_level_approxs = self.approx_anchor_generator.grid_priors( + featmap_sizes, device=device) + approxs_list = [multi_level_approxs for _ in range(num_imgs)] + + # for each image, we compute inside flags of multi level approxes + inside_flag_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + multi_level_flags = [] + multi_level_approxs = approxs_list[img_id] + + # obtain valid flags for each approx first + multi_level_approx_flags = self.approx_anchor_generator \ + .valid_flags(featmap_sizes, + img_meta['pad_shape'], + device=device) + + for i, flags in enumerate(multi_level_approx_flags): + approxs = multi_level_approxs[i] + inside_flags_list = [] + for j in range(self.approxs_per_octave): + split_valid_flags = flags[j::self.approxs_per_octave] + split_approxs = approxs[j::self.approxs_per_octave, :] + inside_flags = anchor_inside_flags( + split_approxs, split_valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + inside_flags_list.append(inside_flags) + # inside_flag for a position is true if any anchor in this + # position is true + inside_flags = ( + torch.stack(inside_flags_list, 0).sum(dim=0) > 0) + multi_level_flags.append(inside_flags) + inside_flag_list.append(multi_level_flags) + return approxs_list, inside_flag_list + + def get_anchors(self, + featmap_sizes: List[Tuple[int, int]], + shape_preds: List[Tensor], + loc_preds: List[Tensor], + batch_img_metas: List[dict], + use_loc_filter: bool = False, + device: str = 'cuda') -> tuple: + """Get squares according to feature map sizes and guided anchors. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + shape_preds (list[tensor]): Multi-level shape predictions. + loc_preds (list[tensor]): Multi-level location predictions. + batch_img_metas (list[dict]): Image meta info. + use_loc_filter (bool): Use loc filter or not. Defaults to False + device (str): device for returned tensors. + Defaults to `cuda`. + + Returns: + tuple: square approxs of each image, guided anchors of each image, + loc masks of each image. + """ + num_imgs = len(batch_img_metas) + num_levels = len(featmap_sizes) + + # since feature map sizes of all images are the same, we only compute + # squares for one time + multi_level_squares = self.square_anchor_generator.grid_priors( + featmap_sizes, device=device) + squares_list = [multi_level_squares for _ in range(num_imgs)] + + # for each image, we compute multi level guided anchors + guided_anchors_list = [] + loc_mask_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + multi_level_guided_anchors = [] + multi_level_loc_mask = [] + for i in range(num_levels): + squares = squares_list[img_id][i] + shape_pred = shape_preds[i][img_id] + loc_pred = loc_preds[i][img_id] + guided_anchors, loc_mask = self._get_guided_anchors_single( + squares, + shape_pred, + loc_pred, + use_loc_filter=use_loc_filter) + multi_level_guided_anchors.append(guided_anchors) + multi_level_loc_mask.append(loc_mask) + guided_anchors_list.append(multi_level_guided_anchors) + loc_mask_list.append(multi_level_loc_mask) + return squares_list, guided_anchors_list, loc_mask_list + + def _get_guided_anchors_single( + self, + squares: Tensor, + shape_pred: Tensor, + loc_pred: Tensor, + use_loc_filter: bool = False) -> Tuple[Tensor]: + """Get guided anchors and loc masks for a single level. + + Args: + squares (tensor): Squares of a single level. + shape_pred (tensor): Shape predictions of a single level. + loc_pred (tensor): Loc predictions of a single level. + use_loc_filter (list[tensor]): Use loc filter or not. + Defaults to False. + + Returns: + tuple: guided anchors, location masks + """ + # calculate location filtering mask + loc_pred = loc_pred.sigmoid().detach() + if use_loc_filter: + loc_mask = loc_pred >= self.loc_filter_thr + else: + loc_mask = loc_pred >= 0.0 + mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_base_priors) + mask = mask.contiguous().view(-1) + # calculate guided anchors + squares = squares[mask] + anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view( + -1, 2).detach()[mask] + bbox_deltas = anchor_deltas.new_full(squares.size(), 0) + bbox_deltas[:, 2:] = anchor_deltas + guided_anchors = self.anchor_coder.decode( + squares, bbox_deltas, wh_ratio_clip=1e-6) + return guided_anchors, mask + + def ga_loc_targets(self, batch_gt_instances: InstanceList, + featmap_sizes: List[Tuple[int, int]]) -> tuple: + """Compute location targets for guided anchoring. + + Each feature map is divided into positive, negative and ignore regions. + - positive regions: target 1, weight 1 + - ignore regions: target 0, weight 0 + - negative regions: target 0, weight 0.1 + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + featmap_sizes (list[tuple]): Multi level sizes of each feature + maps. + + Returns: + tuple: Returns a tuple containing location targets. + """ + anchor_scale = self.approx_anchor_generator.octave_base_scale + anchor_strides = self.approx_anchor_generator.strides + # Currently only supports same stride in x and y direction. + for stride in anchor_strides: + assert (stride[0] == stride[1]) + anchor_strides = [stride[0] for stride in anchor_strides] + + center_ratio = self.train_cfg['center_ratio'] + ignore_ratio = self.train_cfg['ignore_ratio'] + img_per_gpu = len(batch_gt_instances) + num_lvls = len(featmap_sizes) + r1 = (1 - center_ratio) / 2 + r2 = (1 - ignore_ratio) / 2 + all_loc_targets = [] + all_loc_weights = [] + all_ignore_map = [] + for lvl_id in range(num_lvls): + h, w = featmap_sizes[lvl_id] + loc_targets = torch.zeros( + img_per_gpu, + 1, + h, + w, + device=batch_gt_instances[0].bboxes.device, + dtype=torch.float32) + loc_weights = torch.full_like(loc_targets, -1) + ignore_map = torch.zeros_like(loc_targets) + all_loc_targets.append(loc_targets) + all_loc_weights.append(loc_weights) + all_ignore_map.append(ignore_map) + for img_id in range(img_per_gpu): + gt_bboxes = batch_gt_instances[img_id].bboxes + scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * + (gt_bboxes[:, 3] - gt_bboxes[:, 1])) + min_anchor_size = scale.new_full( + (1, ), float(anchor_scale * anchor_strides[0])) + # assign gt bboxes to different feature levels w.r.t. their scales + target_lvls = torch.floor( + torch.log2(scale) - torch.log2(min_anchor_size) + 0.5) + target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long() + for gt_id in range(gt_bboxes.size(0)): + lvl = target_lvls[gt_id].item() + # rescaled to corresponding feature map + gt_ = gt_bboxes[gt_id, :4] / anchor_strides[lvl] + # calculate ignore regions + ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region( + gt_, r2, featmap_sizes[lvl]) + # calculate positive (center) regions + ctr_x1, ctr_y1, ctr_x2, ctr_y2 = calc_region( + gt_, r1, featmap_sizes[lvl]) + all_loc_targets[lvl][img_id, 0, ctr_y1:ctr_y2 + 1, + ctr_x1:ctr_x2 + 1] = 1 + all_loc_weights[lvl][img_id, 0, ignore_y1:ignore_y2 + 1, + ignore_x1:ignore_x2 + 1] = 0 + all_loc_weights[lvl][img_id, 0, ctr_y1:ctr_y2 + 1, + ctr_x1:ctr_x2 + 1] = 1 + # calculate ignore map on nearby low level feature + if lvl > 0: + d_lvl = lvl - 1 + # rescaled to corresponding feature map + gt_ = gt_bboxes[gt_id, :4] / anchor_strides[d_lvl] + ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region( + gt_, r2, featmap_sizes[d_lvl]) + all_ignore_map[d_lvl][img_id, 0, ignore_y1:ignore_y2 + 1, + ignore_x1:ignore_x2 + 1] = 1 + # calculate ignore map on nearby high level feature + if lvl < num_lvls - 1: + u_lvl = lvl + 1 + # rescaled to corresponding feature map + gt_ = gt_bboxes[gt_id, :4] / anchor_strides[u_lvl] + ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region( + gt_, r2, featmap_sizes[u_lvl]) + all_ignore_map[u_lvl][img_id, 0, ignore_y1:ignore_y2 + 1, + ignore_x1:ignore_x2 + 1] = 1 + for lvl_id in range(num_lvls): + # ignore negative regions w.r.t. ignore map + all_loc_weights[lvl_id][(all_loc_weights[lvl_id] < 0) + & (all_ignore_map[lvl_id] > 0)] = 0 + # set negative regions with weight 0.1 + all_loc_weights[lvl_id][all_loc_weights[lvl_id] < 0] = 0.1 + # loc average factor to balance loss + loc_avg_factor = sum( + [t.size(0) * t.size(-1) * t.size(-2) + for t in all_loc_targets]) / 200 + return all_loc_targets, all_loc_weights, loc_avg_factor + + def _ga_shape_target_single(self, + flat_approxs: Tensor, + inside_flags: Tensor, + flat_squares: Tensor, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData], + img_meta: dict, + unmap_outputs: bool = True) -> tuple: + """Compute guided anchoring targets. + + This function returns sampled anchors and gt bboxes directly + rather than calculates regression targets. + + Args: + flat_approxs (Tensor): flat approxs of a single image, + shape (n, 4) + inside_flags (Tensor): inside flags of a single image, + shape (n, ). + flat_squares (Tensor): flat squares of a single image, + shape (approxs_per_octave * n, 4) + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + img_meta (dict): Meta info of a single image. + unmap_outputs (bool): unmap outputs or not. + + Returns: + tuple: Returns a tuple containing shape targets of each image. + """ + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + num_square = flat_squares.size(0) + approxs = flat_approxs.view(num_square, self.approxs_per_octave, 4) + approxs = approxs[inside_flags, ...] + squares = flat_squares[inside_flags, :] + + pred_instances = InstanceData() + pred_instances.priors = squares + pred_instances.approxs = approxs + + assign_result = self.ga_assigner.assign( + pred_instances=pred_instances, + gt_instances=gt_instances, + gt_instances_ignore=gt_instances_ignore) + sampling_result = self.ga_sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + + bbox_anchors = torch.zeros_like(squares) + bbox_gts = torch.zeros_like(squares) + bbox_weights = torch.zeros_like(squares) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + bbox_anchors[pos_inds, :] = sampling_result.pos_bboxes + bbox_gts[pos_inds, :] = sampling_result.pos_gt_bboxes + bbox_weights[pos_inds, :] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_squares.size(0) + bbox_anchors = unmap(bbox_anchors, num_total_anchors, inside_flags) + bbox_gts = unmap(bbox_gts, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (bbox_anchors, bbox_gts, bbox_weights, pos_inds, neg_inds, + sampling_result) + + def ga_shape_targets(self, + approx_list: List[List[Tensor]], + inside_flag_list: List[List[Tensor]], + square_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Compute guided anchoring targets. + + Args: + approx_list (list[list[Tensor]]): Multi level approxs of each + image. + inside_flag_list (list[list[Tensor]]): Multi level inside flags + of each image. + square_list (list[list[Tensor]]): Multi level squares of each + image. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): unmap outputs or not. Defaults to None. + + Returns: + tuple: Returns a tuple containing shape targets. + """ + num_imgs = len(batch_img_metas) + assert len(approx_list) == len(inside_flag_list) == len( + square_list) == num_imgs + # anchor number of multi levels + num_level_squares = [squares.size(0) for squares in square_list[0]] + # concat all level anchors and flags to a single tensor + inside_flag_flat_list = [] + approx_flat_list = [] + square_flat_list = [] + for i in range(num_imgs): + assert len(square_list[i]) == len(inside_flag_list[i]) + inside_flag_flat_list.append(torch.cat(inside_flag_list[i])) + approx_flat_list.append(torch.cat(approx_list[i])) + square_flat_list.append(torch.cat(square_list[i])) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None for _ in range(num_imgs)] + (all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list, + neg_inds_list, sampling_results_list) = multi_apply( + self._ga_shape_target_single, + approx_flat_list, + inside_flag_flat_list, + square_flat_list, + batch_gt_instances, + batch_gt_instances_ignore, + batch_img_metas, + unmap_outputs=unmap_outputs) + # sampled anchors of all images + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + bbox_anchors_list = images_to_levels(all_bbox_anchors, + num_level_squares) + bbox_gts_list = images_to_levels(all_bbox_gts, num_level_squares) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_squares) + return (bbox_anchors_list, bbox_gts_list, bbox_weights_list, + avg_factor) + + def loss_shape_single(self, shape_pred: Tensor, bbox_anchors: Tensor, + bbox_gts: Tensor, anchor_weights: Tensor, + avg_factor: int) -> Tensor: + """Compute shape loss in single level.""" + shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2) + bbox_anchors = bbox_anchors.contiguous().view(-1, 4) + bbox_gts = bbox_gts.contiguous().view(-1, 4) + anchor_weights = anchor_weights.contiguous().view(-1, 4) + bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0) + bbox_deltas[:, 2:] += shape_pred + # filter out negative samples to speed-up weighted_bounded_iou_loss + inds = torch.nonzero( + anchor_weights[:, 0] > 0, as_tuple=False).squeeze(1) + bbox_deltas_ = bbox_deltas[inds] + bbox_anchors_ = bbox_anchors[inds] + bbox_gts_ = bbox_gts[inds] + anchor_weights_ = anchor_weights[inds] + pred_anchors_ = self.anchor_coder.decode( + bbox_anchors_, bbox_deltas_, wh_ratio_clip=1e-6) + loss_shape = self.loss_shape( + pred_anchors_, bbox_gts_, anchor_weights_, avg_factor=avg_factor) + return loss_shape + + def loss_loc_single(self, loc_pred: Tensor, loc_target: Tensor, + loc_weight: Tensor, avg_factor: float) -> Tensor: + """Compute location loss in single level.""" + loss_loc = self.loss_loc( + loc_pred.reshape(-1, 1), + loc_target.reshape(-1).long(), + loc_weight.reshape(-1), + avg_factor=avg_factor) + return loss_loc + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + shape_preds: List[Tensor], + loc_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + shape_preds (list[Tensor]): shape predictions for each scale + level with shape (N, 1, H, W). + loc_preds (list[Tensor]): location predictions for each scale + level with shape (N, num_anchors * 2, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.approx_anchor_generator.num_levels + + device = cls_scores[0].device + + # get loc targets + loc_targets, loc_weights, loc_avg_factor = self.ga_loc_targets( + batch_gt_instances, featmap_sizes) + + # get sampled approxes + approxs_list, inside_flag_list = self.get_sampled_approxs( + featmap_sizes, batch_img_metas, device=device) + # get squares and guided anchors + squares_list, guided_anchors_list, _ = self.get_anchors( + featmap_sizes, + shape_preds, + loc_preds, + batch_img_metas, + device=device) + + # get shape targets + shape_targets = self.ga_shape_targets(approxs_list, inside_flag_list, + squares_list, batch_gt_instances, + batch_img_metas) + (bbox_anchors_list, bbox_gts_list, anchor_weights_list, + ga_avg_factor) = shape_targets + + # get anchor targets + cls_reg_targets = self.get_targets( + guided_anchors_list, + inside_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor) = cls_reg_targets + + # anchor number of multi levels + num_level_anchors = [ + anchors.size(0) for anchors in guided_anchors_list[0] + ] + # concat all level anchors to a single tensor + concat_anchor_list = [] + for i in range(len(guided_anchors_list)): + concat_anchor_list.append(torch.cat(guided_anchors_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + + # get classification and bbox regression losses + losses_cls, losses_bbox = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + all_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + avg_factor=avg_factor) + + # get anchor location loss + losses_loc = [] + for i in range(len(loc_preds)): + loss_loc = self.loss_loc_single( + loc_preds[i], + loc_targets[i], + loc_weights[i], + avg_factor=loc_avg_factor) + losses_loc.append(loss_loc) + + # get anchor shape loss + losses_shape = [] + for i in range(len(shape_preds)): + loss_shape = self.loss_shape_single( + shape_preds[i], + bbox_anchors_list[i], + bbox_gts_list[i], + anchor_weights_list[i], + avg_factor=ga_avg_factor) + losses_shape.append(loss_shape) + + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_shape=losses_shape, + loss_loc=losses_loc) + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + shape_preds: List[Tensor], + loc_preds: List[Tensor], + batch_img_metas: List[dict], + cfg: OptConfigType = None, + rescale: bool = False) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + shape_preds (list[Tensor]): shape predictions for each scale + level with shape (N, 1, H, W). + loc_preds (list[Tensor]): location predictions for each scale + level with shape (N, num_anchors * 2, H, W). + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), the last + dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len( + loc_preds) + num_levels = len(cls_scores) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + device = cls_scores[0].device + # get guided anchors + _, guided_anchors, loc_masks = self.get_anchors( + featmap_sizes, + shape_preds, + loc_preds, + batch_img_metas, + use_loc_filter=not self.training, + device=device) + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + guided_anchor_list = [ + guided_anchors[img_id][i].detach() for i in range(num_levels) + ] + loc_mask_list = [ + loc_masks[img_id][i].detach() for i in range(num_levels) + ] + proposals = self._predict_by_feat_single( + cls_scores=cls_score_list, + bbox_preds=bbox_pred_list, + mlvl_anchors=guided_anchor_list, + mlvl_masks=loc_mask_list, + img_meta=batch_img_metas[img_id], + cfg=cfg, + rescale=rescale) + result_list.append(proposals) + return result_list + + def _predict_by_feat_single(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + mlvl_anchors: List[Tensor], + mlvl_masks: List[Tensor], + img_meta: dict, + cfg: ConfigType, + rescale: bool = False) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + mlvl_anchors (list[Tensor]): Each element in the list is + the anchors of a single level in feature pyramid. it has + shape (num_priors, 4). + mlvl_masks (list[Tensor]): Each element in the list is location + masks of a single level. + img_meta (dict): Image meta info. + cfg (:obj:`ConfigDict` or dict): Test / postprocessing + configuration, if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), the last + dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + mlvl_bbox_preds = [] + mlvl_valid_anchors = [] + mlvl_scores = [] + for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds, + mlvl_anchors, + mlvl_masks): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + # if no location is kept, end. + if mask.sum() == 0: + continue + # reshape scores and bbox_pred + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1) + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + # filter scores, bbox_pred w.r.t. mask. + # anchors are filtered in get_anchors() beforehand. + scores = scores[mask, :] + bbox_pred = bbox_pred[mask, :] + if scores.dim() == 0: + anchors = anchors.unsqueeze(0) + scores = scores.unsqueeze(0) + bbox_pred = bbox_pred.unsqueeze(0) + # filter anchors, bbox_pred, scores w.r.t. scores + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + if self.use_sigmoid_cls: + max_scores, _ = scores.max(dim=1) + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + max_scores, _ = scores[:, :-1].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_anchors.append(anchors) + mlvl_scores.append(scores) + + mlvl_bbox_preds = torch.cat(mlvl_bbox_preds) + mlvl_anchors = torch.cat(mlvl_valid_anchors) + mlvl_scores = torch.cat(mlvl_scores) + mlvl_bboxes = self.bbox_coder.decode( + mlvl_anchors, mlvl_bbox_preds, max_shape=img_meta['img_shape']) + + if rescale: + assert img_meta.get('scale_factor') is not None + mlvl_bboxes /= mlvl_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + + if self.use_sigmoid_cls: + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + # multi class NMS + det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) + + results = InstanceData() + results.bboxes = det_bboxes[:, :-1] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + return results diff --git a/mmdet/models/dense_heads/lad_head.py b/mmdet/models/dense_heads/lad_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1218e1f88206704d4f414d151ccd34a189ac5d0 --- /dev/null +++ b/mmdet/models/dense_heads/lad_head.py @@ -0,0 +1,226 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import InstanceList, OptInstanceList +from ..utils import levels_to_images, multi_apply, unpack_gt_instances +from .paa_head import PAAHead + + +@MODELS.register_module() +class LADHead(PAAHead): + """Label Assignment Head from the paper: `Improving Object Detection by + Label Assignment Distillation `_""" + + def get_label_assignment( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + iou_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> tuple: + """Get label assignment (from teacher). + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + iou_preds (list[Tensor]): iou_preds for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + tuple: Returns a tuple containing label assignment variables. + + - labels (Tensor): Labels of all anchors, each with + shape (num_anchors,). + - labels_weight (Tensor): Label weights of all anchor. + each with shape (num_anchors,). + - bboxes_target (Tensor): BBox targets of all anchors. + each with shape (num_anchors, 4). + - bboxes_weight (Tensor): BBox weights of all anchors. + each with shape (num_anchors, 4). + - pos_inds_flatten (Tensor): Contains all index of positive + sample in all anchor. + - pos_anchors (Tensor): Positive anchors. + - num_pos (int): Number of positive anchors. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + ) + (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds, + pos_gt_index) = cls_reg_targets + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + bbox_preds = levels_to_images(bbox_preds) + bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] + pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, + cls_scores, bbox_preds, labels, + labels_weight, bboxes_target, + bboxes_weight, pos_inds) + + with torch.no_grad(): + reassign_labels, reassign_label_weight, \ + reassign_bbox_weights, num_pos = multi_apply( + self.paa_reassign, + pos_losses_list, + labels, + labels_weight, + bboxes_weight, + pos_inds, + pos_gt_index, + anchor_list) + num_pos = sum(num_pos) + # convert all tensor list to a flatten tensor + labels = torch.cat(reassign_labels, 0).view(-1) + flatten_anchors = torch.cat( + [torch.cat(item, 0) for item in anchor_list]) + labels_weight = torch.cat(reassign_label_weight, 0).view(-1) + bboxes_target = torch.cat(bboxes_target, + 0).view(-1, bboxes_target[0].size(-1)) + + pos_inds_flatten = ((labels >= 0) + & + (labels < self.num_classes)).nonzero().reshape(-1) + + if num_pos: + pos_anchors = flatten_anchors[pos_inds_flatten] + else: + pos_anchors = None + + label_assignment_results = (labels, labels_weight, bboxes_target, + bboxes_weight, pos_inds_flatten, + pos_anchors, num_pos) + return label_assignment_results + + def loss(self, x: List[Tensor], label_assignment_results: tuple, + batch_data_samples: SampleList) -> dict: + """Forward train with the available label assignment (student receives + from teacher). + + Args: + x (list[Tensor]): Features from FPN. + label_assignment_results (tuple): As the outputs defined in the + function `self.get_label_assignment`. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + losses: (dict[str, Tensor]): A dictionary of loss components. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + + outs = self(x) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat( + *loss_inputs, + batch_gt_instances_ignore=batch_gt_instances_ignore, + label_assignment_results=label_assignment_results) + return losses + + def loss_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + iou_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + label_assignment_results: Optional[tuple] = None) -> dict: + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + iou_preds (list[Tensor]): iou_preds for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + label_assignment_results (tuple, optional): As the outputs defined + in the function `self.get_ + label_assignment`. + + Returns: + dict[str, Tensor]: A dictionary of loss gmm_assignment. + """ + + (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds_flatten, + pos_anchors, num_pos) = label_assignment_results + + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + bbox_preds = levels_to_images(bbox_preds) + bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] + iou_preds = levels_to_images(iou_preds) + iou_preds = [item.reshape(-1, 1) for item in iou_preds] + + # convert all tensor list to a flatten tensor + cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) + bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) + iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1)) + + losses_cls = self.loss_cls( + cls_scores, + labels, + labels_weight, + avg_factor=max(num_pos, len(batch_img_metas))) # avoid num_pos=0 + if num_pos: + pos_bbox_pred = self.bbox_coder.decode( + pos_anchors, bbox_preds[pos_inds_flatten]) + pos_bbox_target = bboxes_target[pos_inds_flatten] + iou_target = bbox_overlaps( + pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True) + losses_iou = self.loss_centerness( + iou_preds[pos_inds_flatten], + iou_target.unsqueeze(-1), + avg_factor=num_pos) + losses_bbox = self.loss_bbox( + pos_bbox_pred, pos_bbox_target, avg_factor=num_pos) + + else: + losses_iou = iou_preds.sum() * 0 + losses_bbox = bbox_preds.sum() * 0 + + return dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou) diff --git a/mmdet/models/dense_heads/ld_head.py b/mmdet/models/dense_heads/ld_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2558fac97ee26ff89c5fa1b386f5ce68c3ad384d --- /dev/null +++ b/mmdet/models/dense_heads/ld_head.py @@ -0,0 +1,257 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from ..utils import multi_apply, unpack_gt_instances +from .gfl_head import GFLHead + + +@MODELS.register_module() +class LDHead(GFLHead): + """Localization distillation Head. (Short description) + + It utilizes the learned bbox distributions to transfer the localization + dark knowledge from teacher to student. Original paper: `Localization + Distillation for Object Detection. `_ + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + loss_ld (:obj:`ConfigDict` or dict): Config of Localization + Distillation Loss (LD), T is the temperature for distillation. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + loss_ld: ConfigType = dict( + type='LocalizationDistillationLoss', + loss_weight=0.25, + T=10), + **kwargs) -> dict: + + super().__init__( + num_classes=num_classes, in_channels=in_channels, **kwargs) + self.loss_ld = MODELS.build(loss_ld) + + def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + stride: Tuple[int], soft_targets: Tensor, + avg_factor: int): + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Cls and quality joint scores for each scale + level has shape (N, num_classes, H, W). + bbox_pred (Tensor): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + stride (tuple): Stride in this scale level. + soft_targets (Tensor): Soft BBox regression targets. + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + dict[tuple, Tensor]: Loss components and weight targets. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(-1, 4 * (self.reg_max + 1)) + soft_targets = soft_targets.permute(0, 2, 3, + 1).reshape(-1, + 4 * (self.reg_max + 1)) + + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + score = label_weights.new_zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] + + weight_targets = cls_score.detach().sigmoid() + weight_targets = weight_targets.max(dim=1)[0][pos_inds] + pos_bbox_pred_corners = self.integral(pos_bbox_pred) + pos_decode_bbox_pred = self.bbox_coder.decode( + pos_anchor_centers, pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + score[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) + pos_soft_targets = soft_targets[pos_inds] + soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1) + + target_corners = self.bbox_coder.encode(pos_anchor_centers, + pos_decode_bbox_targets, + self.reg_max).reshape(-1) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=weight_targets, + avg_factor=1.0) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + + # ld loss + loss_ld = self.loss_ld( + pred_corners, + soft_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + + else: + loss_ld = bbox_pred.sum() * 0 + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + weight_targets = bbox_pred.new_tensor(0) + + # cls (qfl) loss + loss_cls = self.loss_cls( + cls_score, (labels, score), + weight=label_weights, + avg_factor=avg_factor) + + return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum() + + def loss(self, x: List[Tensor], out_teacher: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + """ + Args: + x (list[Tensor]): Features from FPN. + out_teacher (tuple[Tensor]): The output of teacher. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + tuple[dict, list]: The loss components and proposals of each image. + + - losses (dict[str, Tensor]): A dictionary of loss components. + - proposal_list (list[Tensor]): Proposals of each image. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + + outs = self(x) + soft_targets = out_teacher[1] + loss_inputs = outs + (batch_gt_instances, batch_img_metas, + soft_targets) + losses = self.loss_by_feat( + *loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore) + + return losses + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + soft_targets: List[Tensor], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Cls and quality scores for each scale + level has shape (N, num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + soft_targets (list[Tensor]): Soft BBox regression targets. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_reg_targets + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + losses_cls, losses_bbox, losses_dfl, losses_ld, \ + avg_factor = multi_apply( + self.loss_by_feat_single, + anchor_list, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + self.prior_generator.strides, + soft_targets, + avg_factor=avg_factor) + + avg_factor = sum(avg_factor) + 1e-6 + avg_factor = reduce_mean(avg_factor).item() + losses_bbox = [x / avg_factor for x in losses_bbox] + losses_dfl = [x / avg_factor for x in losses_dfl] + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_dfl=losses_dfl, + loss_ld=losses_ld) diff --git a/mmdet/models/dense_heads/mask2former_head.py b/mmdet/models/dense_heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..12d47c655255f92819646b8ea304b9736ec30660 --- /dev/null +++ b/mmdet/models/dense_heads/mask2former_head.py @@ -0,0 +1,459 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d +from mmcv.ops import point_sample +from mmengine.model import ModuleList, caffe2_xavier_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig, reduce_mean +from ..layers import Mask2FormerTransformerDecoder, SinePositionalEncoding +from ..utils import get_uncertain_point_coords_with_randomness +from .anchor_free_head import AnchorFreeHead +from .maskformer_head import MaskFormerHead + + +@MODELS.register_module() +class Mask2FormerHead(MaskFormerHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`ConfigDict` or dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer decoder position encoding. Defaults to + dict(num_feats=128, normalize=True). + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + Mask2Former head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + Mask2Former head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: List[int], + feat_channels: int, + out_channels: int, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + num_queries: int = 100, + num_transformer_feat_level: int = 3, + pixel_decoder: ConfigType = ..., + enforce_decoder_input_project: bool = False, + transformer_decoder: ConfigType = ..., + positional_encoding: ConfigType = dict( + num_feats=128, normalize=True), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 133 + [0.1]), + loss_mask: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice: ConfigType = dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.layer_cfg. \ + self_attn_cfg.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = MODELS.build(pixel_decoder_) + self.transformer_decoder = Mask2FormerTransformerDecoder( + **transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = SinePositionalEncoding( + **positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + def init_weights(self) -> None: + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> Tuple[Tensor]: + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + img_meta (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + gt_labels = gt_instances.labels + gt_masks = gt_instances.masks + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + assign_result = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances, + img_meta=img_meta) + pred_instances = InstanceData(scores=cls_score, masks=mask_pred) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries, )) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds, sampling_result) + + def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + avg_factor) = self.get_targets(cls_scores_list, mask_preds_list, + batch_gt_instances, batch_img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int]) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (batch_size, num_queries, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, x: List[Tensor], + batch_data_samples: SampleList) -> Tuple[List[Tensor]]: + """Forward function. + + Args: + x (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[list[Tensor]]: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = x[0].shape[0] + mask_features, multi_scale_memorys = self.pixel_decoder(x) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + mask = decoder_input.new_zeros( + (batch_size, ) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (batch_size, num_queries, c) + query_feat = self.query_feat.weight.unsqueeze(0).repeat( + (batch_size, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat( + (batch_size, 1, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + mask_sum = (attn_mask.sum(-1) != attn_mask.shape[-1]).unsqueeze(-1) + attn_mask = attn_mask & mask_sum + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list diff --git a/mmdet/models/dense_heads/maskformer_head.py b/mmdet/models/dense_heads/maskformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..24c0655ee1c36e0110cf6578d1c095c50a297d81 --- /dev/null +++ b/mmdet/models/dense_heads/maskformer_head.py @@ -0,0 +1,601 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d +from mmengine.model import caffe2_xavier_init +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmdet.models.layers.pixel_decoder import PixelDecoder +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptMultiConfig, reduce_mean) +from ..layers import DetrTransformerDecoder, SinePositionalEncoding +from ..utils import multi_apply, preprocess_panoptic_gt +from .anchor_free_head import AnchorFreeHead + + +@MODELS.register_module() +class MaskFormerHead(AnchorFreeHead): + """Implements the MaskFormer head. + + See `Per-Pixel Classification is Not All You Need for Semantic + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer. + pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel + decoder. + enforce_decoder_input_project (bool): Whether to add a layer + to change the embed_dim of transformer encoder in pixel decoder to + the embed_dim of transformer decoder. Defaults to False. + transformer_decoder (:obj:`ConfigDict` or dict): Config for + transformer decoder. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer decoder position encoding. + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to `CrossEntropyLoss`. + loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. + Defaults to `FocalLoss`. + loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. + Defaults to `DiceLoss`. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + MaskFormer head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + MaskFormer head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: List[int], + feat_channels: int, + out_channels: int, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + num_queries: int = 100, + pixel_decoder: ConfigType = ..., + enforce_decoder_input_project: bool = False, + transformer_decoder: ConfigType = ..., + positional_encoding: ConfigType = dict( + num_feats=128, normalize=True), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + class_weight=[1.0] * 133 + [0.1]), + loss_mask: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=20.0), + loss_dice: ConfigType = dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + naive_dice=True, + loss_weight=1.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + + pixel_decoder.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = MODELS.build(pixel_decoder) + self.transformer_decoder = DetrTransformerDecoder( + **transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + if type(self.pixel_decoder) == PixelDecoder and ( + self.decoder_embed_dims != in_channels[-1] + or enforce_decoder_input_project): + self.decoder_input_proj = Conv2d( + in_channels[-1], self.decoder_embed_dims, kernel_size=1) + else: + self.decoder_input_proj = nn.Identity() + self.decoder_pe = SinePositionalEncoding(**positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, out_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + train_cfg['sampler'], default_args=dict(context=self)) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + def init_weights(self) -> None: + if isinstance(self.decoder_input_proj, Conv2d): + caffe2_xavier_init(self.decoder_input_proj, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def preprocess_gt( + self, batch_gt_instances: InstanceList, + batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList: + """Preprocess the ground truth for all images. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + ground truth labels of each bbox, with shape (num_gts, ) + and ``masks``, each is ground truth masks of each instances + of a image, shape (num_gts, h, w). + gt_semantic_seg (list[Optional[PixelData]]): Ground truth of + semantic segmentation, each with the shape (1, h, w). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. It's None when training instance segmentation. + + Returns: + list[obj:`InstanceData`]: each contains the following keys + + - labels (Tensor): Ground truth class indices\ + for a image, with shape (n, ), n is the sum of\ + number of stuff type and number of instance in a image. + - masks (Tensor): Ground truth mask for a\ + image, with shape (n, h, w). + """ + num_things_list = [self.num_things_classes] * len(batch_gt_instances) + num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances) + gt_labels_list = [ + gt_instances['labels'] for gt_instances in batch_gt_instances + ] + gt_masks_list = [ + gt_instances['masks'] for gt_instances in batch_gt_instances + ] + gt_semantic_segs = [ + None if gt_semantic_seg is None else gt_semantic_seg.sem_seg + for gt_semantic_seg in batch_gt_semantic_segs + ] + targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, + gt_masks_list, gt_semantic_segs, num_things_list, + num_stuff_list) + labels, masks = targets + batch_gt_instances = [ + InstanceData(labels=label, masks=mask) + for label, mask in zip(labels, masks) + ] + return batch_gt_instances + + def get_targets( + self, + cls_scores_list: List[Tensor], + mask_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + return_sampling_results: bool = False + ) -> Tuple[List[Union[Tensor, int]]]: + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - label_weights_list (list[Tensor]): Label weights\ + of all images. Each with shape (num_queries, ). + - mask_targets_list (list[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights_list (list[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to average\ + the loss. When using sampling method, avg_factor is + usually the sum of positive and negative priors. When + using `MaskPseudoSampler`, `avg_factor` is usually equal + to the number of positive priors. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end. + """ + results = multi_apply(self._get_targets_single, cls_scores_list, + mask_preds_list, batch_gt_instances, + batch_img_metas) + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, neg_inds_list, sampling_results_list) = results[:7] + rest_results = list(results[7:]) + + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + + res = (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, avg_factor) + if return_sampling_results: + res = res + (sampling_results_list) + + return res + tuple(rest_results) + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> Tuple[Tensor]: + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + img_meta (dict): Image informtation. + + Returns: + tuple: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + gt_masks = gt_instances.masks + gt_labels = gt_instances.labels + + target_shape = mask_pred.shape[-2:] + if gt_masks.shape[0] > 0: + gt_masks_downsampled = F.interpolate( + gt_masks.unsqueeze(1).float(), target_shape, + mode='nearest').squeeze(1).long() + else: + gt_masks_downsampled = gt_masks + + pred_instances = InstanceData(scores=cls_score, masks=mask_pred) + downsampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_masks_downsampled) + # assign and sample + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=downsampled_gt_instances, + img_meta=img_meta) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones(self.num_queries) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds, sampling_result) + + def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + batch_gt_instances_list = [ + batch_gt_instances for _ in range(num_dec_layers) + ] + img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self._loss_by_feat_single, all_cls_scores, all_mask_preds, + batch_gt_instances_list, img_metas_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip( + losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single decoder\ + layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + avg_factor) = self.get_targets(cls_scores_list, mask_preds_list, + batch_gt_instances, batch_img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + target_shape = mask_targets.shape[-2:] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + # upsample to shape of target + # shape (num_total_gts, h, w) + mask_preds = F.interpolate( + mask_preds.unsqueeze(1), + target_shape, + mode='bilinear', + align_corners=False).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_preds, mask_targets, avg_factor=num_total_masks) + + # mask loss + # FocalLoss support input of shape (n, num_class) + h, w = mask_preds.shape[-2:] + # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1) + mask_preds = mask_preds.reshape(-1, 1) + # shape (num_total_gts, h, w) -> (num_total_gts * h * w) + mask_targets = mask_targets.reshape(-1) + # target is (1 - mask_targets) !!! + loss_mask = self.loss_mask( + mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w) + + return loss_cls, loss_mask, loss_dice + + def forward(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> Tuple[Tensor]: + """Forward function. + + Args: + x (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[Tensor]: a tuple contains two elements. + + - all_cls_scores (Tensor): Classification scores for each\ + scale level. Each is a 4D-tensor with shape\ + (num_decoder, batch_size, num_queries, cls_out_channels).\ + Note `cls_out_channels` should includes background. + - all_mask_preds (Tensor): Mask scores for each decoder\ + layer. Each with shape (num_decoder, batch_size,\ + num_queries, h, w). + """ + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + batch_size = x[0].shape[0] + input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] + padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w), + dtype=torch.float32) + for i in range(batch_size): + img_h, img_w = batch_img_metas[i]['img_shape'] + padding_mask[i, :img_h, :img_w] = 0 + padding_mask = F.interpolate( + padding_mask.unsqueeze(1), size=x[-1].shape[-2:], + mode='nearest').to(torch.bool).squeeze(1) + # when backbone is swin, memory is output of last stage of swin. + # when backbone is r50, memory is output of tranformer encoder. + mask_features, memory = self.pixel_decoder(x, batch_img_metas) + pos_embed = self.decoder_pe(padding_mask) + memory = self.decoder_input_proj(memory) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + memory = memory.flatten(2).permute(0, 2, 1) + pos_embed = pos_embed.flatten(2).permute(0, 2, 1) + # shape (batch_size, h * w) + padding_mask = padding_mask.flatten(1) + # shape = (num_queries, embed_dims) + query_embed = self.query_embed.weight + # shape = (batch_size, num_queries, embed_dims) + query_embed = query_embed.unsqueeze(0).repeat(batch_size, 1, 1) + target = torch.zeros_like(query_embed) + # shape (num_decoder, num_queries, batch_size, embed_dims) + out_dec = self.transformer_decoder( + query=target, + key=memory, + value=memory, + query_pos=query_embed, + key_pos=pos_embed, + key_padding_mask=padding_mask) + + # cls_scores + all_cls_scores = self.cls_embed(out_dec) + + # mask_preds + mask_embed = self.mask_embed(out_dec) + all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed, + mask_features) + + return all_cls_scores, all_mask_preds + + def loss( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the panoptic + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + batch_gt_semantic_segs = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if 'gt_sem_seg' in data_sample: + batch_gt_semantic_segs.append(data_sample.gt_sem_seg) + else: + batch_gt_semantic_segs.append(None) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # preprocess ground truth + batch_gt_instances = self.preprocess_gt(batch_gt_instances, + batch_gt_semantic_segs) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two tensors. + + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + """ + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + + return mask_cls_results, mask_pred_results diff --git a/mmdet/models/dense_heads/nasfcos_head.py b/mmdet/models/dense_heads/nasfcos_head.py new file mode 100644 index 0000000000000000000000000000000000000000..14ee62a7910d90a108fefb2acef00c91ab83ecc8 --- /dev/null +++ b/mmdet/models/dense_heads/nasfcos_head.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.nn as nn +from mmcv.cnn import ConvModule, Scale + +from mmdet.models.dense_heads.fcos_head import FCOSHead +from mmdet.registry import MODELS +from mmdet.utils import OptMultiConfig + + +@MODELS.register_module() +class NASFCOSHead(FCOSHead): + """Anchor-free head used in `NASFCOS `_. + + It is quite similar with FCOS head, except for the searched structure of + classification branch and bbox regression branch, where a structure of + "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + strides (Sequence[int] or Sequence[Tuple[int, int]]): Strides of points + in multiple feature levels. Defaults to (4, 8, 16, 32, 64). + regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple + level points. + center_sampling (bool): If true, use center sampling. + Defaults to False. + center_sample_radius (float): Radius of center sampling. + Defaults to 1.5. + norm_on_bbox (bool): If true, normalize the regression targets with + FPN strides. Defaults to False. + centerness_on_reg (bool): If true, position centerness on the + regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042. + Defaults to False. + conv_bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias of conv will be set as True if `norm_cfg` is + None, otherwise False. Defaults to "auto". + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. + loss_centerness (:obj:`ConfigDict`, or dict): Config of centerness + loss. + norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config norm layer. Defaults to + ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], opitonal): Initialization config dict. + """ # noqa: E501 + + def __init__(self, + *args, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + if init_cfg is None: + init_cfg = [ + dict(type='Caffe2Xavier', layer=['ConvModule', 'Conv2d']), + dict( + type='Normal', + std=0.01, + override=[ + dict(name='conv_reg'), + dict(name='conv_centerness'), + dict( + name='conv_cls', + type='Normal', + std=0.01, + bias_prob=0.01) + ]), + ] + super().__init__(*args, init_cfg=init_cfg, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + dconv3x3_config = dict( + type='DCNv2', + kernel_size=3, + use_bias=True, + deform_groups=2, + padding=1) + conv3x3_config = dict(type='Conv', kernel_size=3, padding=1) + conv1x1_config = dict(type='Conv', kernel_size=1) + + self.arch_config = [ + dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config + ] + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i, op_ in enumerate(self.arch_config): + op = copy.deepcopy(op_) + chn = self.in_channels if i == 0 else self.feat_channels + assert isinstance(op, dict) + use_bias = op.pop('use_bias', False) + padding = op.pop('padding', 0) + kernel_size = op.pop('kernel_size') + module = ConvModule( + chn, + self.feat_channels, + kernel_size, + stride=1, + padding=padding, + norm_cfg=self.norm_cfg, + bias=use_bias, + conv_cfg=op) + + self.cls_convs.append(copy.deepcopy(module)) + self.reg_convs.append(copy.deepcopy(module)) + + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) + + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1f453d2788b354970254e8875068e824c370d4 --- /dev/null +++ b/mmdet/models/dense_heads/paa_head.py @@ -0,0 +1,730 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList) +from ..layers import multiclass_nms +from ..utils import levels_to_images, multi_apply +from . import ATSSHead + +EPS = 1e-12 +try: + import sklearn.mixture as skm +except ImportError: + skm = None + + +@MODELS.register_module() +class PAAHead(ATSSHead): + """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU + Prediction for Object Detection. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + topk (int): Select topk samples with smallest loss in + each level. + score_voting (bool): Whether to use score voting in post-process. + covariance_type : String describing the type of covariance parameters + to be used in :class:`sklearn.mixture.GaussianMixture`. + It must be one of: + + - 'full': each component has its own general covariance matrix + - 'tied': all components share the same general covariance matrix + - 'diag': each component has its own diagonal covariance matrix + - 'spherical': each component has its own single variance + Default: 'diag'. From 'full' to 'spherical', the gmm fitting + process is faster yet the performance could be influenced. For most + cases, 'diag' should be a good choice. + """ + + def __init__(self, + *args, + topk: int = 9, + score_voting: bool = True, + covariance_type: str = 'diag', + **kwargs): + # topk used in paa reassign process + self.topk = topk + self.with_score_voting = score_voting + self.covariance_type = covariance_type + super().__init__(*args, **kwargs) + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + iou_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + iou_preds (list[Tensor]): iou_preds for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss gmm_assignment. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + ) + (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds, + pos_gt_index) = cls_reg_targets + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + bbox_preds = levels_to_images(bbox_preds) + bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] + iou_preds = levels_to_images(iou_preds) + iou_preds = [item.reshape(-1, 1) for item in iou_preds] + pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, + cls_scores, bbox_preds, labels, + labels_weight, bboxes_target, + bboxes_weight, pos_inds) + + with torch.no_grad(): + reassign_labels, reassign_label_weight, \ + reassign_bbox_weights, num_pos = multi_apply( + self.paa_reassign, + pos_losses_list, + labels, + labels_weight, + bboxes_weight, + pos_inds, + pos_gt_index, + anchor_list) + num_pos = sum(num_pos) + # convert all tensor list to a flatten tensor + cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) + bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) + iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1)) + labels = torch.cat(reassign_labels, 0).view(-1) + flatten_anchors = torch.cat( + [torch.cat(item, 0) for item in anchor_list]) + labels_weight = torch.cat(reassign_label_weight, 0).view(-1) + bboxes_target = torch.cat(bboxes_target, + 0).view(-1, bboxes_target[0].size(-1)) + + pos_inds_flatten = ((labels >= 0) + & + (labels < self.num_classes)).nonzero().reshape(-1) + + losses_cls = self.loss_cls( + cls_scores, + labels, + labels_weight, + avg_factor=max(num_pos, len(batch_img_metas))) # avoid num_pos=0 + if num_pos: + pos_bbox_pred = self.bbox_coder.decode( + flatten_anchors[pos_inds_flatten], + bbox_preds[pos_inds_flatten]) + pos_bbox_target = bboxes_target[pos_inds_flatten] + iou_target = bbox_overlaps( + pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True) + losses_iou = self.loss_centerness( + iou_preds[pos_inds_flatten], + iou_target.unsqueeze(-1), + avg_factor=num_pos) + losses_bbox = self.loss_bbox( + pos_bbox_pred, + pos_bbox_target, + iou_target.clamp(min=EPS), + avg_factor=iou_target.sum()) + else: + losses_iou = iou_preds.sum() * 0 + losses_bbox = bbox_preds.sum() * 0 + + return dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou) + + def get_pos_loss(self, anchors: List[Tensor], cls_score: Tensor, + bbox_pred: Tensor, label: Tensor, label_weight: Tensor, + bbox_target: dict, bbox_weight: Tensor, + pos_inds: Tensor) -> Tensor: + """Calculate loss of all potential positive samples obtained from first + match process. + + Args: + anchors (list[Tensor]): Anchors of each scale. + cls_score (Tensor): Box scores of single image with shape + (num_anchors, num_classes) + bbox_pred (Tensor): Box energies / deltas of single image + with shape (num_anchors, 4) + label (Tensor): classification target of each anchor with + shape (num_anchors,) + label_weight (Tensor): Classification loss weight of each + anchor with shape (num_anchors). + bbox_target (dict): Regression target of each anchor with + shape (num_anchors, 4). + bbox_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + pos_inds (Tensor): Index of all positive samples got from + first assign process. + + Returns: + Tensor: Losses of all positive samples in single image. + """ + if not len(pos_inds): + return cls_score.new([]), + anchors_all_level = torch.cat(anchors, 0) + pos_scores = cls_score[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_label = label[pos_inds] + pos_label_weight = label_weight[pos_inds] + pos_bbox_target = bbox_target[pos_inds] + pos_bbox_weight = bbox_weight[pos_inds] + pos_anchors = anchors_all_level[pos_inds] + pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred) + + # to keep loss dimension + loss_cls = self.loss_cls( + pos_scores, + pos_label, + pos_label_weight, + avg_factor=1.0, + reduction_override='none') + + loss_bbox = self.loss_bbox( + pos_bbox_pred, + pos_bbox_target, + pos_bbox_weight, + avg_factor=1.0, # keep same loss weight before reassign + reduction_override='none') + + loss_cls = loss_cls.sum(-1) + pos_loss = loss_bbox + loss_cls + return pos_loss, + + def paa_reassign(self, pos_losses: Tensor, label: Tensor, + label_weight: Tensor, bbox_weight: Tensor, + pos_inds: Tensor, pos_gt_inds: Tensor, + anchors: List[Tensor]) -> tuple: + """Fit loss to GMM distribution and separate positive, ignore, negative + samples again with GMM model. + + Args: + pos_losses (Tensor): Losses of all positive samples in + single image. + label (Tensor): classification target of each anchor with + shape (num_anchors,) + label_weight (Tensor): Classification loss weight of each + anchor with shape (num_anchors). + bbox_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + pos_inds (Tensor): Index of all positive samples got from + first assign process. + pos_gt_inds (Tensor): Gt_index of all positive samples got + from first assign process. + anchors (list[Tensor]): Anchors of each scale. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - label (Tensor): classification target of each anchor after + paa assign, with shape (num_anchors,) + - label_weight (Tensor): Classification loss weight of each + anchor after paa assign, with shape (num_anchors). + - bbox_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + - num_pos (int): The number of positive samples after paa + assign. + """ + if not len(pos_inds): + return label, label_weight, bbox_weight, 0 + label = label.clone() + label_weight = label_weight.clone() + bbox_weight = bbox_weight.clone() + num_gt = pos_gt_inds.max() + 1 + num_level = len(anchors) + num_anchors_each_level = [item.size(0) for item in anchors] + num_anchors_each_level.insert(0, 0) + inds_level_interval = np.cumsum(num_anchors_each_level) + pos_level_mask = [] + for i in range(num_level): + mask = (pos_inds >= inds_level_interval[i]) & ( + pos_inds < inds_level_interval[i + 1]) + pos_level_mask.append(mask) + pos_inds_after_paa = [label.new_tensor([])] + ignore_inds_after_paa = [label.new_tensor([])] + for gt_ind in range(num_gt): + pos_inds_gmm = [] + pos_loss_gmm = [] + gt_mask = pos_gt_inds == gt_ind + for level in range(num_level): + level_mask = pos_level_mask[level] + level_gt_mask = level_mask & gt_mask + value, topk_inds = pos_losses[level_gt_mask].topk( + min(level_gt_mask.sum(), self.topk), largest=False) + pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds]) + pos_loss_gmm.append(value) + pos_inds_gmm = torch.cat(pos_inds_gmm) + pos_loss_gmm = torch.cat(pos_loss_gmm) + # fix gmm need at least two sample + if len(pos_inds_gmm) < 2: + continue + device = pos_inds_gmm.device + pos_loss_gmm, sort_inds = pos_loss_gmm.sort() + pos_inds_gmm = pos_inds_gmm[sort_inds] + pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy() + min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max() + means_init = np.array([min_loss, max_loss]).reshape(2, 1) + weights_init = np.array([0.5, 0.5]) + precisions_init = np.array([1.0, 1.0]).reshape(2, 1, 1) # full + if self.covariance_type == 'spherical': + precisions_init = precisions_init.reshape(2) + elif self.covariance_type == 'diag': + precisions_init = precisions_init.reshape(2, 1) + elif self.covariance_type == 'tied': + precisions_init = np.array([[1.0]]) + if skm is None: + raise ImportError('Please run "pip install sklearn" ' + 'to install sklearn first.') + gmm = skm.GaussianMixture( + 2, + weights_init=weights_init, + means_init=means_init, + precisions_init=precisions_init, + covariance_type=self.covariance_type) + gmm.fit(pos_loss_gmm) + gmm_assignment = gmm.predict(pos_loss_gmm) + scores = gmm.score_samples(pos_loss_gmm) + gmm_assignment = torch.from_numpy(gmm_assignment).to(device) + scores = torch.from_numpy(scores).to(device) + + pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme( + gmm_assignment, scores, pos_inds_gmm) + pos_inds_after_paa.append(pos_inds_temp) + ignore_inds_after_paa.append(ignore_inds_temp) + + pos_inds_after_paa = torch.cat(pos_inds_after_paa) + ignore_inds_after_paa = torch.cat(ignore_inds_after_paa) + reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1) + reassign_ids = pos_inds[reassign_mask] + label[reassign_ids] = self.num_classes + label_weight[ignore_inds_after_paa] = 0 + bbox_weight[reassign_ids] = 0 + num_pos = len(pos_inds_after_paa) + return label, label_weight, bbox_weight, num_pos + + def gmm_separation_scheme(self, gmm_assignment: Tensor, scores: Tensor, + pos_inds_gmm: Tensor) -> Tuple[Tensor, Tensor]: + """A general separation scheme for gmm model. + + It separates a GMM distribution of candidate samples into three + parts, 0 1 and uncertain areas, and you can implement other + separation schemes by rewriting this function. + + Args: + gmm_assignment (Tensor): The prediction of GMM which is of shape + (num_samples,). The 0/1 value indicates the distribution + that each sample comes from. + scores (Tensor): The probability of sample coming from the + fit GMM distribution. The tensor is of shape (num_samples,). + pos_inds_gmm (Tensor): All the indexes of samples which are used + to fit GMM model. The tensor is of shape (num_samples,) + + Returns: + tuple[Tensor, Tensor]: The indices of positive and ignored samples. + + - pos_inds_temp (Tensor): Indices of positive samples. + - ignore_inds_temp (Tensor): Indices of ignore samples. + """ + # The implementation is (c) in Fig.3 in origin paper instead of (b). + # You can refer to issues such as + # https://github.com/kkhoot/PAA/issues/8 and + # https://github.com/kkhoot/PAA/issues/9. + fgs = gmm_assignment == 0 + pos_inds_temp = fgs.new_tensor([], dtype=torch.long) + ignore_inds_temp = fgs.new_tensor([], dtype=torch.long) + if fgs.nonzero().numel(): + _, pos_thr_ind = scores[fgs].topk(1) + pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1] + ignore_inds_temp = pos_inds_gmm.new_tensor([]) + return pos_inds_temp, ignore_inds_temp + + def get_targets(self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Get targets for PAA head. + + This method is almost the same as `AnchorHead.get_targets()`. We direct + return the results from _get_targets_single instead map it to levels + by images_to_levels function. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - labels (list[Tensor]): Labels of all anchors, each with + shape (num_anchors,). + - label_weights (list[Tensor]): Label weights of all anchor. + each with shape (num_anchors,). + - bbox_targets (list[Tensor]): BBox targets of all anchors. + each with shape (num_anchors, 4). + - bbox_weights (list[Tensor]): BBox weights of all anchors. + each with shape (num_anchors, 4). + - pos_inds (list[Tensor]): Contains all index of positive + sample in all anchor. + - gt_inds (list[Tensor]): Contains all gt_index of positive + sample in all anchor. + """ + + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + concat_anchor_list = [] + concat_valid_flag_list = [] + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + concat_anchor_list.append(torch.cat(anchor_list[i])) + concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + results = multi_apply( + self._get_targets_single, + concat_anchor_list, + concat_valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + + (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds, + valid_neg_inds, sampling_result) = results + + # Due to valid flag of anchors, we have to calculate the real pos_inds + # in origin anchor set. + pos_inds = [] + for i, single_labels in enumerate(labels): + pos_mask = (0 <= single_labels) & ( + single_labels < self.num_classes) + pos_inds.append(pos_mask.nonzero().view(-1)) + + gt_inds = [item.pos_assigned_gt_inds for item in sampling_result] + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + gt_inds) + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + This method is same as `AnchorHead._get_targets_single()`. + """ + assert unmap_outputs, 'We must map outputs back to the original' \ + 'set of anchors in PAAhead' + return super(ATSSHead, self)._get_targets_single( + flat_anchors, + valid_flags, + gt_instances, + img_meta, + gt_instances_ignore, + unmap_outputs=True) + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: OptConfigType = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + This method is same as `BaseDenseHead.get_results()`. + """ + assert with_nms, 'PAA only supports "with_nms=True" now and it ' \ + 'means PAAHead does not support ' \ + 'test-time augmentation' + return super().predict_by_feat( + cls_scores=cls_scores, + bbox_preds=bbox_preds, + score_factors=score_factors, + batch_img_metas=batch_img_metas, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: OptConfigType = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factors from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (:obj:`ConfigDict` or dict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_score_factors = [] + for level_idx, (cls_score, bbox_pred, score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, + score_factor_list, mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid() + + if 0 < nms_pre < scores.shape[0]: + max_scores, _ = (scores * + score_factor[:, None]).sqrt().max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + priors = priors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + score_factor = score_factor[topk_inds] + + bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_score_factors.append(score_factor) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.score_factors = torch.cat(mlvl_score_factors) + + return self._bbox_post_process(results, cfg, rescale, with_nms, + img_meta) + + def _bbox_post_process(self, + results: InstanceData, + cfg: ConfigType, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None): + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually with_nms is False is used for aug test. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (:obj:`ConfigDict` or dict): Test / postprocessing + configuration, if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if rescale: + results.bboxes /= results.bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = results.scores.new_zeros(results.scores.shape[0], 1) + mlvl_scores = torch.cat([results.scores, padding], dim=1) + + mlvl_nms_scores = (mlvl_scores * results.score_factors[:, None]).sqrt() + det_bboxes, det_labels = multiclass_nms( + results.bboxes, + mlvl_nms_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=None) + if self.with_score_voting and len(det_bboxes) > 0: + det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels, + results.bboxes, + mlvl_nms_scores, + cfg.score_thr) + nms_results = InstanceData() + nms_results.bboxes = det_bboxes[:, :-1] + nms_results.scores = det_bboxes[:, -1] + nms_results.labels = det_labels + return nms_results + + def score_voting(self, det_bboxes: Tensor, det_labels: Tensor, + mlvl_bboxes: Tensor, mlvl_nms_scores: Tensor, + score_thr: float) -> Tuple[Tensor, Tensor]: + """Implementation of score voting method works on each remaining boxes + after NMS procedure. + + Args: + det_bboxes (Tensor): Remaining boxes after NMS procedure, + with shape (k, 5), each dimension means + (x1, y1, x2, y2, score). + det_labels (Tensor): The label of remaining boxes, with shape + (k, 1),Labels are 0-based. + mlvl_bboxes (Tensor): All boxes before the NMS procedure, + with shape (num_anchors,4). + mlvl_nms_scores (Tensor): The scores of all boxes which is used + in the NMS procedure, with shape (num_anchors, num_class) + score_thr (float): The score threshold of bboxes. + + Returns: + tuple: Usually returns a tuple containing voting results. + + - det_bboxes_voted (Tensor): Remaining boxes after + score voting procedure, with shape (k, 5), each + dimension means (x1, y1, x2, y2, score). + - det_labels_voted (Tensor): Label of remaining bboxes + after voting, with shape (num_anchors,). + """ + candidate_mask = mlvl_nms_scores > score_thr + candidate_mask_nonzeros = candidate_mask.nonzero(as_tuple=False) + candidate_inds = candidate_mask_nonzeros[:, 0] + candidate_labels = candidate_mask_nonzeros[:, 1] + candidate_bboxes = mlvl_bboxes[candidate_inds] + candidate_scores = mlvl_nms_scores[candidate_mask] + det_bboxes_voted = [] + det_labels_voted = [] + for cls in range(self.cls_out_channels): + candidate_cls_mask = candidate_labels == cls + if not candidate_cls_mask.any(): + continue + candidate_cls_scores = candidate_scores[candidate_cls_mask] + candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask] + det_cls_mask = det_labels == cls + det_cls_bboxes = det_bboxes[det_cls_mask].view( + -1, det_bboxes.size(-1)) + det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4], + candidate_cls_bboxes) + for det_ind in range(len(det_cls_bboxes)): + single_det_ious = det_candidate_ious[det_ind] + pos_ious_mask = single_det_ious > 0.01 + pos_ious = single_det_ious[pos_ious_mask] + pos_bboxes = candidate_cls_bboxes[pos_ious_mask] + pos_scores = candidate_cls_scores[pos_ious_mask] + pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) * + pos_scores)[:, None] + voted_box = torch.sum( + pis * pos_bboxes, dim=0) / torch.sum( + pis, dim=0) + voted_score = det_cls_bboxes[det_ind][-1:][None, :] + det_bboxes_voted.append( + torch.cat((voted_box[None, :], voted_score), dim=1)) + det_labels_voted.append(cls) + + det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0) + det_labels_voted = det_labels.new_tensor(det_labels_voted) + return det_bboxes_voted, det_labels_voted diff --git a/mmdet/models/dense_heads/pisa_retinanet_head.py b/mmdet/models/dense_heads/pisa_retinanet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..85fd54f5be3605d0994c2a2d4d9d7deac4c0f284 --- /dev/null +++ b/mmdet/models/dense_heads/pisa_retinanet_head.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import InstanceList, OptInstanceList +from ..losses import carl_loss, isr_p +from ..utils import images_to_levels +from .retina_head import RetinaHead + + +@MODELS.register_module() +class PISARetinaHead(RetinaHead): + """PISA Retinanet Head. + + The head owns the same structure with Retinanet Head, but differs in two + aspects: + 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to + change the positive loss weights. + 2. Classification-aware regression loss is adopted as a third loss. + """ + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: Loss dict, comprise classification loss, regression loss and + carl loss. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + return_sampling_results=True) + if cls_reg_targets is None: + return None + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor, sampling_results_list) = cls_reg_targets + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i in range(len(anchor_list)): + concat_anchor_list.append(torch.cat(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + + num_imgs = len(batch_img_metas) + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels) + for cls_score in cls_scores + ] + flatten_cls_scores = torch.cat( + flatten_cls_scores, dim=1).reshape(-1, + flatten_cls_scores[0].size(-1)) + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in bbox_preds + ] + flatten_bbox_preds = torch.cat( + flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1)) + flatten_labels = torch.cat(labels_list, dim=1).reshape(-1) + flatten_label_weights = torch.cat( + label_weights_list, dim=1).reshape(-1) + flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4) + flatten_bbox_targets = torch.cat( + bbox_targets_list, dim=1).reshape(-1, 4) + flatten_bbox_weights = torch.cat( + bbox_weights_list, dim=1).reshape(-1, 4) + + # Apply ISR-P + isr_cfg = self.train_cfg.get('isr', None) + if isr_cfg is not None: + all_targets = (flatten_labels, flatten_label_weights, + flatten_bbox_targets, flatten_bbox_weights) + with torch.no_grad(): + all_targets = isr_p( + flatten_cls_scores, + flatten_bbox_preds, + all_targets, + flatten_anchors, + sampling_results_list, + bbox_coder=self.bbox_coder, + loss_cls=self.loss_cls, + num_class=self.num_classes, + **self.train_cfg['isr']) + (flatten_labels, flatten_label_weights, flatten_bbox_targets, + flatten_bbox_weights) = all_targets + + # For convenience we compute loss once instead separating by fpn level, + # so that we don't need to separate the weights by level again. + # The result should be the same + losses_cls = self.loss_cls( + flatten_cls_scores, + flatten_labels, + flatten_label_weights, + avg_factor=avg_factor) + losses_bbox = self.loss_bbox( + flatten_bbox_preds, + flatten_bbox_targets, + flatten_bbox_weights, + avg_factor=avg_factor) + loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + + # CARL Loss + carl_cfg = self.train_cfg.get('carl', None) + if carl_cfg is not None: + loss_carl = carl_loss( + flatten_cls_scores, + flatten_labels, + flatten_bbox_preds, + flatten_bbox_targets, + self.loss_bbox, + **self.train_cfg['carl'], + avg_factor=avg_factor, + sigmoid=True, + num_class=self.num_classes) + loss_dict.update(loss_carl) + + return loss_dict diff --git a/mmdet/models/dense_heads/pisa_ssd_head.py b/mmdet/models/dense_heads/pisa_ssd_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ec09cb40a9c95d3f9889d736b80dfccef07f6fd1 --- /dev/null +++ b/mmdet/models/dense_heads/pisa_ssd_head.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Union + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import InstanceList, OptInstanceList +from ..losses import CrossEntropyLoss, SmoothL1Loss, carl_loss, isr_p +from ..utils import multi_apply +from .ssd_head import SSDHead + + +# TODO: add loss evaluator for SSD +@MODELS.register_module() +class PISASSDHead(SSDHead): + """Implementation of `PISA SSD head `_ + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (Sequence[int]): Number of channels in the input feature + map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Defaults to 0. + feat_channels (int): Number of hidden channels when stacked_convs + > 0. Defaults to 256. + use_depthwise (bool): Whether to use DepthwiseSeparableConv. + Defaults to False. + conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct + and config conv layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct + and config norm layer. Defaults to None. + act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct + and config activation layer. Defaults to None. + anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor + generator. + bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Defaults to False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of + anchor head. + test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of + anchor head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], Optional): Initialization config dict. + """ # noqa: W605 + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Union[List[Tensor], Tensor]]: + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Union[List[Tensor], Tensor]]: A dictionary of loss + components. the dict has components below: + + - loss_cls (list[Tensor]): A list containing each feature map \ + classification loss. + - loss_bbox (list[Tensor]): A list containing each feature map \ + regression loss. + - loss_carl (Tensor): The loss of CARL. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + unmap_outputs=False, + return_sampling_results=True) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor, sampling_results_list) = cls_reg_targets + + num_images = len(batch_img_metas) + all_cls_scores = torch.cat([ + s.permute(0, 2, 3, 1).reshape( + num_images, -1, self.cls_out_channels) for s in cls_scores + ], 1) + all_labels = torch.cat(labels_list, -1).view(num_images, -1) + all_label_weights = torch.cat(label_weights_list, + -1).view(num_images, -1) + all_bbox_preds = torch.cat([ + b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) + for b in bbox_preds + ], -2) + all_bbox_targets = torch.cat(bbox_targets_list, + -2).view(num_images, -1, 4) + all_bbox_weights = torch.cat(bbox_weights_list, + -2).view(num_images, -1, 4) + + # concat all level anchors to a single tensor + all_anchors = [] + for i in range(num_images): + all_anchors.append(torch.cat(anchor_list[i])) + + isr_cfg = self.train_cfg.get('isr', None) + all_targets = (all_labels.view(-1), all_label_weights.view(-1), + all_bbox_targets.view(-1, + 4), all_bbox_weights.view(-1, 4)) + # apply ISR-P + if isr_cfg is not None: + all_targets = isr_p( + all_cls_scores.view(-1, all_cls_scores.size(-1)), + all_bbox_preds.view(-1, 4), + all_targets, + torch.cat(all_anchors), + sampling_results_list, + loss_cls=CrossEntropyLoss(), + bbox_coder=self.bbox_coder, + **self.train_cfg['isr'], + num_class=self.num_classes) + (new_labels, new_label_weights, new_bbox_targets, + new_bbox_weights) = all_targets + all_labels = new_labels.view(all_labels.shape) + all_label_weights = new_label_weights.view(all_label_weights.shape) + all_bbox_targets = new_bbox_targets.view(all_bbox_targets.shape) + all_bbox_weights = new_bbox_weights.view(all_bbox_weights.shape) + + # add CARL loss + carl_loss_cfg = self.train_cfg.get('carl', None) + if carl_loss_cfg is not None: + loss_carl = carl_loss( + all_cls_scores.view(-1, all_cls_scores.size(-1)), + all_targets[0], + all_bbox_preds.view(-1, 4), + all_targets[2], + SmoothL1Loss(beta=1.), + **self.train_cfg['carl'], + avg_factor=avg_factor, + num_class=self.num_classes) + + # check NaN and Inf + assert torch.isfinite(all_cls_scores).all().item(), \ + 'classification scores become infinite or NaN!' + assert torch.isfinite(all_bbox_preds).all().item(), \ + 'bbox predications become infinite or NaN!' + + losses_cls, losses_bbox = multi_apply( + self.loss_by_feat_single, + all_cls_scores, + all_bbox_preds, + all_anchors, + all_labels, + all_label_weights, + all_bbox_targets, + all_bbox_weights, + avg_factor=avg_factor) + loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + if carl_loss_cfg is not None: + loss_dict.update(loss_carl) + return loss_dict diff --git a/mmdet/models/dense_heads/reppoints_head.py b/mmdet/models/dense_heads/reppoints_head.py new file mode 100644 index 0000000000000000000000000000000000000000..22f3e3401a4abd9cc35b41d24efe23e5655a905e --- /dev/null +++ b/mmdet/models/dense_heads/reppoints_head.py @@ -0,0 +1,885 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import DeformConv2d +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList +from ..task_modules.prior_generators import MlvlPointGenerator +from ..task_modules.samplers import PseudoSampler +from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply, + unmap) +from .anchor_free_head import AnchorFreeHead + + +@MODELS.register_module() +class RepPointsHead(AnchorFreeHead): + """RepPoint head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + point_feat_channels (int): Number of channels of points features. + num_points (int): Number of points. + gradient_mul (float): The multiplier to gradients from + points refinement and recognition. + point_strides (Sequence[int]): points strides. + point_base_scale (int): bbox scale for assigning labels. + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox_init (:obj:`ConfigDict` or dict): Config of initial points + loss. + loss_bbox_refine (:obj:`ConfigDict` or dict): Config of points loss in + refinement. + use_grid_points (bool): If we use bounding box representation, the + reppoints is represented as grid points on the bounding box. + center_init (bool): Whether to use center point assignment. + transform_method (str): The methods to transform RepPoints to bbox. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + num_classes: int, + in_channels: int, + point_feat_channels: int = 256, + num_points: int = 9, + gradient_mul: float = 0.1, + point_strides: Sequence[int] = [8, 16, 32, 64, 128], + point_base_scale: int = 4, + loss_cls: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init: ConfigType = dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), + loss_bbox_refine: ConfigType = dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + use_grid_points: bool = False, + center_init: bool = True, + transform_method: str = 'moment', + moment_mul: float = 0.01, + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='reppoints_cls_out', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + self.num_points = num_points + self.point_feat_channels = point_feat_channels + self.use_grid_points = use_grid_points + self.center_init = center_init + + # we use deform conv to extract points features + self.dcn_kernel = int(np.sqrt(num_points)) + self.dcn_pad = int((self.dcn_kernel - 1) / 2) + assert self.dcn_kernel * self.dcn_kernel == num_points, \ + 'The points number should be a square number.' + assert self.dcn_kernel % 2 == 1, \ + 'The points number should be an odd square number.' + dcn_base = np.arange(-self.dcn_pad, + self.dcn_pad + 1).astype(np.float64) + dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) + dcn_base_x = np.tile(dcn_base, self.dcn_kernel) + dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape( + (-1)) + self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1) + + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + init_cfg=init_cfg, + **kwargs) + + self.gradient_mul = gradient_mul + self.point_base_scale = point_base_scale + self.point_strides = point_strides + self.prior_generator = MlvlPointGenerator( + self.point_strides, offset=0.) + + if self.train_cfg: + self.init_assigner = TASK_UTILS.build( + self.train_cfg['init']['assigner']) + self.refine_assigner = TASK_UTILS.build( + self.train_cfg['refine']['assigner']) + + if self.train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) + + self.transform_method = transform_method + if self.transform_method == 'moment': + self.moment_transfer = nn.Parameter( + data=torch.zeros(2), requires_grad=True) + self.moment_mul = moment_mul + + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = self.num_classes + else: + self.cls_out_channels = self.num_classes + 1 + self.loss_bbox_init = MODELS.build(loss_bbox_init) + self.loss_bbox_refine = MODELS.build(loss_bbox_refine) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points + self.reppoints_cls_conv = DeformConv2d(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, + self.dcn_pad) + self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, + self.cls_out_channels, 1, 1, 0) + self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, + self.point_feat_channels, 3, + 1, 1) + self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, + self.dcn_pad) + self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + + def points2bbox(self, pts: Tensor, y_first: bool = True) -> Tensor: + """Converting the points set into bounding box. + + Args: + pts (Tensor): the input points sets (fields), each points + set (fields) is represented as 2n scalar. + y_first (bool): if y_first=True, the point set is + represented as [y1, x1, y2, x2 ... yn, xn], otherwise + the point set is represented as + [x1, y1, x2, y2 ... xn, yn]. Defaults to True. + + Returns: + Tensor: each points set is converting to a bbox [x1, y1, x2, y2]. + """ + pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:]) + pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, + ...] + pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, + ...] + if self.transform_method == 'minmax': + bbox_left = pts_x.min(dim=1, keepdim=True)[0] + bbox_right = pts_x.max(dim=1, keepdim=True)[0] + bbox_up = pts_y.min(dim=1, keepdim=True)[0] + bbox_bottom = pts_y.max(dim=1, keepdim=True)[0] + bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], + dim=1) + elif self.transform_method == 'partial_minmax': + pts_y = pts_y[:, :4, ...] + pts_x = pts_x[:, :4, ...] + bbox_left = pts_x.min(dim=1, keepdim=True)[0] + bbox_right = pts_x.max(dim=1, keepdim=True)[0] + bbox_up = pts_y.min(dim=1, keepdim=True)[0] + bbox_bottom = pts_y.max(dim=1, keepdim=True)[0] + bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], + dim=1) + elif self.transform_method == 'moment': + pts_y_mean = pts_y.mean(dim=1, keepdim=True) + pts_x_mean = pts_x.mean(dim=1, keepdim=True) + pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True) + pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True) + moment_transfer = (self.moment_transfer * self.moment_mul) + ( + self.moment_transfer.detach() * (1 - self.moment_mul)) + moment_width_transfer = moment_transfer[0] + moment_height_transfer = moment_transfer[1] + half_width = pts_x_std * torch.exp(moment_width_transfer) + half_height = pts_y_std * torch.exp(moment_height_transfer) + bbox = torch.cat([ + pts_x_mean - half_width, pts_y_mean - half_height, + pts_x_mean + half_width, pts_y_mean + half_height + ], + dim=1) + else: + raise NotImplementedError + return bbox + + def gen_grid_from_reg(self, reg: Tensor, + previous_boxes: Tensor) -> Tuple[Tensor]: + """Base on the previous bboxes and regression values, we compute the + regressed bboxes and generate the grids on the bboxes. + + Args: + reg (Tensor): the regression value to previous bboxes. + previous_boxes (Tensor): previous bboxes. + + Returns: + Tuple[Tensor]: generate grids on the regressed bboxes. + """ + b, _, h, w = reg.shape + bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2. + bwh = (previous_boxes[:, 2:, ...] - + previous_boxes[:, :2, ...]).clamp(min=1e-6) + grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp( + reg[:, 2:, ...]) + grid_wh = bwh * torch.exp(reg[:, 2:, ...]) + grid_left = grid_topleft[:, [0], ...] + grid_top = grid_topleft[:, [1], ...] + grid_width = grid_wh[:, [0], ...] + grid_height = grid_wh[:, [1], ...] + intervel = torch.linspace(0., 1., self.dcn_kernel).view( + 1, self.dcn_kernel, 1, 1).type_as(reg) + grid_x = grid_left + grid_width * intervel + grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1) + grid_x = grid_x.view(b, -1, h, w) + grid_y = grid_top + grid_height * intervel + grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1) + grid_y = grid_y.view(b, -1, h, w) + grid_yx = torch.stack([grid_y, grid_x], dim=2) + grid_yx = grid_yx.view(b, -1, h, w) + regressed_bbox = torch.cat([ + grid_left, grid_top, grid_left + grid_width, grid_top + grid_height + ], 1) + return grid_yx, regressed_bbox + + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]: + return multi_apply(self.forward_single, feats) + + def forward_single(self, x: Tensor) -> Tuple[Tensor]: + """Forward feature map of a single FPN level.""" + dcn_base_offset = self.dcn_base_offset.type_as(x) + # If we use center_init, the initial reppoints is from center points. + # If we use bounding bbox representation, the initial reppoints is + # from regular grid placed on a pre-defined bbox. + if self.use_grid_points or not self.center_init: + scale = self.point_base_scale / 2 + points_init = dcn_base_offset / dcn_base_offset.max() * scale + bbox_init = x.new_tensor([-scale, -scale, scale, + scale]).view(1, 4, 1, 1) + else: + points_init = 0 + cls_feat = x + pts_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + pts_feat = reg_conv(pts_feat) + # initialize reppoints + pts_out_init = self.reppoints_pts_init_out( + self.relu(self.reppoints_pts_init_conv(pts_feat))) + if self.use_grid_points: + pts_out_init, bbox_out_init = self.gen_grid_from_reg( + pts_out_init, bbox_init.detach()) + else: + pts_out_init = pts_out_init + points_init + # refine and classify reppoints + pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach( + ) + self.gradient_mul * pts_out_init + dcn_offset = pts_out_init_grad_mul - dcn_base_offset + cls_out = self.reppoints_cls_out( + self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) + pts_out_refine = self.reppoints_pts_refine_out( + self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) + if self.use_grid_points: + pts_out_refine, bbox_out_refine = self.gen_grid_from_reg( + pts_out_refine, bbox_out_init.detach()) + else: + pts_out_refine = pts_out_refine + pts_out_init.detach() + + if self.training: + return cls_out, pts_out_init, pts_out_refine + else: + return cls_out, self.points2bbox(pts_out_refine) + + def get_points(self, featmap_sizes: List[Tuple[int]], + batch_img_metas: List[dict], device: str) -> tuple: + """Get points according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + batch_img_metas (list[dict]): Image meta info. + + Returns: + tuple: points of each image, valid flags of each image + """ + num_imgs = len(batch_img_metas) + + # since feature map sizes of all images are the same, we only compute + # points center for one time + multi_level_points = self.prior_generator.grid_priors( + featmap_sizes, device=device, with_stride=True) + points_list = [[point.clone() for point in multi_level_points] + for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level grids + valid_flag_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + multi_level_flags = self.prior_generator.valid_flags( + featmap_sizes, img_meta['pad_shape'], device=device) + valid_flag_list.append(multi_level_flags) + + return points_list, valid_flag_list + + def centers_to_bboxes(self, point_list: List[Tensor]) -> List[Tensor]: + """Get bboxes according to center points. + + Only used in :class:`MaxIoUAssigner`. + """ + bbox_list = [] + for i_img, point in enumerate(point_list): + bbox = [] + for i_lvl in range(len(self.point_strides)): + scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5 + bbox_shift = torch.Tensor([-scale, -scale, scale, + scale]).view(1, 4).type_as(point[0]) + bbox_center = torch.cat( + [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1) + bbox.append(bbox_center + bbox_shift) + bbox_list.append(bbox) + return bbox_list + + def offset_to_pts(self, center_list: List[Tensor], + pred_list: List[Tensor]) -> List[Tensor]: + """Change from point offset to point coordinate.""" + pts_list = [] + for i_lvl in range(len(self.point_strides)): + pts_lvl = [] + for i_img in range(len(center_list)): + pts_center = center_list[i_img][i_lvl][:, :2].repeat( + 1, self.num_points) + pts_shift = pred_list[i_lvl][i_img] + yx_pts_shift = pts_shift.permute(1, 2, 0).view( + -1, 2 * self.num_points) + y_pts_shift = yx_pts_shift[..., 0::2] + x_pts_shift = yx_pts_shift[..., 1::2] + xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1) + xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) + pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center + pts_lvl.append(pts) + pts_lvl = torch.stack(pts_lvl, 0) + pts_list.append(pts_lvl) + return pts_list + + def _get_targets_single(self, + flat_proposals: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + gt_instances_ignore: InstanceData, + stage: str = 'init', + unmap_outputs: bool = True) -> tuple: + """Compute corresponding GT box and classification targets for + proposals. + + Args: + flat_proposals (Tensor): Multi level points of a image. + valid_flags (Tensor): Multi level valid flags of a image. + gt_instances (InstanceData): It usually includes ``bboxes`` and + ``labels`` attributes. + gt_instances_ignore (InstanceData): It includes ``bboxes`` + attribute data that is ignored during training and testing. + stage (str): 'init' or 'refine'. Generate target for + init stage or refine stage. Defaults to 'init'. + unmap_outputs (bool): Whether to map outputs back to + the original set of anchors. Defaults to True. + + Returns: + tuple: + + - labels (Tensor): Labels of each level. + - label_weights (Tensor): Label weights of each level. + - bbox_targets (Tensor): BBox targets of each level. + - bbox_weights (Tensor): BBox weights of each level. + - pos_inds (Tensor): positive samples indexes. + - neg_inds (Tensor): negative samples indexes. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + inside_flags = valid_flags + if not inside_flags.any(): + raise ValueError( + 'There is no valid proposal inside the image boundary. Please ' + 'check the image size.') + # assign gt and sample proposals + proposals = flat_proposals[inside_flags, :] + pred_instances = InstanceData(priors=proposals) + + if stage == 'init': + assigner = self.init_assigner + pos_weight = self.train_cfg['init']['pos_weight'] + else: + assigner = self.refine_assigner + pos_weight = self.train_cfg['refine']['pos_weight'] + + assign_result = assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_proposals = proposals.shape[0] + bbox_gt = proposals.new_zeros([num_valid_proposals, 4]) + pos_proposals = torch.zeros_like(proposals) + proposals_weights = proposals.new_zeros([num_valid_proposals, 4]) + labels = proposals.new_full((num_valid_proposals, ), + self.num_classes, + dtype=torch.long) + label_weights = proposals.new_zeros( + num_valid_proposals, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + bbox_gt[pos_inds, :] = sampling_result.pos_gt_bboxes + pos_proposals[pos_inds, :] = proposals[pos_inds, :] + proposals_weights[pos_inds, :] = 1.0 + + labels[pos_inds] = sampling_result.pos_gt_labels + if pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of proposals + if unmap_outputs: + num_total_proposals = flat_proposals.size(0) + labels = unmap( + labels, + num_total_proposals, + inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_proposals, + inside_flags) + bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) + pos_proposals = unmap(pos_proposals, num_total_proposals, + inside_flags) + proposals_weights = unmap(proposals_weights, num_total_proposals, + inside_flags) + + return (labels, label_weights, bbox_gt, pos_proposals, + proposals_weights, pos_inds, neg_inds, sampling_result) + + def get_targets(self, + proposals_list: List[Tensor], + valid_flag_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + stage: str = 'init', + unmap_outputs: bool = True, + return_sampling_results: bool = False) -> tuple: + """Compute corresponding GT box and classification targets for + proposals. + + Args: + proposals_list (list[Tensor]): Multi level points/bboxes of each + image. + valid_flag_list (list[Tensor]): Multi level valid flags of each + image. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + stage (str): 'init' or 'refine'. Generate target for init stage or + refine stage. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_gt_list (list[Tensor]): Ground truth bbox of each level. + - proposals_list (list[Tensor]): Proposals(points/bboxes) of + each level. + - proposal_weights_list (list[Tensor]): Proposal weights of + each level. + - avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + """ + assert stage in ['init', 'refine'] + num_imgs = len(batch_img_metas) + assert len(proposals_list) == len(valid_flag_list) == num_imgs + + # points number of multi levels + num_level_proposals = [points.size(0) for points in proposals_list[0]] + + # concat all level points and flags to a single tensor + for i in range(num_imgs): + assert len(proposals_list[i]) == len(valid_flag_list[i]) + proposals_list[i] = torch.cat(proposals_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + + (all_labels, all_label_weights, all_bbox_gt, all_proposals, + all_proposal_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, + proposals_list, + valid_flag_list, + batch_gt_instances, + batch_gt_instances_ignore, + stage=stage, + unmap_outputs=unmap_outputs) + + # sampled points of all images + avg_refactor = sum( + [results.avg_factor for results in sampling_results_list]) + labels_list = images_to_levels(all_labels, num_level_proposals) + label_weights_list = images_to_levels(all_label_weights, + num_level_proposals) + bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) + proposals_list = images_to_levels(all_proposals, num_level_proposals) + proposal_weights_list = images_to_levels(all_proposal_weights, + num_level_proposals) + res = (labels_list, label_weights_list, bbox_gt_list, proposals_list, + proposal_weights_list, avg_refactor) + if return_sampling_results: + res = res + (sampling_results_list, ) + + return res + + def loss_by_feat_single(self, cls_score: Tensor, pts_pred_init: Tensor, + pts_pred_refine: Tensor, labels: Tensor, + label_weights, bbox_gt_init: Tensor, + bbox_weights_init: Tensor, bbox_gt_refine: Tensor, + bbox_weights_refine: Tensor, stride: int, + avg_factor_init: int, + avg_factor_refine: int) -> Tuple[Tensor]: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_classes, h_i, w_i). + pts_pred_init (Tensor): Points of shape + (batch_size, h_i * w_i, num_points * 2). + pts_pred_refine (Tensor): Points refined of shape + (batch_size, h_i * w_i, num_points * 2). + labels (Tensor): Ground truth class indices with shape + (batch_size, h_i * w_i). + label_weights (Tensor): Label weights of shape + (batch_size, h_i * w_i). + bbox_gt_init (Tensor): BBox regression targets in the init stage + of shape (batch_size, h_i * w_i, 4). + bbox_weights_init (Tensor): BBox regression loss weights in the + init stage of shape (batch_size, h_i * w_i, 4). + bbox_gt_refine (Tensor): BBox regression targets in the refine + stage of shape (batch_size, h_i * w_i, 4). + bbox_weights_refine (Tensor): BBox regression loss weights in the + refine stage of shape (batch_size, h_i * w_i, 4). + stride (int): Point stride. + avg_factor_init (int): Average factor that is used to average + the loss in the init stage. + avg_factor_refine (int): Average factor that is used to average + the loss in the refine stage. + + Returns: + Tuple[Tensor]: loss components. + """ + # classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + cls_score = cls_score.contiguous() + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor_refine) + + # points loss + bbox_gt_init = bbox_gt_init.reshape(-1, 4) + bbox_weights_init = bbox_weights_init.reshape(-1, 4) + bbox_pred_init = self.points2bbox( + pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False) + bbox_gt_refine = bbox_gt_refine.reshape(-1, 4) + bbox_weights_refine = bbox_weights_refine.reshape(-1, 4) + bbox_pred_refine = self.points2bbox( + pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False) + normalize_term = self.point_base_scale * stride + loss_pts_init = self.loss_bbox_init( + bbox_pred_init / normalize_term, + bbox_gt_init / normalize_term, + bbox_weights_init, + avg_factor=avg_factor_init) + loss_pts_refine = self.loss_bbox_refine( + bbox_pred_refine / normalize_term, + bbox_gt_refine / normalize_term, + bbox_weights_refine, + avg_factor=avg_factor_refine) + return loss_cls, loss_pts_init, loss_pts_refine + + def loss_by_feat( + self, + cls_scores: List[Tensor], + pts_preds_init: List[Tensor], + pts_preds_refine: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, of shape (batch_size, num_classes, h, w). + pts_preds_init (list[Tensor]): Points for each scale level, each is + a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2). + pts_preds_refine (list[Tensor]): Points refined for each scale + level, each is a 3D-tensor, of shape + (batch_size, h_i * w_i, num_points * 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + device = cls_scores[0].device + + # target for initial stage + center_list, valid_flag_list = self.get_points(featmap_sizes, + batch_img_metas, device) + pts_coordinate_preds_init = self.offset_to_pts(center_list, + pts_preds_init) + if self.train_cfg['init']['assigner']['type'] == 'PointAssigner': + # Assign target for center list + candidate_list = center_list + else: + # transform center list to bbox list and + # assign target for bbox list + bbox_list = self.centers_to_bboxes(center_list) + candidate_list = bbox_list + cls_reg_targets_init = self.get_targets( + proposals_list=candidate_list, + valid_flag_list=valid_flag_list, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + stage='init', + return_sampling_results=False) + (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, + avg_factor_init) = cls_reg_targets_init + + # target for refinement stage + center_list, valid_flag_list = self.get_points(featmap_sizes, + batch_img_metas, device) + pts_coordinate_preds_refine = self.offset_to_pts( + center_list, pts_preds_refine) + bbox_list = [] + for i_img, center in enumerate(center_list): + bbox = [] + for i_lvl in range(len(pts_preds_refine)): + bbox_preds_init = self.points2bbox( + pts_preds_init[i_lvl].detach()) + bbox_shift = bbox_preds_init * self.point_strides[i_lvl] + bbox_center = torch.cat( + [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1) + bbox.append(bbox_center + + bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4)) + bbox_list.append(bbox) + cls_reg_targets_refine = self.get_targets( + proposals_list=bbox_list, + valid_flag_list=valid_flag_list, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + stage='refine', + return_sampling_results=False) + (labels_list, label_weights_list, bbox_gt_list_refine, + candidate_list_refine, bbox_weights_list_refine, + avg_factor_refine) = cls_reg_targets_refine + + # compute loss + losses_cls, losses_pts_init, losses_pts_refine = multi_apply( + self.loss_by_feat_single, + cls_scores, + pts_coordinate_preds_init, + pts_coordinate_preds_refine, + labels_list, + label_weights_list, + bbox_gt_list_init, + bbox_weights_list_init, + bbox_gt_list_refine, + bbox_weights_list_refine, + self.point_strides, + avg_factor_init=avg_factor_init, + avg_factor_refine=avg_factor_refine) + loss_dict_all = { + 'loss_cls': losses_cls, + 'loss_pts_init': losses_pts_init, + 'loss_pts_refine': losses_pts_refine + } + return loss_dict_all + + # Same as base_dense_head/_get_bboxes_single except self._bbox_decode + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RepPoints head does not need + this value. + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, has shape + (num_priors, 2). + img_meta (dict): Image meta info. + cfg (:obj:`ConfigDict`): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_score_list) == len(bbox_pred_list) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_labels = [] + for level_idx, (cls_score, bbox_pred, priors) in enumerate( + zip(cls_score_list, bbox_pred_list, mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1)[:, :-1] + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, _, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + + bboxes = self._bbox_decode(priors, bbox_pred, + self.point_strides[level_idx], + img_shape) + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def _bbox_decode(self, points: Tensor, bbox_pred: Tensor, stride: int, + max_shape: Tuple[int, int]) -> Tensor: + """Decode the prediction to bounding box. + + Args: + points (Tensor): shape (h_i * w_i, 2). + bbox_pred (Tensor): shape (h_i * w_i, 4). + stride (int): Stride for bbox_pred in different level. + max_shape (Tuple[int, int]): image shape. + + Returns: + Tensor: Bounding boxes decoded. + """ + bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1) + bboxes = bbox_pred * stride + bbox_pos_center + x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1]) + y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0]) + x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1]) + y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0]) + decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return decoded_bboxes diff --git a/mmdet/models/dense_heads/retina_head.py b/mmdet/models/dense_heads/retina_head.py new file mode 100644 index 0000000000000000000000000000000000000000..be3ae74d81ba38609646f0d0406098ecbdcef688 --- /dev/null +++ b/mmdet/models/dense_heads/retina_head.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmdet.registry import MODELS +from .anchor_head import AnchorHead + + +@MODELS.register_module() +class RetinaHead(AnchorHead): + r"""An anchor-based head used in `RetinaNet + `_. + + The head contains two subnetworks. The first classifies anchor boxes and + the second regresses deltas for the anchors. + + Example: + >>> import torch + >>> self = RetinaHead(11, 7) + >>> x = torch.rand(1, 7, 32, 32) + >>> cls_score, bbox_pred = self.forward_single(x) + >>> # Each anchor predicts a score for each class except background + >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors + >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors + >>> assert cls_per_anchor == (self.num_classes) + >>> assert box_per_anchor == 4 + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + conv_cfg=None, + norm_cfg=None, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01)), + **kwargs): + assert stacked_convs >= 0, \ + '`stacked_convs` must be non-negative integers, ' \ + f'but got {stacked_convs} instead.' + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + super(RetinaHead, self).__init__( + num_classes, + in_channels, + anchor_generator=anchor_generator, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + in_channels = self.in_channels + for i in range(self.stacked_convs): + self.cls_convs.append( + ConvModule( + in_channels, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + in_channels, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + in_channels = self.feat_channels + self.retina_cls = nn.Conv2d( + in_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + reg_dim = self.bbox_coder.encode_size + self.retina_reg = nn.Conv2d( + in_channels, self.num_base_priors * reg_dim, 3, padding=1) + + def forward_single(self, x): + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale + level, the channels number is num_anchors * 4. + """ + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_pred = self.retina_reg(reg_feat) + return cls_score, bbox_pred diff --git a/mmdet/models/dense_heads/retina_sepbn_head.py b/mmdet/models/dense_heads/retina_sepbn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..681a39983a08670adaa3e24a4099c4f26bc967ce --- /dev/null +++ b/mmdet/models/dense_heads/retina_sepbn_head.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import bias_init_with_prob, normal_init +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig +from .anchor_head import AnchorHead + + +@MODELS.register_module() +class RetinaSepBNHead(AnchorHead): + """"RetinaHead with separate BN. + + In RetinaHead, conv/norm layers are shared across different FPN levels, + while in RetinaSepBNHead, conv layers are shared across different FPN + levels, but BN layers are separated. + """ + + def __init__(self, + num_classes: int, + num_ins: int, + in_channels: int, + stacked_convs: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.num_ins = num_ins + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.num_ins): + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + for j in range(self.stacked_convs): + chn = self.in_channels if j == 0 else self.feat_channels + cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.cls_convs.append(cls_convs) + self.reg_convs.append(reg_convs) + for i in range(self.stacked_convs): + for j in range(1, self.num_ins): + self.cls_convs[j][i].conv = self.cls_convs[0][i].conv + self.reg_convs[j][i].conv = self.reg_convs[0][i].conv + self.retina_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.retina_reg = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + + def init_weights(self) -> None: + """Initialize weights of the head.""" + super().init_weights() + for m in self.cls_convs[0]: + normal_init(m.conv, std=0.01) + for m in self.reg_convs[0]: + normal_init(m.conv, std=0.01) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.retina_cls, std=0.01, bias=bias_cls) + normal_init(self.retina_reg, std=0.01) + + def forward(self, feats: Tuple[Tensor]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + + - cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, the channels number is + num_anchors * 4. + """ + cls_scores = [] + bbox_preds = [] + for i, x in enumerate(feats): + cls_feat = feats[i] + reg_feat = feats[i] + for cls_conv in self.cls_convs[i]: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs[i]: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_pred = self.retina_reg(reg_feat) + cls_scores.append(cls_score) + bbox_preds.append(bbox_pred) + return cls_scores, bbox_preds diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6b544009d2ffc4c3c9065707a0a8a72c577eb432 --- /dev/null +++ b/mmdet/models/dense_heads/rpn_head.py @@ -0,0 +1,302 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.ops import batched_nms +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import (cat_boxes, empty_box_as, get_box_tensor, + get_box_wh, scale_boxes) +from mmdet.utils import InstanceList, MultiConfig, OptInstanceList +from .anchor_head import AnchorHead + + +@MODELS.register_module() +class RPNHead(AnchorHead): + """Implementation of RPN head. + + Args: + in_channels (int): Number of channels in the input feature map. + num_classes (int): Number of categories excluding the background + category. Defaults to 1. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \ + list[dict]): Initialization config dict. + num_convs (int): Number of convolution layers in the head. + Defaults to 1. + """ # noqa: W605 + + def __init__(self, + in_channels: int, + num_classes: int = 1, + init_cfg: MultiConfig = dict( + type='Normal', layer='Conv2d', std=0.01), + num_convs: int = 1, + **kwargs) -> None: + self.num_convs = num_convs + assert num_classes == 1 + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + if self.num_convs > 1: + rpn_convs = [] + for i in range(self.num_convs): + if i == 0: + in_channels = self.in_channels + else: + in_channels = self.feat_channels + # use ``inplace=False`` to avoid error: one of the variables + # needed for gradient computation has been modified by an + # inplace operation. + rpn_convs.append( + ConvModule( + in_channels, + self.feat_channels, + 3, + padding=1, + inplace=False)) + self.rpn_conv = nn.Sequential(*rpn_convs) + else: + self.rpn_conv = nn.Conv2d( + self.in_channels, self.feat_channels, 3, padding=1) + self.rpn_cls = nn.Conv2d(self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 1) + reg_dim = self.bbox_coder.encode_size + self.rpn_reg = nn.Conv2d(self.feat_channels, + self.num_base_priors * reg_dim, 1) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level \ + the channels number is num_base_priors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale \ + level, the channels number is num_base_priors * 4. + """ + x = self.rpn_conv(x) + x = F.relu(x) + rpn_cls_score = self.rpn_cls(x) + rpn_bbox_pred = self.rpn_reg(x) + return rpn_cls_score, rpn_bbox_pred + + def loss_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) \ + -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[obj:InstanceData]): Batch of gt_instance. + It usually includes ``bboxes`` and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[obj:InstanceData], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + losses = super().loss_by_feat( + cls_scores, + bbox_preds, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + return dict( + loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Be compatible with + BaseDenseHead. Not used in RPNHead. + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (ConfigDict, optional): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + level_ids = [] + for level_idx, (cls_score, bbox_pred, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, + mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + reg_dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, reg_dim) + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0] since mmdet v2.0 + # BG cat_id: 1 + scores = cls_score.softmax(-1)[:, :-1] + + scores = torch.squeeze(scores) + if 0 < nms_pre < scores.shape[0]: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] + bbox_pred = bbox_pred[topk_inds, :] + priors = priors[topk_inds] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + + # use level id to implement the separate level nms + level_ids.append( + scores.new_full((scores.size(0), ), + level_idx, + dtype=torch.long)) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.scores = torch.cat(mlvl_scores) + results.level_ids = torch.cat(level_ids) + + return self._bbox_post_process( + results=results, cfg=cfg, rescale=rescale, img_meta=img_meta) + + def _bbox_post_process(self, + results: InstanceData, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> InstanceData: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (ConfigDict): Test / postprocessing configuration. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert with_nms, '`with_nms` must be True in RPNHead' + if rescale: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + results.bboxes = scale_boxes(results.bboxes, scale_factor) + + # filter small size bboxes + if cfg.get('min_bbox_size', -1) >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + if results.bboxes.numel() > 0: + bboxes = get_box_tensor(results.bboxes) + det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, + results.level_ids, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + results = results[:cfg.max_per_img] + # TODO: This would unreasonably show the 0th class label + # in visualization + results.labels = results.scores.new_zeros( + len(results), dtype=torch.long) + del results.level_ids + else: + # To avoid some potential error + results_ = InstanceData() + results_.bboxes = empty_box_as(results.bboxes) + results_.scores = results.scores.new_zeros(0) + results_.labels = results.scores.new_zeros(0) + results = results_ + return results diff --git a/mmdet/models/dense_heads/rtmdet_head.py b/mmdet/models/dense_heads/rtmdet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0ee6d2f35a0fa46ba0b8de21054433d0420b65 --- /dev/null +++ b/mmdet/models/dense_heads/rtmdet_head.py @@ -0,0 +1,692 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule, Scale, is_norm +from mmengine.model import bias_init_with_prob, constant_init, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import distance2bbox +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from ..layers.transformer import inverse_sigmoid +from ..task_modules import anchor_inside_flags +from ..utils import (images_to_levels, multi_apply, sigmoid_geometric_mean, + unmap) +from .atss_head import ATSSHead + + +@MODELS.register_module() +class RTMDetHead(ATSSHead): + """Detection Head of RTMDet. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + with_objectness (bool): Whether to add an objectness branch. + Defaults to True. + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Default: dict(type='ReLU') + """ + + def __init__(self, + num_classes: int, + in_channels: int, + with_objectness: bool = True, + act_cfg: ConfigType = dict(type='ReLU'), + **kwargs) -> None: + self.act_cfg = act_cfg + self.with_objectness = with_objectness + super().__init__(num_classes, in_channels, **kwargs) + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + + def _init_layers(self): + """Initialize layers of the head.""" + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + pred_pad_size = self.pred_kernel_size // 2 + self.rtm_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + self.pred_kernel_size, + padding=pred_pad_size) + self.rtm_reg = nn.Conv2d( + self.feat_channels, + self.num_base_priors * 4, + self.pred_kernel_size, + padding=pred_pad_size) + if self.with_objectness: + self.rtm_obj = nn.Conv2d( + self.feat_channels, + 1, + self.pred_kernel_size, + padding=pred_pad_size) + + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + def init_weights(self) -> None: + """Initialize weights of the head.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.rtm_cls, std=0.01, bias=bias_cls) + normal_init(self.rtm_reg, std=0.01) + if self.with_objectness: + normal_init(self.rtm_obj, std=0.01, bias=bias_cls) + + def forward(self, feats: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + - cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + """ + + cls_scores = [] + bbox_preds = [] + for idx, (x, scale, stride) in enumerate( + zip(feats, self.scales, self.prior_generator.strides)): + cls_feat = x + reg_feat = x + + for cls_layer in self.cls_convs: + cls_feat = cls_layer(cls_feat) + cls_score = self.rtm_cls(cls_feat) + + for reg_layer in self.reg_convs: + reg_feat = reg_layer(reg_feat) + + if self.with_objectness: + objectness = self.rtm_obj(reg_feat) + cls_score = inverse_sigmoid( + sigmoid_geometric_mean(cls_score, objectness)) + + reg_dist = scale(self.rtm_reg(reg_feat).exp()).float() * stride[0] + + cls_scores.append(cls_score) + bbox_preds.append(reg_dist) + return tuple(cls_scores), tuple(bbox_preds) + + def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + labels: Tensor, label_weights: Tensor, + bbox_targets: Tensor, assign_metrics: Tensor, + stride: List[int]): + """Compute loss of a single scale level. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Decoded bboxes for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors). + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + assign_metrics (Tensor): Assign metrics with shape + (N, num_total_anchors). + stride (List[int]): Downsample stride of the feature map. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.reshape(-1, 4) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + assign_metrics = assign_metrics.reshape(-1) + label_weights = label_weights.reshape(-1) + targets = (labels, assign_metrics) + + loss_cls = self.loss_cls( + cls_score, targets, label_weights, avg_factor=1.0) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + + pos_decode_bbox_pred = pos_bbox_pred + pos_decode_bbox_targets = pos_bbox_targets + + # regression loss + pos_bbox_weight = assign_metrics[pos_inds] + + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=pos_bbox_weight, + avg_factor=1.0) + else: + loss_bbox = bbox_pred.sum() * 0 + pos_bbox_weight = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, assign_metrics.sum(), pos_bbox_weight.sum() + + def loss_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(batch_img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1) + decoded_bboxes = [] + for anchor, bbox_pred in zip(anchor_list[0], bbox_preds): + anchor = anchor.reshape(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + bbox_pred = distance2bbox(anchor, bbox_pred) + decoded_bboxes.append(bbox_pred) + + flatten_bboxes = torch.cat(decoded_bboxes, 1) + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bboxes, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + assign_metrics_list, sampling_results_list) = cls_reg_targets + + losses_cls, losses_bbox,\ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_by_feat_single, + cls_scores, + decoded_bboxes, + labels_list, + label_weights_list, + bbox_targets_list, + assign_metrics_list, + self.prior_generator.strides) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + + def get_targets(self, + cls_scores: Tensor, + bbox_preds: Tensor, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores (Tensor): Classification predictions of images, + a 3D-Tensor with shape [num_imgs, num_priors, num_classes]. + bbox_preds (Tensor): Decoded bboxes predictions of one image, + a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x, + tl_y, br_x, br_y] format. + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: a tuple containing learning targets. + + - anchors_list (list[list[Tensor]]): Anchors of each level. + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - assign_metrics_list (list[Tensor]): alignment metrics of each + level. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + # anchor_list: list(b * [-1, 4]) + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_assign_metrics, sampling_results_list) = multi_apply( + self._get_targets_single, + cls_scores.detach(), + bbox_preds.detach(), + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + assign_metrics_list = images_to_levels(all_assign_metrics, + num_level_anchors) + + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, assign_metrics_list, sampling_results_list) + + def _get_targets_single(self, + cls_scores: Tensor, + bbox_preds: Tensor, + flat_anchors: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs=True): + """Compute regression, classification targets for anchors in a single + image. + + Args: + cls_scores (list(Tensor)): Box scores for each image. + bbox_preds (list(Tensor)): Box energies / deltas for each image. + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: N is the number of total anchors in the image. + + - anchors (Tensor): All anchors in the image with shape (N, 4). + - labels (Tensor): Labels of all anchors in the image with shape + (N,). + - label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + - bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + - norm_alignment_metrics (Tensor): Normalized alignment metrics + of all priors in the image with shape (N,). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + pred_instances = InstanceData( + scores=cls_scores[inside_flags, :], + bboxes=bbox_preds[inside_flags, :], + priors=anchors) + + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + assign_metrics = anchors.new_zeros( + num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + # point-based + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + class_assigned_gt_inds = torch.unique( + sampling_result.pos_assigned_gt_inds) + for gt_inds in class_assigned_gt_inds: + gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds == + gt_inds] + assign_metrics[gt_class_inds] = assign_result.max_overlaps[ + gt_class_inds] + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + assign_metrics = unmap(assign_metrics, num_total_anchors, + inside_flags) + return (anchors, labels, label_weights, bbox_targets, assign_metrics, + sampling_result) + + def get_anchors(self, + featmap_sizes: List[tuple], + batch_img_metas: List[dict], + device: Union[torch.device, str] = 'cuda') \ + -> Tuple[List[List[Tensor]], List[List[Tensor]]]: + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + batch_img_metas (list[dict]): Image meta info. + device (torch.device or str): Device for returned tensors. + Defaults to cuda. + + Returns: + tuple: + + - anchor_list (list[list[Tensor]]): Anchors of each image. + - valid_flag_list (list[list[Tensor]]): Valid flags of each + image. + """ + num_imgs = len(batch_img_metas) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = self.prior_generator.grid_priors( + featmap_sizes, device=device, with_stride=True) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + multi_level_flags = self.prior_generator.valid_flags( + featmap_sizes, img_meta['pad_shape'], device) + valid_flag_list.append(multi_level_flags) + return anchor_list, valid_flag_list + + +@MODELS.register_module() +class RTMDetSepBNHead(RTMDetHead): + """RTMDetHead with separated BN layers and shared conv layers. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + share_conv (bool): Whether to share conv layers between stages. + Defaults to True. + use_depthwise (bool): Whether to use depthwise separable convolution in + head. Defaults to False. + norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization + layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001). + act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer. + Defaults to dict(type='SiLU'). + pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + share_conv: bool = True, + use_depthwise: bool = False, + norm_cfg: ConfigType = dict( + type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='SiLU'), + pred_kernel_size: int = 1, + exp_on_reg=False, + **kwargs) -> None: + self.share_conv = share_conv + self.exp_on_reg = exp_on_reg + self.use_depthwise = use_depthwise + super().__init__( + num_classes, + in_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + pred_kernel_size=pred_kernel_size, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + conv = DepthwiseSeparableConvModule \ + if self.use_depthwise else ConvModule + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + + self.rtm_cls = nn.ModuleList() + self.rtm_reg = nn.ModuleList() + if self.with_objectness: + self.rtm_obj = nn.ModuleList() + for n in range(len(self.prior_generator.strides)): + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + conv( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + reg_convs.append( + conv( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.cls_convs.append(cls_convs) + self.reg_convs.append(reg_convs) + + self.rtm_cls.append( + nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + self.pred_kernel_size, + padding=self.pred_kernel_size // 2)) + self.rtm_reg.append( + nn.Conv2d( + self.feat_channels, + self.num_base_priors * 4, + self.pred_kernel_size, + padding=self.pred_kernel_size // 2)) + if self.with_objectness: + self.rtm_obj.append( + nn.Conv2d( + self.feat_channels, + 1, + self.pred_kernel_size, + padding=self.pred_kernel_size // 2)) + + if self.share_conv: + for n in range(len(self.prior_generator.strides)): + for i in range(self.stacked_convs): + self.cls_convs[n][i].conv = self.cls_convs[0][i].conv + self.reg_convs[n][i].conv = self.reg_convs[0][i].conv + + def init_weights(self) -> None: + """Initialize weights of the head.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + bias_cls = bias_init_with_prob(0.01) + for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg): + normal_init(rtm_cls, std=0.01, bias=bias_cls) + normal_init(rtm_reg, std=0.01) + if self.with_objectness: + for rtm_obj in self.rtm_obj: + normal_init(rtm_obj, std=0.01, bias=bias_cls) + + def forward(self, feats: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + + - cls_scores (tuple[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + - bbox_preds (tuple[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * 4. + """ + + cls_scores = [] + bbox_preds = [] + for idx, (x, stride) in enumerate( + zip(feats, self.prior_generator.strides)): + cls_feat = x + reg_feat = x + + for cls_layer in self.cls_convs[idx]: + cls_feat = cls_layer(cls_feat) + cls_score = self.rtm_cls[idx](cls_feat) + + for reg_layer in self.reg_convs[idx]: + reg_feat = reg_layer(reg_feat) + + if self.with_objectness: + objectness = self.rtm_obj[idx](reg_feat) + cls_score = inverse_sigmoid( + sigmoid_geometric_mean(cls_score, objectness)) + if self.exp_on_reg: + reg_dist = self.rtm_reg[idx](reg_feat).exp() * stride[0] + else: + reg_dist = self.rtm_reg[idx](reg_feat) * stride[0] + cls_scores.append(cls_score) + bbox_preds.append(reg_dist) + return tuple(cls_scores), tuple(bbox_preds) diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py new file mode 100644 index 0000000000000000000000000000000000000000..261a57fe485245dcbe41696c9237258f829ca25a --- /dev/null +++ b/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -0,0 +1,1034 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, is_norm +from mmcv.ops import batched_nms +from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, + normal_init) +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.models.utils import (filter_scores_and_topk, multi_apply, + select_single_mlvl, sigmoid_geometric_mean) +from mmdet.registry import MODELS +from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor, + get_box_wh, scale_boxes) +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from .rtmdet_head import RTMDetHead + + +@MODELS.register_module() +class RTMDetInsHead(RTMDetHead): + """Detection Head of RTMDet-Ins. + + Args: + num_prototypes (int): Number of mask prototype features extracted + from the mask head. Defaults to 8. + dyconv_channels (int): Channel of the dynamic conv layers. + Defaults to 8. + num_dyconvs (int): Number of the dynamic convolution layers. + Defaults to 3. + mask_loss_stride (int): Down sample stride of the masks for loss + computation. Defaults to 4. + loss_mask (:obj:`ConfigDict` or dict): Config dict for mask loss. + """ + + def __init__(self, + *args, + num_prototypes: int = 8, + dyconv_channels: int = 8, + num_dyconvs: int = 3, + mask_loss_stride: int = 4, + loss_mask=dict( + type='DiceLoss', + loss_weight=2.0, + eps=5e-6, + reduction='mean'), + **kwargs) -> None: + self.num_prototypes = num_prototypes + self.num_dyconvs = num_dyconvs + self.dyconv_channels = dyconv_channels + self.mask_loss_stride = mask_loss_stride + super().__init__(*args, **kwargs) + self.loss_mask = MODELS.build(loss_mask) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + super()._init_layers() + # a branch to predict kernels of dynamic convs + self.kernel_convs = nn.ModuleList() + # calculate num dynamic parameters + weight_nums, bias_nums = [], [] + for i in range(self.num_dyconvs): + if i == 0: + weight_nums.append( + # mask prototype and coordinate features + (self.num_prototypes + 2) * self.dyconv_channels) + bias_nums.append(self.dyconv_channels * 1) + elif i == self.num_dyconvs - 1: + weight_nums.append(self.dyconv_channels * 1) + bias_nums.append(1) + else: + weight_nums.append(self.dyconv_channels * self.dyconv_channels) + bias_nums.append(self.dyconv_channels * 1) + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_gen_params = sum(weight_nums) + sum(bias_nums) + + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.kernel_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + pred_pad_size = self.pred_kernel_size // 2 + self.rtm_kernel = nn.Conv2d( + self.feat_channels, + self.num_gen_params, + self.pred_kernel_size, + padding=pred_pad_size) + self.mask_head = MaskFeatModule( + in_channels=self.in_channels, + feat_channels=self.feat_channels, + stacked_convs=4, + num_levels=len(self.prior_generator.strides), + num_prototypes=self.num_prototypes, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg) + + def forward(self, feats: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + - cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale + levels, each is a 4D-tensor, the channels number is + num_gen_params. + - mask_feat (Tensor): Output feature of the mask head. Each is a + 4D-tensor, the channels number is num_prototypes. + """ + mask_feat = self.mask_head(feats) + + cls_scores = [] + bbox_preds = [] + kernel_preds = [] + for idx, (x, scale, stride) in enumerate( + zip(feats, self.scales, self.prior_generator.strides)): + cls_feat = x + reg_feat = x + kernel_feat = x + + for cls_layer in self.cls_convs: + cls_feat = cls_layer(cls_feat) + cls_score = self.rtm_cls(cls_feat) + + for kernel_layer in self.kernel_convs: + kernel_feat = kernel_layer(kernel_feat) + kernel_pred = self.rtm_kernel(kernel_feat) + + for reg_layer in self.reg_convs: + reg_feat = reg_layer(reg_feat) + + if self.with_objectness: + objectness = self.rtm_obj(reg_feat) + cls_score = inverse_sigmoid( + sigmoid_geometric_mean(cls_score, objectness)) + + reg_dist = scale(self.rtm_reg(reg_feat)) * stride[0] + + cls_scores.append(cls_score) + bbox_preds.append(reg_dist) + kernel_preds.append(kernel_pred) + return tuple(cls_scores), tuple(bbox_preds), tuple( + kernel_preds), mask_feat + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + kernel_preds: List[Tensor], + mask_feat: Tensor, + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigType] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + kernel_preds (list[Tensor]): Kernel predictions of dynamic + convs for all scale levels, each is a 4D-tensor, has shape + (batch_size, num_params, H, W). + mask_feat (Tensor): Mask prototype features extracted from the + mask head, has shape (batch_size, num_prototypes, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). + """ + assert len(cls_scores) == len(bbox_preds) + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device, + with_stride=True) + + result_list = [] + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl( + cls_scores, img_id, detach=True) + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + kernel_pred_list = select_single_mlvl( + kernel_preds, img_id, detach=True) + if with_score_factors: + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + else: + score_factor_list = [None for _ in range(num_levels)] + + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + kernel_pred_list=kernel_pred_list, + mask_feat=mask_feat[img_id], + score_factor_list=score_factor_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + kernel_pred_list: List[Tensor], + mask_feat: Tensor, + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigType, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox and mask results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + kernel_preds (list[Tensor]): Kernel predictions of dynamic + convs for all scale levels of a single image, each is a + 4D-tensor, has shape (num_params, H, W). + mask_feat (Tensor): Mask prototype features of a single image + extracted from the mask head, has shape (num_prototypes, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_kernels = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + + for level_idx, (cls_score, bbox_pred, kernel_pred, + score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, kernel_pred_list, + score_factor_list, mlvl_priors)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + kernel_pred = kernel_pred.permute(1, 2, 0).reshape( + -1, self.num_gen_params) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = cls_score.softmax(-1)[:, :-1] + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict( + bbox_pred=bbox_pred, + priors=priors, + kernel_pred=kernel_pred)) + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + kernel_pred = filtered_results['kernel_pred'] + + if with_score_factors: + score_factor = score_factor[keep_idxs] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + mlvl_kernels.append(kernel_pred) + + if with_score_factors: + mlvl_score_factors.append(score_factor) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode( + priors[..., :2], bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.priors = priors + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + results.kernels = torch.cat(mlvl_kernels) + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + + return self._bbox_mask_post_process( + results=results, + mask_feat=mask_feat, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def _bbox_mask_post_process( + self, + results: InstanceData, + mask_feat, + cfg: ConfigType, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> InstanceData: + """bbox and mask post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (ConfigDict): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). + """ + stride = self.prior_generator.strides[0][0] + if rescale: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + results.bboxes = scale_boxes(results.bboxes, scale_factor) + + if hasattr(results, 'score_factors'): + # TODO: Add sqrt operation in order to be consistent with + # the paper. + score_factors = results.pop('score_factors') + results.scores = results.scores * score_factors + + # filter small size bboxes + if cfg.get('min_bbox_size', -1) >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg + assert with_nms, 'with_nms must be True for RTMDet-Ins' + if results.bboxes.numel() > 0: + bboxes = get_box_tensor(results.bboxes) + det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, + results.labels, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + results = results[:cfg.max_per_img] + + # process masks + mask_logits = self._mask_predict_by_feat_single( + mask_feat, results.kernels, results.priors) + + mask_logits = F.interpolate( + mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear') + if rescale: + ori_h, ori_w = img_meta['ori_shape'][:2] + mask_logits = F.interpolate( + mask_logits, + size=[ + math.ceil(mask_logits.shape[-2] * scale_factor[0]), + math.ceil(mask_logits.shape[-1] * scale_factor[1]) + ], + mode='bilinear', + align_corners=False)[..., :ori_h, :ori_w] + masks = mask_logits.sigmoid().squeeze(0) + masks = masks > cfg.mask_thr_binary + results.masks = masks + else: + h, w = img_meta['ori_shape'][:2] if rescale else img_meta[ + 'img_shape'][:2] + results.masks = torch.zeros( + size=(results.bboxes.shape[0], h, w), + dtype=torch.bool, + device=results.bboxes.device) + + return results + + def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple: + """split kernel head prediction to conv weight and bias.""" + n_inst = flatten_kernels.size(0) + n_layers = len(self.weight_nums) + params_splits = list( + torch.split_with_sizes( + flatten_kernels, self.weight_nums + self.bias_nums, dim=1)) + weight_splits = params_splits[:n_layers] + bias_splits = params_splits[n_layers:] + for i in range(n_layers): + if i < n_layers - 1: + weight_splits[i] = weight_splits[i].reshape( + n_inst * self.dyconv_channels, -1, 1, 1) + bias_splits[i] = bias_splits[i].reshape(n_inst * + self.dyconv_channels) + else: + weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1) + bias_splits[i] = bias_splits[i].reshape(n_inst) + + return weight_splits, bias_splits + + def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, + priors: Tensor) -> Tensor: + """Generate mask logits from mask features with dynamic convs. + + Args: + mask_feat (Tensor): Mask prototype features. + Has shape (num_prototypes, H, W). + kernels (Tensor): Kernel parameters for each instance. + Has shape (num_instance, num_params) + priors (Tensor): Center priors for each instance. + Has shape (num_instance, 4). + Returns: + Tensor: Instance segmentation masks for each instance. + Has shape (num_instance, H, W). + """ + num_inst = priors.shape[0] + h, w = mask_feat.size()[-2:] + if num_inst < 1: + return torch.empty( + size=(num_inst, h, w), + dtype=mask_feat.dtype, + device=mask_feat.device) + if len(mask_feat.shape) < 4: + mask_feat.unsqueeze(0) + + coord = self.prior_generator.single_level_grid_priors( + (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2) + num_inst = priors.shape[0] + points = priors[:, :2].reshape(-1, 1, 2) + strides = priors[:, 2:].reshape(-1, 1, 2) + relative_coord = (points - coord).permute(0, 2, 1) / ( + strides[..., 0].reshape(-1, 1, 1) * 8) + relative_coord = relative_coord.reshape(num_inst, 2, h, w) + + mask_feat = torch.cat( + [relative_coord, + mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) + weights, biases = self.parse_dynamic_params(kernels) + + n_layers = len(weights) + x = mask_feat.reshape(1, -1, h, w) + for i, (weight, bias) in enumerate(zip(weights, biases)): + x = F.conv2d( + x, weight, bias=bias, stride=1, padding=0, groups=num_inst) + if i < n_layers - 1: + x = F.relu(x) + x = x.reshape(num_inst, h, w) + return x + + def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, + sampling_results_list: list, + batch_gt_instances: InstanceList) -> Tensor: + """Compute instance segmentation loss. + + Args: + mask_feats (list[Tensor]): Mask prototype features extracted from + the mask head. Has shape (N, num_prototypes, H, W) + flatten_kernels (list[Tensor]): Kernels of the dynamic conv layers. + Has shape (N, num_instances, num_params) + sampling_results_list (list[:obj:`SamplingResults`]) Batch of + assignment results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + Tensor: The mask loss tensor. + """ + batch_pos_mask_logits = [] + pos_gt_masks = [] + for idx, (mask_feat, kernels, sampling_results, + gt_instances) in enumerate( + zip(mask_feats, flatten_kernels, sampling_results_list, + batch_gt_instances)): + pos_priors = sampling_results.pos_priors + pos_inds = sampling_results.pos_inds + pos_kernels = kernels[pos_inds] # n_pos, num_gen_params + pos_mask_logits = self._mask_predict_by_feat_single( + mask_feat, pos_kernels, pos_priors) + if gt_instances.masks.numel() == 0: + gt_masks = torch.empty_like(gt_instances.masks) + else: + gt_masks = gt_instances.masks[ + sampling_results.pos_assigned_gt_inds, :] + batch_pos_mask_logits.append(pos_mask_logits) + pos_gt_masks.append(gt_masks) + + pos_gt_masks = torch.cat(pos_gt_masks, 0) + batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) + + # avg_factor + num_pos = batch_pos_mask_logits.shape[0] + num_pos = reduce_mean(mask_feats.new_tensor([num_pos + ])).clamp_(min=1).item() + + if batch_pos_mask_logits.shape[0] == 0: + return mask_feats.sum() * 0 + + scale = self.prior_generator.strides[0][0] // self.mask_loss_stride + # upsample pred masks + batch_pos_mask_logits = F.interpolate( + batch_pos_mask_logits.unsqueeze(0), + scale_factor=scale, + mode='bilinear', + align_corners=False).squeeze(0) + # downsample gt masks + pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride // + 2::self.mask_loss_stride, + self.mask_loss_stride // + 2::self.mask_loss_stride] + + loss_mask = self.loss_mask( + batch_pos_mask_logits, + pos_gt_masks, + weight=None, + avg_factor=num_pos) + + return loss_mask + + def loss_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + kernel_preds: List[Tensor], + mask_feat: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(batch_img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1) + flatten_kernels = torch.cat([ + kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.num_gen_params) + for kernel_pred in kernel_preds + ], 1) + decoded_bboxes = [] + for anchor, bbox_pred in zip(anchor_list[0], bbox_preds): + anchor = anchor.reshape(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + bbox_pred = distance2bbox(anchor, bbox_pred) + decoded_bboxes.append(bbox_pred) + + flatten_bboxes = torch.cat(decoded_bboxes, 1) + for gt_instances in batch_gt_instances: + gt_instances.masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device) + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bboxes, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + assign_metrics_list, sampling_results_list) = cls_reg_targets + + losses_cls, losses_bbox,\ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_by_feat_single, + cls_scores, + decoded_bboxes, + labels_list, + label_weights_list, + bbox_targets_list, + assign_metrics_list, + self.prior_generator.strides) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + + loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, + sampling_results_list, + batch_gt_instances) + loss = dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask) + return loss + + +class MaskFeatModule(BaseModule): + """Mask feature head used in RTMDet-Ins. + + Args: + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels of the mask feature + map branch. + num_levels (int): The starting feature map level from RPN that + will be used to predict the mask feature map. + num_prototypes (int): Number of output channel of the mask feature + map branch. This is the channel count of the mask + feature map that to be dynamically convolved with the predicted + kernel. + stacked_convs (int): Number of convs in mask feature branch. + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True) + norm_cfg (dict): Config dict for normalization layer. Default: None. + """ + + def __init__( + self, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 4, + num_levels: int = 3, + num_prototypes: int = 8, + act_cfg: ConfigType = dict(type='ReLU', inplace=True), + norm_cfg: ConfigType = dict(type='BN') + ) -> None: + super().__init__(init_cfg=None) + self.num_levels = num_levels + self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1) + convs = [] + for i in range(stacked_convs): + in_c = in_channels if i == 0 else feat_channels + convs.append( + ConvModule( + in_c, + feat_channels, + 3, + padding=1, + act_cfg=act_cfg, + norm_cfg=norm_cfg)) + self.stacked_convs = nn.Sequential(*convs) + self.projection = nn.Conv2d( + feat_channels, num_prototypes, kernel_size=1) + + def forward(self, features: Tuple[Tensor, ...]) -> Tensor: + # multi-level feature fusion + fusion_feats = [features[0]] + size = features[0].shape[-2:] + for i in range(1, self.num_levels): + f = F.interpolate(features[i], size=size, mode='bilinear') + fusion_feats.append(f) + fusion_feats = torch.cat(fusion_feats, dim=1) + fusion_feats = self.fusion_conv(fusion_feats) + # pred mask feats + mask_features = self.stacked_convs(fusion_feats) + mask_features = self.projection(mask_features) + return mask_features + + +@MODELS.register_module() +class RTMDetInsSepBNHead(RTMDetInsHead): + """Detection Head of RTMDet-Ins with sep-bn layers. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + share_conv (bool): Whether to share conv layers between stages. + Defaults to True. + norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization + layer. Defaults to dict(type='BN'). + act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer. + Defaults to dict(type='SiLU', inplace=True). + pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + share_conv: bool = True, + with_objectness: bool = False, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + act_cfg: ConfigType = dict(type='SiLU', inplace=True), + pred_kernel_size: int = 1, + **kwargs) -> None: + self.share_conv = share_conv + super().__init__( + num_classes, + in_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + pred_kernel_size=pred_kernel_size, + with_objectness=with_objectness, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.kernel_convs = nn.ModuleList() + + self.rtm_cls = nn.ModuleList() + self.rtm_reg = nn.ModuleList() + self.rtm_kernel = nn.ModuleList() + self.rtm_obj = nn.ModuleList() + + # calculate num dynamic parameters + weight_nums, bias_nums = [], [] + for i in range(self.num_dyconvs): + if i == 0: + weight_nums.append( + (self.num_prototypes + 2) * self.dyconv_channels) + bias_nums.append(self.dyconv_channels) + elif i == self.num_dyconvs - 1: + weight_nums.append(self.dyconv_channels) + bias_nums.append(1) + else: + weight_nums.append(self.dyconv_channels * self.dyconv_channels) + bias_nums.append(self.dyconv_channels) + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_gen_params = sum(weight_nums) + sum(bias_nums) + pred_pad_size = self.pred_kernel_size // 2 + + for n in range(len(self.prior_generator.strides)): + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + kernel_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + kernel_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.cls_convs.append(cls_convs) + self.reg_convs.append(cls_convs) + self.kernel_convs.append(kernel_convs) + + self.rtm_cls.append( + nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + self.pred_kernel_size, + padding=pred_pad_size)) + self.rtm_reg.append( + nn.Conv2d( + self.feat_channels, + self.num_base_priors * 4, + self.pred_kernel_size, + padding=pred_pad_size)) + self.rtm_kernel.append( + nn.Conv2d( + self.feat_channels, + self.num_gen_params, + self.pred_kernel_size, + padding=pred_pad_size)) + if self.with_objectness: + self.rtm_obj.append( + nn.Conv2d( + self.feat_channels, + 1, + self.pred_kernel_size, + padding=pred_pad_size)) + + if self.share_conv: + for n in range(len(self.prior_generator.strides)): + for i in range(self.stacked_convs): + self.cls_convs[n][i].conv = self.cls_convs[0][i].conv + self.reg_convs[n][i].conv = self.reg_convs[0][i].conv + + self.mask_head = MaskFeatModule( + in_channels=self.in_channels, + feat_channels=self.feat_channels, + stacked_convs=4, + num_levels=len(self.prior_generator.strides), + num_prototypes=self.num_prototypes, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg) + + def init_weights(self) -> None: + """Initialize weights of the head.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + bias_cls = bias_init_with_prob(0.01) + for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg, + self.rtm_kernel): + normal_init(rtm_cls, std=0.01, bias=bias_cls) + normal_init(rtm_reg, std=0.01, bias=1) + if self.with_objectness: + for rtm_obj in self.rtm_obj: + normal_init(rtm_obj, std=0.01, bias=bias_cls) + + def forward(self, feats: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + - cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale + levels, each is a 4D-tensor, the channels number is + num_gen_params. + - mask_feat (Tensor): Output feature of the mask head. Each is a + 4D-tensor, the channels number is num_prototypes. + """ + mask_feat = self.mask_head(feats) + + cls_scores = [] + bbox_preds = [] + kernel_preds = [] + for idx, (x, stride) in enumerate( + zip(feats, self.prior_generator.strides)): + cls_feat = x + reg_feat = x + kernel_feat = x + + for cls_layer in self.cls_convs[idx]: + cls_feat = cls_layer(cls_feat) + cls_score = self.rtm_cls[idx](cls_feat) + + for kernel_layer in self.kernel_convs[idx]: + kernel_feat = kernel_layer(kernel_feat) + kernel_pred = self.rtm_kernel[idx](kernel_feat) + + for reg_layer in self.reg_convs[idx]: + reg_feat = reg_layer(reg_feat) + + if self.with_objectness: + objectness = self.rtm_obj[idx](reg_feat) + cls_score = inverse_sigmoid( + sigmoid_geometric_mean(cls_score, objectness)) + + reg_dist = F.relu(self.rtm_reg[idx](reg_feat)) * stride[0] + + cls_scores.append(cls_score) + bbox_preds.append(reg_dist) + kernel_preds.append(kernel_pred) + return tuple(cls_scores), tuple(bbox_preds), tuple( + kernel_preds), mask_feat diff --git a/mmdet/models/dense_heads/sabl_retina_head.py b/mmdet/models/dense_heads/sabl_retina_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd1b71cc2c80035a0378180da70caddf853375d --- /dev/null +++ b/mmdet/models/dense_heads/sabl_retina_head.py @@ -0,0 +1,706 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptInstanceList) +from ..task_modules.samplers import PseudoSampler +from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply, + unmap) +from .base_dense_head import BaseDenseHead +from .guided_anchor_head import GuidedAnchorHead + + +@MODELS.register_module() +class SABLRetinaHead(BaseDenseHead): + """Side-Aware Boundary Localization (SABL) for RetinaNet. + + The anchor generation, assigning and sampling in SABLRetinaHead + are the same as GuidedAnchorHead for guided anchoring. + + Please refer to https://arxiv.org/abs/1912.04260 for more details. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of Convs for classification and + regression branches. Defaults to 4. + feat_channels (int): Number of hidden channels. Defaults to 256. + approx_anchor_generator (:obj:`ConfigType` or dict): Config dict for + approx generator. + square_anchor_generator (:obj:`ConfigDict` or dict): Config dict for + square generator. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + ConvModule. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + Norm Layer. Defaults to None. + bbox_coder (:obj:`ConfigDict` or dict): Config dict for bbox coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Default False. It should be ``True`` when + using ``IoULoss``, ``GIoULoss``, or ``DIoULoss`` in the bbox head. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + SABLRetinaHead. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + SABLRetinaHead. + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox_cls (:obj:`ConfigDict` or dict): Config of classification + loss for bbox branch. + loss_bbox_reg (:obj:`ConfigDict` or dict): Config of regression loss + for bbox branch. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + stacked_convs: int = 4, + feat_channels: int = 256, + approx_anchor_generator: ConfigType = dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator: ConfigType = dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + bbox_coder: ConfigType = dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + reg_decoded_bbox: bool = False, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + loss_cls: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls: ConfigType = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg: ConfigType = dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5), + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', name='retina_cls', std=0.01, bias_prob=0.01)) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_classes = num_classes + self.feat_channels = feat_channels + self.num_buckets = bbox_coder['num_buckets'] + self.side_num = int(np.ceil(self.num_buckets / 2)) + + assert (approx_anchor_generator['octave_base_scale'] == + square_anchor_generator['scales'][0]) + assert (approx_anchor_generator['strides'] == + square_anchor_generator['strides']) + + self.approx_anchor_generator = TASK_UTILS.build( + approx_anchor_generator) + self.square_anchor_generator = TASK_UTILS.build( + square_anchor_generator) + self.approxs_per_octave = ( + self.approx_anchor_generator.num_base_priors[0]) + + # one anchor per location + self.num_base_priors = self.square_anchor_generator.num_base_priors[0] + + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.reg_decoded_bbox = reg_decoded_bbox + + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox_cls = MODELS.build(loss_bbox_cls) + self.loss_bbox_reg = MODELS.build(loss_bbox_reg) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + # use PseudoSampler when sampling is False + if 'sampler' in self.train_cfg: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) + + self._init_layers() + + def _init_layers(self) -> None: + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.retina_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + self.retina_bbox_reg = nn.Conv2d( + self.feat_channels, self.side_num * 4, 3, padding=1) + self.retina_bbox_cls = nn.Conv2d( + self.feat_channels, self.side_num * 4, 3, padding=1) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_cls_pred = self.retina_bbox_cls(reg_feat) + bbox_reg_pred = self.retina_bbox_reg(reg_feat) + bbox_pred = (bbox_cls_pred, bbox_reg_pred) + return cls_score, bbox_pred + + def forward(self, feats: List[Tensor]) -> Tuple[List[Tensor]]: + return multi_apply(self.forward_single, feats) + + def get_anchors( + self, + featmap_sizes: List[tuple], + img_metas: List[dict], + device: Union[torch.device, str] = 'cuda' + ) -> Tuple[List[List[Tensor]], List[List[Tensor]]]: + """Get squares according to feature map sizes and guided anchors. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + device (torch.device | str): device for returned tensors + + Returns: + tuple: square approxs of each image + """ + num_imgs = len(img_metas) + + # since feature map sizes of all images are the same, we only compute + # squares for one time + multi_level_squares = self.square_anchor_generator.grid_priors( + featmap_sizes, device=device) + squares_list = [multi_level_squares for _ in range(num_imgs)] + + return squares_list + + def get_targets(self, + approx_list: List[List[Tensor]], + inside_flag_list: List[List[Tensor]], + square_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas, + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs=True) -> tuple: + """Compute bucketing targets. + + Args: + approx_list (list[list[Tensor]]): Multi level approxs of each + image. + inside_flag_list (list[list[Tensor]]): Multi level inside flags of + each image. + square_list (list[list[Tensor]]): Multi level squares of each + image. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: Returns a tuple containing learning targets. + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each level. + - bbox_cls_targets_list (list[Tensor]): BBox cls targets of \ + each level. + - bbox_cls_weights_list (list[Tensor]): BBox cls weights of \ + each level. + - bbox_reg_targets_list (list[Tensor]): BBox reg targets of \ + each level. + - bbox_reg_weights_list (list[Tensor]): BBox reg weights of \ + each level. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + num_imgs = len(batch_img_metas) + assert len(approx_list) == len(inside_flag_list) == len( + square_list) == num_imgs + # anchor number of multi levels + num_level_squares = [squares.size(0) for squares in square_list[0]] + # concat all level anchors and flags to a single tensor + inside_flag_flat_list = [] + approx_flat_list = [] + square_flat_list = [] + for i in range(num_imgs): + assert len(square_list[i]) == len(inside_flag_list[i]) + inside_flag_flat_list.append(torch.cat(inside_flag_list[i])) + approx_flat_list.append(torch.cat(approx_list[i])) + square_flat_list.append(torch.cat(square_list[i])) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None for _ in range(num_imgs)] + (all_labels, all_label_weights, all_bbox_cls_targets, + all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights, + pos_inds_list, neg_inds_list, sampling_results_list) = multi_apply( + self._get_targets_single, + approx_flat_list, + inside_flag_flat_list, + square_flat_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + + # sampled anchors of all images + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_squares) + label_weights_list = images_to_levels(all_label_weights, + num_level_squares) + bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets, + num_level_squares) + bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights, + num_level_squares) + bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets, + num_level_squares) + bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights, + num_level_squares) + return (labels_list, label_weights_list, bbox_cls_targets_list, + bbox_cls_weights_list, bbox_reg_targets_list, + bbox_reg_weights_list, avg_factor) + + def _get_targets_single(self, + flat_approxs: Tensor, + inside_flags: Tensor, + flat_squares: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + Args: + flat_approxs (Tensor): flat approxs of a single image, + shape (n, 4) + inside_flags (Tensor): inside flags of a single image, + shape (n, ). + flat_squares (Tensor): flat squares of a single image, + shape (approxs_per_octave * n, 4) + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: + + - labels_list (Tensor): Labels in a single image. + - label_weights (Tensor): Label weights in a single image. + - bbox_cls_targets (Tensor): BBox cls targets in a single image. + - bbox_cls_weights (Tensor): BBox cls weights in a single image. + - bbox_reg_targets (Tensor): BBox reg targets in a single image. + - bbox_reg_weights (Tensor): BBox reg weights in a single image. + - num_total_pos (int): Number of positive samples in a single \ + image. + - num_total_neg (int): Number of negative samples in a single \ + image. + - sampling_result (:obj:`SamplingResult`): Sampling result object. + """ + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + num_square = flat_squares.size(0) + approxs = flat_approxs.view(num_square, self.approxs_per_octave, 4) + approxs = approxs[inside_flags, ...] + squares = flat_squares[inside_flags, :] + + pred_instances = InstanceData() + pred_instances.priors = squares + pred_instances.approxs = approxs + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_squares = squares.shape[0] + bbox_cls_targets = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + bbox_cls_weights = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + bbox_reg_targets = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + bbox_reg_weights = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + labels = squares.new_full((num_valid_squares, ), + self.num_classes, + dtype=torch.long) + label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + (pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets, + pos_bbox_cls_weights) = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + + bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets + bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets + bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights + bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_squares.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors, + inside_flags) + bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors, + inside_flags) + bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors, + inside_flags) + bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors, + inside_flags) + return (labels, label_weights, bbox_cls_targets, bbox_cls_weights, + bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds, + sampling_result) + + def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + labels: Tensor, label_weights: Tensor, + bbox_cls_targets: Tensor, bbox_cls_weights: Tensor, + bbox_reg_targets: Tensor, bbox_reg_weights: Tensor, + avg_factor: float) -> Tuple[Tensor]: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels in a single image. + label_weights (Tensor): Label weights in a single level. + bbox_cls_targets (Tensor): BBox cls targets in a single level. + bbox_cls_weights (Tensor): BBox cls weights in a single level. + bbox_reg_targets (Tensor): BBox reg targets in a single level. + bbox_reg_weights (Tensor): BBox reg weights in a single level. + avg_factor (int): Average factor that is used to average the loss. + + Returns: + tuple: loss components. + """ + # classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + # regression loss + bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4) + bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4) + bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4) + bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4) + (bbox_cls_pred, bbox_reg_pred) = bbox_pred + bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape( + -1, self.side_num * 4) + bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape( + -1, self.side_num * 4) + loss_bbox_cls = self.loss_bbox_cls( + bbox_cls_pred, + bbox_cls_targets.long(), + bbox_cls_weights, + avg_factor=avg_factor * 4 * self.side_num) + loss_bbox_reg = self.loss_bbox_reg( + bbox_reg_pred, + bbox_reg_targets, + bbox_reg_weights, + avg_factor=avg_factor * 4 * self.bbox_coder.offset_topk) + return loss_cls, loss_bbox_cls, loss_bbox_reg + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.approx_anchor_generator.num_levels + + device = cls_scores[0].device + + # get sampled approxes + approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs( + self, featmap_sizes, batch_img_metas, device=device) + + square_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + approxs_list, + inside_flag_list, + square_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (labels_list, label_weights_list, bbox_cls_targets_list, + bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list, + avg_factor) = cls_reg_targets + + losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_cls_targets_list, + bbox_cls_weights_list, + bbox_reg_targets_list, + bbox_reg_weights_list, + avg_factor=avg_factor) + return dict( + loss_cls=losses_cls, + loss_bbox_cls=losses_bbox_cls, + loss_bbox_reg=losses_bbox_reg) + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_img_metas: List[dict], + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + batch_img_metas (list[dict], Optional): Batch image meta info. + cfg (:obj:`ConfigDict`, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + + device = cls_scores[0].device + mlvl_anchors = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_cls_pred_list = [ + bbox_preds[i][0][img_id].detach() for i in range(num_levels) + ] + bbox_reg_pred_list = [ + bbox_preds[i][1][img_id].detach() for i in range(num_levels) + ] + proposals = self._predict_by_feat_single( + cls_scores=cls_score_list, + bbox_cls_preds=bbox_cls_pred_list, + bbox_reg_preds=bbox_reg_pred_list, + mlvl_anchors=mlvl_anchors[img_id], + img_meta=batch_img_metas[img_id], + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(proposals) + return result_list + + def _predict_by_feat_single(self, + cls_scores: List[Tensor], + bbox_cls_preds: List[Tensor], + bbox_reg_preds: List[Tensor], + mlvl_anchors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + cfg = self.test_cfg if cfg is None else cfg + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_confids = [] + mlvl_labels = [] + assert len(cls_scores) == len(bbox_cls_preds) == len( + bbox_reg_preds) == len(mlvl_anchors) + for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip( + cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_cls_pred.size( + )[-2:] == bbox_reg_pred.size()[-2::] + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1)[:, :-1] + bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape( + -1, self.side_num * 4) + bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape( + -1, self.side_num * 4) + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict( + anchors=anchors, + bbox_cls_pred=bbox_cls_pred, + bbox_reg_pred=bbox_reg_pred)) + scores, labels, _, filtered_results = results + + anchors = filtered_results['anchors'] + bbox_cls_pred = filtered_results['bbox_cls_pred'] + bbox_reg_pred = filtered_results['bbox_reg_pred'] + + bbox_preds = [ + bbox_cls_pred.contiguous(), + bbox_reg_pred.contiguous() + ] + bboxes, confids = self.bbox_coder.decode( + anchors.contiguous(), + bbox_preds, + max_shape=img_meta['img_shape']) + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_confids.append(confids) + mlvl_labels.append(labels) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.score_factors = torch.cat(mlvl_confids) + results.labels = torch.cat(mlvl_labels) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) diff --git a/mmdet/models/dense_heads/solo_head.py b/mmdet/models/dense_heads/solo_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf338451358b01899faa4b299d33fafd7262d21 --- /dev/null +++ b/mmdet/models/dense_heads/solo_head.py @@ -0,0 +1,1263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import mmcv +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.utils.misc import floordiv +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType +from ..layers import mask_matrix_nms +from ..utils import center_of_mass, generate_coordinate, multi_apply +from .base_mask_head import BaseMaskHead + + +@MODELS.register_module() +class SOLOHead(BaseMaskHead): + """SOLO mask head used in `SOLO: Segmenting Objects by Locations. + + `_ + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + Defaults to 256. + stacked_convs (int): Number of stacking convs of the head. + Defaults to 4. + strides (tuple): Downsample factor of each feature map. + scale_ranges (tuple[tuple[int, int]]): Area range of multiple + level masks, in the format [(min1, max1), (min2, max2), ...]. + A range of (16, 64) means the area range between (16, 64). + pos_scale (float): Constant scale factor to control the center region. + num_grids (list[int]): Divided image into a uniform grids, each + feature map has a different grid value. The number of output + channels is grid ** 2. Defaults to [40, 36, 24, 16, 12]. + cls_down_index (int): The index of downsample operation in + classification branch. Defaults to 0. + loss_mask (dict): Config of mask loss. + loss_cls (dict): Config of classification loss. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to norm_cfg=dict(type='GN', num_groups=32, + requires_grad=True). + train_cfg (dict): Training config of head. + test_cfg (dict): Testing config of head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 4, + strides: tuple = (4, 8, 16, 32, 64), + scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128, + 512)), + pos_scale: float = 0.2, + num_grids: list = [40, 36, 24, 16, 12], + cls_down_index: int = 0, + loss_mask: ConfigType = dict( + type='DiceLoss', use_sigmoid=True, loss_weight=3.0), + loss_cls: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: MultiConfig = [ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ] + ) -> None: + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.cls_out_channels = self.num_classes + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.strides = strides + self.num_grids = num_grids + # number of FPN feats + self.num_levels = len(strides) + assert self.num_levels == len(scale_ranges) == len(num_grids) + self.scale_ranges = scale_ranges + self.pos_scale = pos_scale + + self.cls_down_index = cls_down_index + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.norm_cfg = norm_cfg + self.init_cfg = init_cfg + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.mask_convs = nn.ModuleList() + self.cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels + 2 if i == 0 else self.feat_channels + self.mask_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + self.conv_mask_list = nn.ModuleList() + for num_grid in self.num_grids: + self.conv_mask_list.append( + nn.Conv2d(self.feat_channels, num_grid**2, 1)) + + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]: + """Downsample the first feat and upsample last feat in feats. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + list[Tensor]: Features after resizing, each is a 4D-tensor. + """ + out = [] + for i in range(len(x)): + if i == 0: + out.append( + F.interpolate(x[0], scale_factor=0.5, mode='bilinear')) + elif i == len(x) - 1: + out.append( + F.interpolate( + x[i], size=x[i - 1].shape[-2:], mode='bilinear')) + else: + out.append(x[i]) + return out + + def forward(self, x: Tuple[Tensor]) -> tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and mask prediction. + + - mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. + Each element in the list has shape + (batch_size, num_grids**2 ,h ,w). + - mlvl_cls_preds (list[Tensor]): Multi-level scores. + Each element in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + """ + assert len(x) == self.num_levels + feats = self.resize_feats(x) + mlvl_mask_preds = [] + mlvl_cls_preds = [] + for i in range(self.num_levels): + x = feats[i] + mask_feat = x + cls_feat = x + # generate and concat the coordinate + coord_feat = generate_coordinate(mask_feat.size(), + mask_feat.device) + mask_feat = torch.cat([mask_feat, coord_feat], 1) + + for mask_layer in (self.mask_convs): + mask_feat = mask_layer(mask_feat) + + mask_feat = F.interpolate( + mask_feat, scale_factor=2, mode='bilinear') + mask_preds = self.conv_mask_list[i](mask_feat) + + # cls branch + for j, cls_layer in enumerate(self.cls_convs): + if j == self.cls_down_index: + num_grid = self.num_grids[i] + cls_feat = F.interpolate( + cls_feat, size=num_grid, mode='bilinear') + cls_feat = cls_layer(cls_feat) + + cls_pred = self.conv_cls(cls_feat) + + if not self.training: + feat_wh = feats[0].size()[-2:] + upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) + mask_preds = F.interpolate( + mask_preds.sigmoid(), size=upsampled_size, mode='bilinear') + cls_pred = cls_pred.sigmoid() + # get local maximum + local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_pred + cls_pred = cls_pred * keep_mask + + mlvl_mask_preds.append(mask_preds) + mlvl_cls_preds.append(cls_pred) + return mlvl_mask_preds, mlvl_cls_preds + + def loss_by_feat(self, mlvl_mask_preds: List[Tensor], + mlvl_cls_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. + Each element in the list has shape + (batch_size, num_grids**2 ,h ,w). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_levels = self.num_levels + num_imgs = len(batch_img_metas) + + featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds] + + # `BoolTensor` in `pos_masks` represent + # whether the corresponding point is + # positive + pos_mask_targets, labels, pos_masks = multi_apply( + self._get_targets_single, + batch_gt_instances, + featmap_sizes=featmap_sizes) + + # change from the outside list meaning multi images + # to the outside list meaning multi levels + mlvl_pos_mask_targets = [[] for _ in range(num_levels)] + mlvl_pos_mask_preds = [[] for _ in range(num_levels)] + mlvl_pos_masks = [[] for _ in range(num_levels)] + mlvl_labels = [[] for _ in range(num_levels)] + for img_id in range(num_imgs): + assert num_levels == len(pos_mask_targets[img_id]) + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl].append( + pos_mask_targets[img_id][lvl]) + mlvl_pos_mask_preds[lvl].append( + mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...]) + mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten()) + mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) + + # cat multiple image + temp_mlvl_cls_preds = [] + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl] = torch.cat( + mlvl_pos_mask_targets[lvl], dim=0) + mlvl_pos_mask_preds[lvl] = torch.cat( + mlvl_pos_mask_preds[lvl], dim=0) + mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0) + mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) + temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( + 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) + + num_pos = sum(item.sum() for item in mlvl_pos_masks) + # dice loss + loss_mask = [] + for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets): + if pred.size()[0] == 0: + loss_mask.append(pred.sum().unsqueeze(0)) + continue + loss_mask.append( + self.loss_mask(pred, target, reduction_override='none')) + if num_pos > 0: + loss_mask = torch.cat(loss_mask).sum() / num_pos + else: + loss_mask = torch.cat(loss_mask).mean() + + flatten_labels = torch.cat(mlvl_labels) + flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) + loss_cls = self.loss_cls( + flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) + return dict(loss_mask=loss_mask, loss_cls=loss_cls) + + def _get_targets_single(self, + gt_instances: InstanceData, + featmap_sizes: Optional[list] = None) -> tuple: + """Compute targets for predictions of single image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + featmap_sizes (list[:obj:`torch.size`]): Size of each + feature map from feature pyramid, each element + means (feat_h, feat_w). Defaults to None. + + Returns: + Tuple: Usually returns a tuple containing targets for predictions. + + - mlvl_pos_mask_targets (list[Tensor]): Each element represent + the binary mask targets for positive points in this + level, has shape (num_pos, out_h, out_w). + - mlvl_labels (list[Tensor]): Each element is + classification labels for all + points in this level, has shape + (num_grid, num_grid). + - mlvl_pos_masks (list[Tensor]): Each element is + a `BoolTensor` to represent whether the + corresponding point in single level + is positive, has shape (num_grid **2). + """ + gt_labels = gt_instances.labels + device = gt_labels.device + + gt_bboxes = gt_instances.bboxes + gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * + (gt_bboxes[:, 3] - gt_bboxes[:, 1])) + + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device) + + mlvl_pos_mask_targets = [] + mlvl_labels = [] + mlvl_pos_masks = [] + for (lower_bound, upper_bound), stride, featmap_size, num_grid \ + in zip(self.scale_ranges, self.strides, + featmap_sizes, self.num_grids): + + mask_target = torch.zeros( + [num_grid**2, featmap_size[0], featmap_size[1]], + dtype=torch.uint8, + device=device) + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + labels = torch.zeros([num_grid, num_grid], + dtype=torch.int64, + device=device) + self.num_classes + pos_mask = torch.zeros([num_grid**2], + dtype=torch.bool, + device=device) + + gt_inds = ((gt_areas >= lower_bound) & + (gt_areas <= upper_bound)).nonzero().flatten() + if len(gt_inds) == 0: + mlvl_pos_mask_targets.append( + mask_target.new_zeros(0, featmap_size[0], featmap_size[1])) + mlvl_labels.append(labels) + mlvl_pos_masks.append(pos_mask) + continue + hit_gt_bboxes = gt_bboxes[gt_inds] + hit_gt_labels = gt_labels[gt_inds] + hit_gt_masks = gt_masks[gt_inds, ...] + + pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - + hit_gt_bboxes[:, 0]) * self.pos_scale + pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - + hit_gt_bboxes[:, 1]) * self.pos_scale + + # Make sure hit_gt_masks has a value + valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 + output_stride = stride / 2 + + for gt_mask, gt_label, pos_h_range, pos_w_range, \ + valid_mask_flag in \ + zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, + pos_w_ranges, valid_mask_flags): + if not valid_mask_flag: + continue + upsampled_size = (featmap_sizes[0][0] * 4, + featmap_sizes[0][1] * 4) + center_h, center_w = center_of_mass(gt_mask) + + coord_w = int( + floordiv((center_w / upsampled_size[1]), (1. / num_grid), + rounding_mode='trunc')) + coord_h = int( + floordiv((center_h / upsampled_size[0]), (1. / num_grid), + rounding_mode='trunc')) + + # left, top, right, down + top_box = max( + 0, + int( + floordiv( + (center_h - pos_h_range) / upsampled_size[0], + (1. / num_grid), + rounding_mode='trunc'))) + down_box = min( + num_grid - 1, + int( + floordiv( + (center_h + pos_h_range) / upsampled_size[0], + (1. / num_grid), + rounding_mode='trunc'))) + left_box = max( + 0, + int( + floordiv( + (center_w - pos_w_range) / upsampled_size[1], + (1. / num_grid), + rounding_mode='trunc'))) + right_box = min( + num_grid - 1, + int( + floordiv( + (center_w + pos_w_range) / upsampled_size[1], + (1. / num_grid), + rounding_mode='trunc'))) + + top = max(top_box, coord_h - 1) + down = min(down_box, coord_h + 1) + left = max(coord_w - 1, left_box) + right = min(right_box, coord_w + 1) + + labels[top:(down + 1), left:(right + 1)] = gt_label + # ins + gt_mask = np.uint8(gt_mask.cpu().numpy()) + # Follow the original implementation, F.interpolate is + # different from cv2 and opencv + gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride) + gt_mask = torch.from_numpy(gt_mask).to(device=device) + + for i in range(top, down + 1): + for j in range(left, right + 1): + index = int(i * num_grid + j) + mask_target[index, :gt_mask.shape[0], :gt_mask. + shape[1]] = gt_mask + pos_mask[index] = True + mlvl_pos_mask_targets.append(mask_target[pos_mask]) + mlvl_labels.append(labels) + mlvl_pos_masks.append(pos_mask) + return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks + + def predict_by_feat(self, mlvl_mask_preds: List[Tensor], + mlvl_cls_scores: List[Tensor], + batch_img_metas: List[dict], **kwargs) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. + + Args: + mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. + Each element in the list has shape + (batch_size, num_grids**2 ,h ,w). + mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + batch_img_metas (list[dict]): Meta information of all images. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + mlvl_cls_scores = [ + item.permute(0, 2, 3, 1) for item in mlvl_cls_scores + ] + assert len(mlvl_mask_preds) == len(mlvl_cls_scores) + num_levels = len(mlvl_cls_scores) + + results_list = [] + for img_id in range(len(batch_img_metas)): + cls_pred_list = [ + mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) + for lvl in range(num_levels) + ] + mask_pred_list = [ + mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels) + ] + + cls_pred_list = torch.cat(cls_pred_list, dim=0) + mask_pred_list = torch.cat(mask_pred_list, dim=0) + img_meta = batch_img_metas[img_id] + + results = self._predict_by_feat_single( + cls_pred_list, mask_pred_list, img_meta=img_meta) + results_list.append(results) + + return results_list + + def _predict_by_feat_single(self, + cls_scores: Tensor, + mask_preds: Tensor, + img_meta: dict, + cfg: OptConfigType = None) -> InstanceData: + """Transform a single image's features extracted from the head into + mask results. + + Args: + cls_scores (Tensor): Classification score of all points + in single image, has shape (num_points, num_classes). + mask_preds (Tensor): Mask prediction of all points in + single image, has shape (num_points, feat_h, feat_w). + img_meta (dict): Meta information of corresponding image. + cfg (dict, optional): Config used in test phase. + Defaults to None. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + + def empty_results(cls_scores, ori_shape): + """Generate a empty results.""" + results = InstanceData() + results.scores = cls_scores.new_ones(0) + results.masks = cls_scores.new_zeros(0, *ori_shape) + results.labels = cls_scores.new_ones(0) + results.bboxes = cls_scores.new_zeros(0, 4) + return results + + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(mask_preds) + + featmap_size = mask_preds.size()[-2:] + + h, w = img_meta['img_shape'][:2] + upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) + + score_mask = (cls_scores > cfg.score_thr) + cls_scores = cls_scores[score_mask] + if len(cls_scores) == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + + inds = score_mask.nonzero() + cls_labels = inds[:, 1] + + # Filter the mask mask with an area is smaller than + # stride of corresponding feature level + lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) + strides = cls_scores.new_ones(lvl_interval[-1]) + strides[:lvl_interval[0]] *= self.strides[0] + for lvl in range(1, self.num_levels): + strides[lvl_interval[lvl - + 1]:lvl_interval[lvl]] *= self.strides[lvl] + strides = strides[inds[:, 0]] + mask_preds = mask_preds[inds[:, 0]] + + masks = mask_preds > cfg.mask_thr + sum_masks = masks.sum((1, 2)).float() + keep = sum_masks > strides + if keep.sum() == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + masks = masks[keep] + mask_preds = mask_preds[keep] + sum_masks = sum_masks[keep] + cls_scores = cls_scores[keep] + cls_labels = cls_labels[keep] + + # maskness. + mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks + cls_scores *= mask_scores + + scores, labels, _, keep_inds = mask_matrix_nms( + masks, + cls_labels, + cls_scores, + mask_area=sum_masks, + nms_pre=cfg.nms_pre, + max_num=cfg.max_per_img, + kernel=cfg.kernel, + sigma=cfg.sigma, + filter_thr=cfg.filter_thr) + # mask_matrix_nms may return an empty Tensor + if len(keep_inds) == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + mask_preds = mask_preds[keep_inds] + mask_preds = F.interpolate( + mask_preds.unsqueeze(0), size=upsampled_size, + mode='bilinear')[:, :, :h, :w] + mask_preds = F.interpolate( + mask_preds, size=img_meta['ori_shape'][:2], + mode='bilinear').squeeze(0) + masks = mask_preds > cfg.mask_thr + + results = InstanceData() + results.masks = masks + results.labels = labels + results.scores = scores + # create an empty bbox in InstanceData to avoid bugs when + # calculating metrics. + results.bboxes = results.scores.new_zeros(len(scores), 4) + return results + + +@MODELS.register_module() +class DecoupledSOLOHead(SOLOHead): + """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations. + + `_ + + Args: + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + *args, + init_cfg: MultiConfig = [ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_x')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_y')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ], + **kwargs) -> None: + super().__init__(*args, init_cfg=init_cfg, **kwargs) + + def _init_layers(self) -> None: + self.mask_convs_x = nn.ModuleList() + self.mask_convs_y = nn.ModuleList() + self.cls_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + chn = self.in_channels + 1 if i == 0 else self.feat_channels + self.mask_convs_x.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + self.mask_convs_y.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + + self.conv_mask_list_x = nn.ModuleList() + self.conv_mask_list_y = nn.ModuleList() + for num_grid in self.num_grids: + self.conv_mask_list_x.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_mask_list_y.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def forward(self, x: Tuple[Tensor]) -> Tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and mask prediction. + + - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from x branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction + from y branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + - mlvl_cls_preds (list[Tensor]): Multi-level scores. + Each element in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + """ + assert len(x) == self.num_levels + feats = self.resize_feats(x) + mask_preds_x = [] + mask_preds_y = [] + cls_preds = [] + for i in range(self.num_levels): + x = feats[i] + mask_feat = x + cls_feat = x + # generate and concat the coordinate + coord_feat = generate_coordinate(mask_feat.size(), + mask_feat.device) + mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1) + mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1) + + for mask_layer_x, mask_layer_y in \ + zip(self.mask_convs_x, self.mask_convs_y): + mask_feat_x = mask_layer_x(mask_feat_x) + mask_feat_y = mask_layer_y(mask_feat_y) + + mask_feat_x = F.interpolate( + mask_feat_x, scale_factor=2, mode='bilinear') + mask_feat_y = F.interpolate( + mask_feat_y, scale_factor=2, mode='bilinear') + + mask_pred_x = self.conv_mask_list_x[i](mask_feat_x) + mask_pred_y = self.conv_mask_list_y[i](mask_feat_y) + + # cls branch + for j, cls_layer in enumerate(self.cls_convs): + if j == self.cls_down_index: + num_grid = self.num_grids[i] + cls_feat = F.interpolate( + cls_feat, size=num_grid, mode='bilinear') + cls_feat = cls_layer(cls_feat) + + cls_pred = self.conv_cls(cls_feat) + + if not self.training: + feat_wh = feats[0].size()[-2:] + upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) + mask_pred_x = F.interpolate( + mask_pred_x.sigmoid(), + size=upsampled_size, + mode='bilinear') + mask_pred_y = F.interpolate( + mask_pred_y.sigmoid(), + size=upsampled_size, + mode='bilinear') + cls_pred = cls_pred.sigmoid() + # get local maximum + local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_pred + cls_pred = cls_pred * keep_mask + + mask_preds_x.append(mask_pred_x) + mask_preds_y.append(mask_pred_y) + cls_preds.append(cls_pred) + return mask_preds_x, mask_preds_y, cls_preds + + def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor], + mlvl_mask_preds_y: List[Tensor], + mlvl_cls_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from x branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction + from y branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_levels = self.num_levels + num_imgs = len(batch_img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x] + + pos_mask_targets, labels, xy_pos_indexes = multi_apply( + self._get_targets_single, + batch_gt_instances, + featmap_sizes=featmap_sizes) + + # change from the outside list meaning multi images + # to the outside list meaning multi levels + mlvl_pos_mask_targets = [[] for _ in range(num_levels)] + mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)] + mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)] + mlvl_labels = [[] for _ in range(num_levels)] + for img_id in range(num_imgs): + + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl].append( + pos_mask_targets[img_id][lvl]) + mlvl_pos_mask_preds_x[lvl].append( + mlvl_mask_preds_x[lvl][img_id, + xy_pos_indexes[img_id][lvl][:, 1]]) + mlvl_pos_mask_preds_y[lvl].append( + mlvl_mask_preds_y[lvl][img_id, + xy_pos_indexes[img_id][lvl][:, 0]]) + mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) + + # cat multiple image + temp_mlvl_cls_preds = [] + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl] = torch.cat( + mlvl_pos_mask_targets[lvl], dim=0) + mlvl_pos_mask_preds_x[lvl] = torch.cat( + mlvl_pos_mask_preds_x[lvl], dim=0) + mlvl_pos_mask_preds_y[lvl] = torch.cat( + mlvl_pos_mask_preds_y[lvl], dim=0) + mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) + temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( + 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) + + num_pos = 0. + # dice loss + loss_mask = [] + for pred_x, pred_y, target in \ + zip(mlvl_pos_mask_preds_x, + mlvl_pos_mask_preds_y, mlvl_pos_mask_targets): + num_masks = pred_x.size(0) + if num_masks == 0: + # make sure can get grad + loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0)) + continue + num_pos += num_masks + pred_mask = pred_y.sigmoid() * pred_x.sigmoid() + loss_mask.append( + self.loss_mask(pred_mask, target, reduction_override='none')) + if num_pos > 0: + loss_mask = torch.cat(loss_mask).sum() / num_pos + else: + loss_mask = torch.cat(loss_mask).mean() + + # cate + flatten_labels = torch.cat(mlvl_labels) + flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) + + loss_cls = self.loss_cls( + flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) + return dict(loss_mask=loss_mask, loss_cls=loss_cls) + + def _get_targets_single(self, + gt_instances: InstanceData, + featmap_sizes: Optional[list] = None) -> tuple: + """Compute targets for predictions of single image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + featmap_sizes (list[:obj:`torch.size`]): Size of each + feature map from feature pyramid, each element + means (feat_h, feat_w). Defaults to None. + + Returns: + Tuple: Usually returns a tuple containing targets for predictions. + + - mlvl_pos_mask_targets (list[Tensor]): Each element represent + the binary mask targets for positive points in this + level, has shape (num_pos, out_h, out_w). + - mlvl_labels (list[Tensor]): Each element is + classification labels for all + points in this level, has shape + (num_grid, num_grid). + - mlvl_xy_pos_indexes (list[Tensor]): Each element + in the list contains the index of positive samples in + corresponding level, has shape (num_pos, 2), last + dimension 2 present (index_x, index_y). + """ + mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \ + super()._get_targets_single(gt_instances, + featmap_sizes=featmap_sizes) + + mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero() + for item in mlvl_labels] + + return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes + + def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor], + mlvl_mask_preds_y: List[Tensor], + mlvl_cls_scores: List[Tensor], + batch_img_metas: List[dict], **kwargs) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. + + Args: + mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from x branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction + from y branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes ,num_grids ,num_grids). + batch_img_metas (list[dict]): Meta information of all images. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + mlvl_cls_scores = [ + item.permute(0, 2, 3, 1) for item in mlvl_cls_scores + ] + assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores) + num_levels = len(mlvl_cls_scores) + + results_list = [] + for img_id in range(len(batch_img_metas)): + cls_pred_list = [ + mlvl_cls_scores[i][img_id].view( + -1, self.cls_out_channels).detach() + for i in range(num_levels) + ] + mask_pred_list_x = [ + mlvl_mask_preds_x[i][img_id] for i in range(num_levels) + ] + mask_pred_list_y = [ + mlvl_mask_preds_y[i][img_id] for i in range(num_levels) + ] + + cls_pred_list = torch.cat(cls_pred_list, dim=0) + mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0) + mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0) + img_meta = batch_img_metas[img_id] + + results = self._predict_by_feat_single( + cls_pred_list, + mask_pred_list_x, + mask_pred_list_y, + img_meta=img_meta) + results_list.append(results) + return results_list + + def _predict_by_feat_single(self, + cls_scores: Tensor, + mask_preds_x: Tensor, + mask_preds_y: Tensor, + img_meta: dict, + cfg: OptConfigType = None) -> InstanceData: + """Transform a single image's features extracted from the head into + mask results. + + Args: + cls_scores (Tensor): Classification score of all points + in single image, has shape (num_points, num_classes). + mask_preds_x (Tensor): Mask prediction of x branch of + all points in single image, has shape + (sum_num_grids, feat_h, feat_w). + mask_preds_y (Tensor): Mask prediction of y branch of + all points in single image, has shape + (sum_num_grids, feat_h, feat_w). + img_meta (dict): Meta information of corresponding image. + cfg (dict): Config used in test phase. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + + def empty_results(cls_scores, ori_shape): + """Generate a empty results.""" + results = InstanceData() + results.scores = cls_scores.new_ones(0) + results.masks = cls_scores.new_zeros(0, *ori_shape) + results.labels = cls_scores.new_ones(0) + results.bboxes = cls_scores.new_zeros(0, 4) + return results + + cfg = self.test_cfg if cfg is None else cfg + + featmap_size = mask_preds_x.size()[-2:] + + h, w = img_meta['img_shape'][:2] + upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) + + score_mask = (cls_scores > cfg.score_thr) + cls_scores = cls_scores[score_mask] + inds = score_mask.nonzero() + lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0) + num_all_points = lvl_interval[-1] + lvl_start_index = inds.new_ones(num_all_points) + num_grids = inds.new_ones(num_all_points) + seg_size = inds.new_tensor(self.num_grids).cumsum(0) + mask_lvl_start_index = inds.new_ones(num_all_points) + strides = inds.new_ones(num_all_points) + + lvl_start_index[:lvl_interval[0]] *= 0 + mask_lvl_start_index[:lvl_interval[0]] *= 0 + num_grids[:lvl_interval[0]] *= self.num_grids[0] + strides[:lvl_interval[0]] *= self.strides[0] + + for lvl in range(1, self.num_levels): + lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + lvl_interval[lvl - 1] + mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + seg_size[lvl - 1] + num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + self.num_grids[lvl] + strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + self.strides[lvl] + + lvl_start_index = lvl_start_index[inds[:, 0]] + mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]] + num_grids = num_grids[inds[:, 0]] + strides = strides[inds[:, 0]] + + y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids + x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids + y_inds = mask_lvl_start_index + y_lvl_offset + x_inds = mask_lvl_start_index + x_lvl_offset + + cls_labels = inds[:, 1] + mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...] + + masks = mask_preds > cfg.mask_thr + sum_masks = masks.sum((1, 2)).float() + keep = sum_masks > strides + if keep.sum() == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + + masks = masks[keep] + mask_preds = mask_preds[keep] + sum_masks = sum_masks[keep] + cls_scores = cls_scores[keep] + cls_labels = cls_labels[keep] + + # maskness. + mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks + cls_scores *= mask_scores + + scores, labels, _, keep_inds = mask_matrix_nms( + masks, + cls_labels, + cls_scores, + mask_area=sum_masks, + nms_pre=cfg.nms_pre, + max_num=cfg.max_per_img, + kernel=cfg.kernel, + sigma=cfg.sigma, + filter_thr=cfg.filter_thr) + # mask_matrix_nms may return an empty Tensor + if len(keep_inds) == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + mask_preds = mask_preds[keep_inds] + mask_preds = F.interpolate( + mask_preds.unsqueeze(0), size=upsampled_size, + mode='bilinear')[:, :, :h, :w] + mask_preds = F.interpolate( + mask_preds, size=img_meta['ori_shape'][:2], + mode='bilinear').squeeze(0) + masks = mask_preds > cfg.mask_thr + + results = InstanceData() + results.masks = masks + results.labels = labels + results.scores = scores + # create an empty bbox in InstanceData to avoid bugs when + # calculating metrics. + results.bboxes = results.scores.new_zeros(len(scores), 4) + + return results + + +@MODELS.register_module() +class DecoupledSOLOLightHead(DecoupledSOLOHead): + """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by + Locations `_ + + Args: + with_dcn (bool): Whether use dcn in mask_convs and cls_convs, + Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + *args, + dcn_cfg: OptConfigType = None, + init_cfg: MultiConfig = [ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_x')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_y')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ], + **kwargs) -> None: + assert dcn_cfg is None or isinstance(dcn_cfg, dict) + self.dcn_cfg = dcn_cfg + super().__init__(*args, init_cfg=init_cfg, **kwargs) + + def _init_layers(self) -> None: + self.mask_convs = nn.ModuleList() + self.cls_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + if self.dcn_cfg is not None \ + and i == self.stacked_convs - 1: + conv_cfg = self.dcn_cfg + else: + conv_cfg = None + + chn = self.in_channels + 2 if i == 0 else self.feat_channels + self.mask_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + self.conv_mask_list_x = nn.ModuleList() + self.conv_mask_list_y = nn.ModuleList() + for num_grid in self.num_grids: + self.conv_mask_list_x.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_mask_list_y.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def forward(self, x: Tuple[Tensor]) -> Tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and mask prediction. + + - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from x branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction + from y branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + - mlvl_cls_preds (list[Tensor]): Multi-level scores. + Each element in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + """ + assert len(x) == self.num_levels + feats = self.resize_feats(x) + mask_preds_x = [] + mask_preds_y = [] + cls_preds = [] + for i in range(self.num_levels): + x = feats[i] + mask_feat = x + cls_feat = x + # generate and concat the coordinate + coord_feat = generate_coordinate(mask_feat.size(), + mask_feat.device) + mask_feat = torch.cat([mask_feat, coord_feat], 1) + + for mask_layer in self.mask_convs: + mask_feat = mask_layer(mask_feat) + + mask_feat = F.interpolate( + mask_feat, scale_factor=2, mode='bilinear') + + mask_pred_x = self.conv_mask_list_x[i](mask_feat) + mask_pred_y = self.conv_mask_list_y[i](mask_feat) + + # cls branch + for j, cls_layer in enumerate(self.cls_convs): + if j == self.cls_down_index: + num_grid = self.num_grids[i] + cls_feat = F.interpolate( + cls_feat, size=num_grid, mode='bilinear') + cls_feat = cls_layer(cls_feat) + + cls_pred = self.conv_cls(cls_feat) + + if not self.training: + feat_wh = feats[0].size()[-2:] + upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) + mask_pred_x = F.interpolate( + mask_pred_x.sigmoid(), + size=upsampled_size, + mode='bilinear') + mask_pred_y = F.interpolate( + mask_pred_y.sigmoid(), + size=upsampled_size, + mode='bilinear') + cls_pred = cls_pred.sigmoid() + # get local maximum + local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_pred + cls_pred = cls_pred * keep_mask + + mask_preds_x.append(mask_pred_x) + mask_preds_y.append(mask_pred_y) + cls_preds.append(cls_pred) + return mask_preds_x, mask_preds_y, cls_preds diff --git a/mmdet/models/dense_heads/solov2_head.py b/mmdet/models/dense_heads/solov2_head.py new file mode 100644 index 0000000000000000000000000000000000000000..35b9df0c45148cb18e8afb659b10dd0b9e866b99 --- /dev/null +++ b/mmdet/models/dense_heads/solov2_head.py @@ -0,0 +1,799 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional, Tuple + +import mmcv +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.utils.misc import floordiv +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType +from ..layers import mask_matrix_nms +from ..utils import center_of_mass, generate_coordinate, multi_apply +from .solo_head import SOLOHead + + +class MaskFeatModule(BaseModule): + """SOLOv2 mask feature map branch used in `SOLOv2: Dynamic and Fast + Instance Segmentation. `_ + + Args: + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels of the mask feature + map branch. + start_level (int): The starting feature map level from RPN that + will be used to predict the mask feature map. + end_level (int): The ending feature map level from rpn that + will be used to predict the mask feature map. + out_channels (int): Number of output channels of the mask feature + map branch. This is the channel count of the mask + feature map that to be dynamically convolved with the predicted + kernel. + mask_stride (int): Downsample factor of the mask feature map output. + Defaults to 4. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + feat_channels: int, + start_level: int, + end_level: int, + out_channels: int, + mask_stride: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: MultiConfig = [ + dict(type='Normal', layer='Conv2d', std=0.01) + ] + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.start_level = start_level + self.end_level = end_level + self.mask_stride = mask_stride + assert start_level >= 0 and end_level >= start_level + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self._init_layers() + self.fp16_enabled = False + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.convs_all_levels = nn.ModuleList() + for i in range(self.start_level, self.end_level + 1): + convs_per_level = nn.Sequential() + if i == 0: + convs_per_level.add_module( + f'conv{i}', + ConvModule( + self.in_channels, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=False)) + self.convs_all_levels.append(convs_per_level) + continue + + for j in range(i): + if j == 0: + if i == self.end_level: + chn = self.in_channels + 2 + else: + chn = self.in_channels + convs_per_level.add_module( + f'conv{j}', + ConvModule( + chn, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=False)) + convs_per_level.add_module( + f'upsample{j}', + nn.Upsample( + scale_factor=2, + mode='bilinear', + align_corners=False)) + continue + + convs_per_level.add_module( + f'conv{j}', + ConvModule( + self.feat_channels, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=False)) + convs_per_level.add_module( + f'upsample{j}', + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + + self.convs_all_levels.append(convs_per_level) + + self.conv_pred = ConvModule( + self.feat_channels, + self.out_channels, + 1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + + def forward(self, x: Tuple[Tensor]) -> Tensor: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + Tensor: The predicted mask feature map. + """ + inputs = x[self.start_level:self.end_level + 1] + assert len(inputs) == (self.end_level - self.start_level + 1) + feature_add_all_level = self.convs_all_levels[0](inputs[0]) + for i in range(1, len(inputs)): + input_p = inputs[i] + if i == len(inputs) - 1: + coord_feat = generate_coordinate(input_p.size(), + input_p.device) + input_p = torch.cat([input_p, coord_feat], 1) + + feature_add_all_level = feature_add_all_level + \ + self.convs_all_levels[i](input_p) + + feature_pred = self.conv_pred(feature_add_all_level) + return feature_pred + + +@MODELS.register_module() +class SOLOV2Head(SOLOHead): + """SOLOv2 mask head used in `SOLOv2: Dynamic and Fast Instance + Segmentation. `_ + + Args: + mask_feature_head (dict): Config of SOLOv2MaskFeatHead. + dynamic_conv_size (int): Dynamic Conv kernel size. Defaults to 1. + dcn_cfg (dict): Dcn conv configurations in kernel_convs and cls_conv. + Defaults to None. + dcn_apply_to_all_conv (bool): Whether to use dcn in every layer of + kernel_convs and cls_convs, or only the last layer. It shall be set + `True` for the normal version of SOLOv2 and `False` for the + light-weight version. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + *args, + mask_feature_head: ConfigType, + dynamic_conv_size: int = 1, + dcn_cfg: OptConfigType = None, + dcn_apply_to_all_conv: bool = True, + init_cfg: MultiConfig = [ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ], + **kwargs) -> None: + assert dcn_cfg is None or isinstance(dcn_cfg, dict) + self.dcn_cfg = dcn_cfg + self.with_dcn = dcn_cfg is not None + self.dcn_apply_to_all_conv = dcn_apply_to_all_conv + self.dynamic_conv_size = dynamic_conv_size + mask_out_channels = mask_feature_head.get('out_channels') + self.kernel_out_channels = \ + mask_out_channels * self.dynamic_conv_size * self.dynamic_conv_size + + super().__init__(*args, init_cfg=init_cfg, **kwargs) + + # update the in_channels of mask_feature_head + if mask_feature_head.get('in_channels', None) is not None: + if mask_feature_head.in_channels != self.in_channels: + warnings.warn('The `in_channels` of SOLOv2MaskFeatHead and ' + 'SOLOv2Head should be same, changing ' + 'mask_feature_head.in_channels to ' + f'{self.in_channels}') + mask_feature_head.update(in_channels=self.in_channels) + else: + mask_feature_head.update(in_channels=self.in_channels) + + self.mask_feature_head = MaskFeatModule(**mask_feature_head) + self.mask_stride = self.mask_feature_head.mask_stride + self.fp16_enabled = False + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.cls_convs = nn.ModuleList() + self.kernel_convs = nn.ModuleList() + conv_cfg = None + for i in range(self.stacked_convs): + if self.with_dcn: + if self.dcn_apply_to_all_conv: + conv_cfg = self.dcn_cfg + elif i == self.stacked_convs - 1: + # light head + conv_cfg = self.dcn_cfg + + chn = self.in_channels + 2 if i == 0 else self.feat_channels + self.kernel_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None)) + + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None)) + + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + self.conv_kernel = nn.Conv2d( + self.feat_channels, self.kernel_out_channels, 3, padding=1) + + def forward(self, x): + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores, mask prediction, + and mask features. + + - mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel + prediction. The kernel is used to generate instance + segmentation masks by dynamic convolution. Each element in + the list has shape + (batch_size, kernel_out_channels, num_grids, num_grids). + - mlvl_cls_preds (list[Tensor]): Multi-level scores. Each + element in the list has shape + (batch_size, num_classes, num_grids, num_grids). + - mask_feats (Tensor): Unified mask feature map used to + generate instance segmentation masks by dynamic convolution. + Has shape (batch_size, mask_out_channels, h, w). + """ + assert len(x) == self.num_levels + mask_feats = self.mask_feature_head(x) + ins_kernel_feats = self.resize_feats(x) + mlvl_kernel_preds = [] + mlvl_cls_preds = [] + for i in range(self.num_levels): + ins_kernel_feat = ins_kernel_feats[i] + # ins branch + # concat coord + coord_feat = generate_coordinate(ins_kernel_feat.size(), + ins_kernel_feat.device) + ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1) + + # kernel branch + kernel_feat = ins_kernel_feat + kernel_feat = F.interpolate( + kernel_feat, + size=self.num_grids[i], + mode='bilinear', + align_corners=False) + + cate_feat = kernel_feat[:, :-2, :, :] + + kernel_feat = kernel_feat.contiguous() + for i, kernel_conv in enumerate(self.kernel_convs): + kernel_feat = kernel_conv(kernel_feat) + kernel_pred = self.conv_kernel(kernel_feat) + + # cate branch + cate_feat = cate_feat.contiguous() + for i, cls_conv in enumerate(self.cls_convs): + cate_feat = cls_conv(cate_feat) + cate_pred = self.conv_cls(cate_feat) + + mlvl_kernel_preds.append(kernel_pred) + mlvl_cls_preds.append(cate_pred) + + return mlvl_kernel_preds, mlvl_cls_preds, mask_feats + + def _get_targets_single(self, + gt_instances: InstanceData, + featmap_sizes: Optional[list] = None) -> tuple: + """Compute targets for predictions of single image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + featmap_sizes (list[:obj:`torch.size`]): Size of each + feature map from feature pyramid, each element + means (feat_h, feat_w). Defaults to None. + + Returns: + Tuple: Usually returns a tuple containing targets for predictions. + + - mlvl_pos_mask_targets (list[Tensor]): Each element represent + the binary mask targets for positive points in this + level, has shape (num_pos, out_h, out_w). + - mlvl_labels (list[Tensor]): Each element is + classification labels for all + points in this level, has shape + (num_grid, num_grid). + - mlvl_pos_masks (list[Tensor]): Each element is + a `BoolTensor` to represent whether the + corresponding point in single level + is positive, has shape (num_grid **2). + - mlvl_pos_indexes (list[list]): Each element + in the list contains the positive index in + corresponding level, has shape (num_pos). + """ + gt_labels = gt_instances.labels + device = gt_labels.device + + gt_bboxes = gt_instances.bboxes + gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * + (gt_bboxes[:, 3] - gt_bboxes[:, 1])) + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device) + + mlvl_pos_mask_targets = [] + mlvl_pos_indexes = [] + mlvl_labels = [] + mlvl_pos_masks = [] + for (lower_bound, upper_bound), num_grid \ + in zip(self.scale_ranges, self.num_grids): + mask_target = [] + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + pos_index = [] + labels = torch.zeros([num_grid, num_grid], + dtype=torch.int64, + device=device) + self.num_classes + pos_mask = torch.zeros([num_grid**2], + dtype=torch.bool, + device=device) + + gt_inds = ((gt_areas >= lower_bound) & + (gt_areas <= upper_bound)).nonzero().flatten() + if len(gt_inds) == 0: + mlvl_pos_mask_targets.append( + torch.zeros([0, featmap_sizes[0], featmap_sizes[1]], + dtype=torch.uint8, + device=device)) + mlvl_labels.append(labels) + mlvl_pos_masks.append(pos_mask) + mlvl_pos_indexes.append([]) + continue + hit_gt_bboxes = gt_bboxes[gt_inds] + hit_gt_labels = gt_labels[gt_inds] + hit_gt_masks = gt_masks[gt_inds, ...] + + pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - + hit_gt_bboxes[:, 0]) * self.pos_scale + pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - + hit_gt_bboxes[:, 1]) * self.pos_scale + + # Make sure hit_gt_masks has a value + valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 + + for gt_mask, gt_label, pos_h_range, pos_w_range, \ + valid_mask_flag in \ + zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, + pos_w_ranges, valid_mask_flags): + if not valid_mask_flag: + continue + upsampled_size = (featmap_sizes[0] * self.mask_stride, + featmap_sizes[1] * self.mask_stride) + center_h, center_w = center_of_mass(gt_mask) + + coord_w = int( + floordiv((center_w / upsampled_size[1]), (1. / num_grid), + rounding_mode='trunc')) + coord_h = int( + floordiv((center_h / upsampled_size[0]), (1. / num_grid), + rounding_mode='trunc')) + + # left, top, right, down + top_box = max( + 0, + int( + floordiv( + (center_h - pos_h_range) / upsampled_size[0], + (1. / num_grid), + rounding_mode='trunc'))) + down_box = min( + num_grid - 1, + int( + floordiv( + (center_h + pos_h_range) / upsampled_size[0], + (1. / num_grid), + rounding_mode='trunc'))) + left_box = max( + 0, + int( + floordiv( + (center_w - pos_w_range) / upsampled_size[1], + (1. / num_grid), + rounding_mode='trunc'))) + right_box = min( + num_grid - 1, + int( + floordiv( + (center_w + pos_w_range) / upsampled_size[1], + (1. / num_grid), + rounding_mode='trunc'))) + + top = max(top_box, coord_h - 1) + down = min(down_box, coord_h + 1) + left = max(coord_w - 1, left_box) + right = min(right_box, coord_w + 1) + + labels[top:(down + 1), left:(right + 1)] = gt_label + # ins + gt_mask = np.uint8(gt_mask.cpu().numpy()) + # Follow the original implementation, F.interpolate is + # different from cv2 and opencv + gt_mask = mmcv.imrescale(gt_mask, scale=1. / self.mask_stride) + gt_mask = torch.from_numpy(gt_mask).to(device=device) + + for i in range(top, down + 1): + for j in range(left, right + 1): + index = int(i * num_grid + j) + this_mask_target = torch.zeros( + [featmap_sizes[0], featmap_sizes[1]], + dtype=torch.uint8, + device=device) + this_mask_target[:gt_mask.shape[0], :gt_mask. + shape[1]] = gt_mask + mask_target.append(this_mask_target) + pos_mask[index] = True + pos_index.append(index) + if len(mask_target) == 0: + mask_target = torch.zeros( + [0, featmap_sizes[0], featmap_sizes[1]], + dtype=torch.uint8, + device=device) + else: + mask_target = torch.stack(mask_target, 0) + mlvl_pos_mask_targets.append(mask_target) + mlvl_labels.append(labels) + mlvl_pos_masks.append(pos_mask) + mlvl_pos_indexes.append(pos_index) + return (mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks, + mlvl_pos_indexes) + + def loss_by_feat(self, mlvl_kernel_preds: List[Tensor], + mlvl_cls_preds: List[Tensor], mask_feats: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel + prediction. The kernel is used to generate instance + segmentation masks by dynamic convolution. Each element in the + list has shape + (batch_size, kernel_out_channels, num_grids, num_grids). + mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids, num_grids). + mask_feats (Tensor): Unified mask feature map used to generate + instance segmentation masks by dynamic convolution. Has shape + (batch_size, mask_out_channels, h, w). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = mask_feats.size()[-2:] + + pos_mask_targets, labels, pos_masks, pos_indexes = multi_apply( + self._get_targets_single, + batch_gt_instances, + featmap_sizes=featmap_sizes) + + mlvl_mask_targets = [ + torch.cat(lvl_mask_targets, 0) + for lvl_mask_targets in zip(*pos_mask_targets) + ] + + mlvl_pos_kernel_preds = [] + for lvl_kernel_preds, lvl_pos_indexes in zip(mlvl_kernel_preds, + zip(*pos_indexes)): + lvl_pos_kernel_preds = [] + for img_lvl_kernel_preds, img_lvl_pos_indexes in zip( + lvl_kernel_preds, lvl_pos_indexes): + img_lvl_pos_kernel_preds = img_lvl_kernel_preds.view( + img_lvl_kernel_preds.shape[0], -1)[:, img_lvl_pos_indexes] + lvl_pos_kernel_preds.append(img_lvl_pos_kernel_preds) + mlvl_pos_kernel_preds.append(lvl_pos_kernel_preds) + + # make multilevel mlvl_mask_pred + mlvl_mask_preds = [] + for lvl_pos_kernel_preds in mlvl_pos_kernel_preds: + lvl_mask_preds = [] + for img_id, img_lvl_pos_kernel_pred in enumerate( + lvl_pos_kernel_preds): + if img_lvl_pos_kernel_pred.size()[-1] == 0: + continue + img_mask_feats = mask_feats[[img_id]] + h, w = img_mask_feats.shape[-2:] + num_kernel = img_lvl_pos_kernel_pred.shape[1] + img_lvl_mask_pred = F.conv2d( + img_mask_feats, + img_lvl_pos_kernel_pred.permute(1, 0).view( + num_kernel, -1, self.dynamic_conv_size, + self.dynamic_conv_size), + stride=1).view(-1, h, w) + lvl_mask_preds.append(img_lvl_mask_pred) + if len(lvl_mask_preds) == 0: + lvl_mask_preds = None + else: + lvl_mask_preds = torch.cat(lvl_mask_preds, 0) + mlvl_mask_preds.append(lvl_mask_preds) + # dice loss + num_pos = 0 + for img_pos_masks in pos_masks: + for lvl_img_pos_masks in img_pos_masks: + # Fix `Tensor` object has no attribute `count_nonzero()` + # in PyTorch 1.6, the type of `lvl_img_pos_masks` + # should be `torch.bool`. + num_pos += lvl_img_pos_masks.nonzero().numel() + loss_mask = [] + for lvl_mask_preds, lvl_mask_targets in zip(mlvl_mask_preds, + mlvl_mask_targets): + if lvl_mask_preds is None: + continue + loss_mask.append( + self.loss_mask( + lvl_mask_preds, + lvl_mask_targets, + reduction_override='none')) + if num_pos > 0: + loss_mask = torch.cat(loss_mask).sum() / num_pos + else: + loss_mask = mask_feats.sum() * 0 + + # cate + flatten_labels = [ + torch.cat( + [img_lvl_labels.flatten() for img_lvl_labels in lvl_labels]) + for lvl_labels in zip(*labels) + ] + flatten_labels = torch.cat(flatten_labels) + + flatten_cls_preds = [ + lvl_cls_preds.permute(0, 2, 3, 1).reshape(-1, self.num_classes) + for lvl_cls_preds in mlvl_cls_preds + ] + flatten_cls_preds = torch.cat(flatten_cls_preds) + + loss_cls = self.loss_cls( + flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) + return dict(loss_mask=loss_mask, loss_cls=loss_cls) + + def predict_by_feat(self, mlvl_kernel_preds: List[Tensor], + mlvl_cls_scores: List[Tensor], mask_feats: Tensor, + batch_img_metas: List[dict], **kwargs) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. + + Args: + mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel + prediction. The kernel is used to generate instance + segmentation masks by dynamic convolution. Each element in the + list has shape + (batch_size, kernel_out_channels, num_grids, num_grids). + mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids, num_grids). + mask_feats (Tensor): Unified mask feature map used to generate + instance segmentation masks by dynamic convolution. Has shape + (batch_size, mask_out_channels, h, w). + batch_img_metas (list[dict]): Meta information of all images. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + num_levels = len(mlvl_cls_scores) + assert len(mlvl_kernel_preds) == len(mlvl_cls_scores) + + for lvl in range(num_levels): + cls_scores = mlvl_cls_scores[lvl] + cls_scores = cls_scores.sigmoid() + local_max = F.max_pool2d(cls_scores, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_scores + cls_scores = cls_scores * keep_mask + mlvl_cls_scores[lvl] = cls_scores.permute(0, 2, 3, 1) + + result_list = [] + for img_id in range(len(batch_img_metas)): + img_cls_pred = [ + mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) + for lvl in range(num_levels) + ] + img_mask_feats = mask_feats[[img_id]] + img_kernel_pred = [ + mlvl_kernel_preds[lvl][img_id].permute(1, 2, 0).view( + -1, self.kernel_out_channels) for lvl in range(num_levels) + ] + img_cls_pred = torch.cat(img_cls_pred, dim=0) + img_kernel_pred = torch.cat(img_kernel_pred, dim=0) + result = self._predict_by_feat_single( + img_kernel_pred, + img_cls_pred, + img_mask_feats, + img_meta=batch_img_metas[img_id]) + result_list.append(result) + return result_list + + def _predict_by_feat_single(self, + kernel_preds: Tensor, + cls_scores: Tensor, + mask_feats: Tensor, + img_meta: dict, + cfg: OptConfigType = None) -> InstanceData: + """Transform a single image's features extracted from the head into + mask results. + + Args: + kernel_preds (Tensor): Dynamic kernel prediction of all points + in single image, has shape + (num_points, kernel_out_channels). + cls_scores (Tensor): Classification score of all points + in single image, has shape (num_points, num_classes). + mask_feats (Tensor): Mask prediction of all points in + single image, has shape (num_points, feat_h, feat_w). + img_meta (dict): Meta information of corresponding image. + cfg (dict, optional): Config used in test phase. + Defaults to None. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + + def empty_results(cls_scores, ori_shape): + """Generate a empty results.""" + results = InstanceData() + results.scores = cls_scores.new_ones(0) + results.masks = cls_scores.new_zeros(0, *ori_shape) + results.labels = cls_scores.new_ones(0) + results.bboxes = cls_scores.new_zeros(0, 4) + return results + + cfg = self.test_cfg if cfg is None else cfg + assert len(kernel_preds) == len(cls_scores) + + featmap_size = mask_feats.size()[-2:] + + # overall info + h, w = img_meta['img_shape'][:2] + upsampled_size = (featmap_size[0] * self.mask_stride, + featmap_size[1] * self.mask_stride) + + # process. + score_mask = (cls_scores > cfg.score_thr) + cls_scores = cls_scores[score_mask] + if len(cls_scores) == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + + # cate_labels & kernel_preds + inds = score_mask.nonzero() + cls_labels = inds[:, 1] + kernel_preds = kernel_preds[inds[:, 0]] + + # trans vector. + lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) + strides = kernel_preds.new_ones(lvl_interval[-1]) + + strides[:lvl_interval[0]] *= self.strides[0] + for lvl in range(1, self.num_levels): + strides[lvl_interval[lvl - + 1]:lvl_interval[lvl]] *= self.strides[lvl] + strides = strides[inds[:, 0]] + + # mask encoding. + kernel_preds = kernel_preds.view( + kernel_preds.size(0), -1, self.dynamic_conv_size, + self.dynamic_conv_size) + mask_preds = F.conv2d( + mask_feats, kernel_preds, stride=1).squeeze(0).sigmoid() + # mask. + masks = mask_preds > cfg.mask_thr + sum_masks = masks.sum((1, 2)).float() + keep = sum_masks > strides + if keep.sum() == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + masks = masks[keep] + mask_preds = mask_preds[keep] + sum_masks = sum_masks[keep] + cls_scores = cls_scores[keep] + cls_labels = cls_labels[keep] + + # maskness. + mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks + cls_scores *= mask_scores + + scores, labels, _, keep_inds = mask_matrix_nms( + masks, + cls_labels, + cls_scores, + mask_area=sum_masks, + nms_pre=cfg.nms_pre, + max_num=cfg.max_per_img, + kernel=cfg.kernel, + sigma=cfg.sigma, + filter_thr=cfg.filter_thr) + if len(keep_inds) == 0: + return empty_results(cls_scores, img_meta['ori_shape'][:2]) + mask_preds = mask_preds[keep_inds] + mask_preds = F.interpolate( + mask_preds.unsqueeze(0), + size=upsampled_size, + mode='bilinear', + align_corners=False)[:, :, :h, :w] + mask_preds = F.interpolate( + mask_preds, + size=img_meta['ori_shape'][:2], + mode='bilinear', + align_corners=False).squeeze(0) + masks = mask_preds > cfg.mask_thr + + results = InstanceData() + results.masks = masks + results.labels = labels + results.scores = scores + # create an empty bbox in InstanceData to avoid bugs when + # calculating metrics. + results.bboxes = results.scores.new_zeros(len(scores), 4) + + return results diff --git a/mmdet/models/dense_heads/ssd_head.py b/mmdet/models/dense_heads/ssd_head.py new file mode 100644 index 0000000000000000000000000000000000000000..950df29110d914cc888bc16c6cbf1856f604a1de --- /dev/null +++ b/mmdet/models/dense_heads/ssd_head.py @@ -0,0 +1,362 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList +from ..losses import smooth_l1_loss +from ..task_modules.samplers import PseudoSampler +from ..utils import multi_apply +from .anchor_head import AnchorHead + + +# TODO: add loss evaluator for SSD +@MODELS.register_module() +class SSDHead(AnchorHead): + """Implementation of `SSD head `_ + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (Sequence[int]): Number of channels in the input feature + map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Defaults to 0. + feat_channels (int): Number of hidden channels when stacked_convs + > 0. Defaults to 256. + use_depthwise (bool): Whether to use DepthwiseSeparableConv. + Defaults to False. + conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct + and config conv layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct + and config norm layer. Defaults to None. + act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct + and config activation layer. Defaults to None. + anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor + generator. + bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Defaults to False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of + anchor head. + test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of + anchor head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], Optional): Initialization config dict. + """ # noqa: W605 + + def __init__( + self, + num_classes: int = 80, + in_channels: Sequence[int] = (512, 1024, 512, 256, 256, 256), + stacked_convs: int = 0, + feat_channels: int = 256, + use_depthwise: bool = False, + conv_cfg: Optional[ConfigType] = None, + norm_cfg: Optional[ConfigType] = None, + act_cfg: Optional[ConfigType] = None, + anchor_generator: ConfigType = dict( + type='SSDAnchorGenerator', + scale_major=False, + input_size=300, + strides=[8, 16, 32, 64, 100, 300], + ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), + basesize_ratio_range=(0.1, 0.9)), + bbox_coder: ConfigType = dict( + type='DeltaXYWHBBoxCoder', + clip_border=True, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + ), + reg_decoded_bbox: bool = False, + train_cfg: Optional[ConfigType] = None, + test_cfg: Optional[ConfigType] = None, + init_cfg: MultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform', bias=0) + ) -> None: + super(AnchorHead, self).__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.in_channels = in_channels + self.stacked_convs = stacked_convs + self.feat_channels = feat_channels + self.use_depthwise = use_depthwise + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.cls_out_channels = num_classes + 1 # add background class + self.prior_generator = TASK_UTILS.build(anchor_generator) + + # Usually the numbers of anchors for each level are the same + # except SSD detectors. So it is an int in the most dense + # heads but a list of int in SSDHead + self.num_base_priors = self.prior_generator.num_base_priors + + self._init_layers() + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.reg_decoded_bbox = reg_decoded_bbox + self.use_sigmoid_cls = False + self.cls_focal_loss = False + self.train_cfg = train_cfg + self.test_cfg = test_cfg + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + if self.train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler(context=self) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + # TODO: Use registry to choose ConvModule type + conv = DepthwiseSeparableConvModule \ + if self.use_depthwise else ConvModule + + for channel, num_base_priors in zip(self.in_channels, + self.num_base_priors): + cls_layers = [] + reg_layers = [] + in_channel = channel + # build stacked conv tower, not used in default ssd + for i in range(self.stacked_convs): + cls_layers.append( + conv( + in_channel, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + reg_layers.append( + conv( + in_channel, + self.feat_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channel = self.feat_channels + # SSD-Lite head + if self.use_depthwise: + cls_layers.append( + ConvModule( + in_channel, + in_channel, + 3, + padding=1, + groups=in_channel, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + reg_layers.append( + ConvModule( + in_channel, + in_channel, + 3, + padding=1, + groups=in_channel, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + cls_layers.append( + nn.Conv2d( + in_channel, + num_base_priors * self.cls_out_channels, + kernel_size=1 if self.use_depthwise else 3, + padding=0 if self.use_depthwise else 1)) + reg_layers.append( + nn.Conv2d( + in_channel, + num_base_priors * 4, + kernel_size=1 if self.use_depthwise else 3, + padding=0 if self.use_depthwise else 1)) + self.cls_convs.append(nn.Sequential(*cls_layers)) + self.reg_convs.append(nn.Sequential(*reg_layers)) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple[list[Tensor], list[Tensor]]: A tuple of cls_scores list and + bbox_preds list. + + - cls_scores (list[Tensor]): Classification scores for all scale \ + levels, each is a 4D-tensor, the channels number is \ + num_anchors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all scale \ + levels, each is a 4D-tensor, the channels number is \ + num_anchors * 4. + """ + cls_scores = [] + bbox_preds = [] + for feat, reg_conv, cls_conv in zip(x, self.reg_convs, self.cls_convs): + cls_scores.append(cls_conv(feat)) + bbox_preds.append(reg_conv(feat)) + return cls_scores, bbox_preds + + def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + anchor: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, + avg_factor: int) -> Tuple[Tensor, Tensor]: + """Compute loss of a single image. + + Args: + cls_score (Tensor): Box scores for eachimage + Has shape (num_total_anchors, num_classes). + bbox_pred (Tensor): Box energies / deltas for each image + level with shape (num_total_anchors, 4). + anchors (Tensor): Box reference for each scale level with shape + (num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape + (num_total_anchors,). + label_weights (Tensor): Label weights of each anchor with shape + (num_total_anchors,) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (num_total_anchors, 4). + bbox_weights (Tensor): BBox regression loss weights of each anchor + with shape (num_total_anchors, 4). + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + Tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one + feature map. + """ + + loss_cls_all = F.cross_entropy( + cls_score, labels, reduction='none') * label_weights + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero( + as_tuple=False).reshape(-1) + neg_inds = (labels == self.num_classes).nonzero( + as_tuple=False).view(-1) + + num_pos_samples = pos_inds.size(0) + num_neg_samples = self.train_cfg['neg_pos_ratio'] * num_pos_samples + if num_neg_samples > neg_inds.size(0): + num_neg_samples = neg_inds.size(0) + topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) + loss_cls_pos = loss_cls_all[pos_inds].sum() + loss_cls_neg = topk_loss_cls_neg.sum() + loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor + + if self.reg_decoded_bbox: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, it + # decodes the already encoded coordinates to absolute format. + bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) + + loss_bbox = smooth_l1_loss( + bbox_pred, + bbox_targets, + bbox_weights, + beta=self.train_cfg['smoothl1_beta'], + avg_factor=avg_factor) + return loss_cls[None], loss_bbox + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, List[Tensor]]: + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, list[Tensor]]: A dictionary of loss components. the dict + has components below: + + - loss_cls (list[Tensor]): A list containing each feature map \ + classification loss. + - loss_bbox (list[Tensor]): A list containing each feature map \ + regression loss. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + unmap_outputs=True) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor) = cls_reg_targets + + num_images = len(batch_img_metas) + all_cls_scores = torch.cat([ + s.permute(0, 2, 3, 1).reshape( + num_images, -1, self.cls_out_channels) for s in cls_scores + ], 1) + all_labels = torch.cat(labels_list, -1).view(num_images, -1) + all_label_weights = torch.cat(label_weights_list, + -1).view(num_images, -1) + all_bbox_preds = torch.cat([ + b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) + for b in bbox_preds + ], -2) + all_bbox_targets = torch.cat(bbox_targets_list, + -2).view(num_images, -1, 4) + all_bbox_weights = torch.cat(bbox_weights_list, + -2).view(num_images, -1, 4) + + # concat all level anchors to a single tensor + all_anchors = [] + for i in range(num_images): + all_anchors.append(torch.cat(anchor_list[i])) + + losses_cls, losses_bbox = multi_apply( + self.loss_by_feat_single, + all_cls_scores, + all_bbox_preds, + all_anchors, + all_labels, + all_label_weights, + all_bbox_targets, + all_bbox_weights, + avg_factor=avg_factor) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) diff --git a/mmdet/models/dense_heads/tood_head.py b/mmdet/models/dense_heads/tood_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8c59598d89289df6d1a87c7b6fde112429ac8f45 --- /dev/null +++ b/mmdet/models/dense_heads/tood_head.py @@ -0,0 +1,805 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from mmcv.ops import deform_conv2d +from mmengine import MessageHub +from mmengine.config import ConfigDict +from mmengine.model import bias_init_with_prob, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import distance2bbox +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, reduce_mean) +from ..task_modules.prior_generators import anchor_inside_flags +from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply, + sigmoid_geometric_mean, unmap) +from .atss_head import ATSSHead + + +class TaskDecomposition(nn.Module): + """Task decomposition module in task-aligned predictor of TOOD. + + Args: + feat_channels (int): Number of feature channels in TOOD head. + stacked_convs (int): Number of conv layers in TOOD head. + la_down_rate (int): Downsample rate of layer attention. + Defaults to 8. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Defaults to None. + """ + + def __init__(self, + feat_channels: int, + stacked_convs: int, + la_down_rate: int = 8, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None) -> None: + super().__init__() + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.in_channels = self.feat_channels * self.stacked_convs + self.norm_cfg = norm_cfg + self.layer_attention = nn.Sequential( + nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1), + nn.ReLU(inplace=True), + nn.Conv2d( + self.in_channels // la_down_rate, + self.stacked_convs, + 1, + padding=0), nn.Sigmoid()) + + self.reduction_conv = ConvModule( + self.in_channels, + self.feat_channels, + 1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=norm_cfg is None) + + def init_weights(self) -> None: + """Initialize the parameters.""" + for m in self.layer_attention.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + normal_init(self.reduction_conv.conv, std=0.01) + + def forward(self, + feat: Tensor, + avg_feat: Optional[Tensor] = None) -> Tensor: + """Forward function of task decomposition module.""" + b, c, h, w = feat.shape + if avg_feat is None: + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + weight = self.layer_attention(avg_feat) + + # here we first compute the product between layer attention weight and + # conv weight, and then compute the convolution between new conv weight + # and feature map, in order to save memory and FLOPs. + conv_weight = weight.reshape( + b, 1, self.stacked_convs, + 1) * self.reduction_conv.conv.weight.reshape( + 1, self.feat_channels, self.stacked_convs, self.feat_channels) + conv_weight = conv_weight.reshape(b, self.feat_channels, + self.in_channels) + feat = feat.reshape(b, self.in_channels, h * w) + feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h, + w) + if self.norm_cfg is not None: + feat = self.reduction_conv.norm(feat) + feat = self.reduction_conv.activate(feat) + + return feat + + +@MODELS.register_module() +class TOODHead(ATSSHead): + """TOODHead used in `TOOD: Task-aligned One-stage Object Detection. + + `_. + + TOOD uses Task-aligned head (T-head) and is optimized by Task Alignment + Learning (TAL). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_dcn (int): Number of deformable convolution in the head. + Defaults to 0. + anchor_type (str): If set to ``anchor_free``, the head will use centers + to regress bboxes. If set to ``anchor_based``, the head will + regress bboxes based on anchors. Defaults to ``anchor_free``. + initial_loss_cls (:obj:`ConfigDict` or dict): Config of initial loss. + + Example: + >>> self = TOODHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred = self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ + + def __init__(self, + num_classes: int, + in_channels: int, + num_dcn: int = 0, + anchor_type: str = 'anchor_free', + initial_loss_cls: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + **kwargs) -> None: + assert anchor_type in ['anchor_free', 'anchor_based'] + self.num_dcn = num_dcn + self.anchor_type = anchor_type + super().__init__( + num_classes=num_classes, in_channels=in_channels, **kwargs) + + if self.train_cfg: + self.initial_epoch = self.train_cfg['initial_epoch'] + self.initial_assigner = TASK_UTILS.build( + self.train_cfg['initial_assigner']) + self.initial_loss_cls = MODELS.build(initial_loss_cls) + self.assigner = self.initial_assigner + self.alignment_assigner = TASK_UTILS.build( + self.train_cfg['assigner']) + self.alpha = self.train_cfg['alpha'] + self.beta = self.train_cfg['beta'] + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList() + for i in range(self.stacked_convs): + if i < self.num_dcn: + conv_cfg = dict(type='DCNv2', deform_groups=4) + else: + conv_cfg = self.conv_cfg + chn = self.in_channels if i == 0 else self.feat_channels + self.inter_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + self.cls_decomp = TaskDecomposition(self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + self.conv_cfg, self.norm_cfg) + self.reg_decomp = TaskDecomposition(self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + self.conv_cfg, self.norm_cfg) + + self.tood_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.tood_reg = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + + self.cls_prob_module = nn.Sequential( + nn.Conv2d(self.feat_channels * self.stacked_convs, + self.feat_channels // 4, 1), nn.ReLU(inplace=True), + nn.Conv2d(self.feat_channels // 4, 1, 3, padding=1)) + self.reg_offset_module = nn.Sequential( + nn.Conv2d(self.feat_channels * self.stacked_convs, + self.feat_channels // 4, 1), nn.ReLU(inplace=True), + nn.Conv2d(self.feat_channels // 4, 4 * 2, 3, padding=1)) + + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + def init_weights(self) -> None: + """Initialize weights of the head.""" + bias_cls = bias_init_with_prob(0.01) + for m in self.inter_convs: + normal_init(m.conv, std=0.01) + for m in self.cls_prob_module: + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.01) + for m in self.reg_offset_module: + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + normal_init(self.cls_prob_module[-1], std=0.01, bias=bias_cls) + + self.cls_decomp.init_weights() + self.reg_decomp.init_weights() + + normal_init(self.tood_cls, std=0.01, bias=bias_cls) + normal_init(self.tood_reg, std=0.01) + + def forward(self, feats: Tuple[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + bbox_preds (list[Tensor]): Decoded box for all scale levels, + each is a 4D-tensor, the channels number is + num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format. + """ + cls_scores = [] + bbox_preds = [] + for idx, (x, scale, stride) in enumerate( + zip(feats, self.scales, self.prior_generator.strides)): + b, c, h, w = x.shape + anchor = self.prior_generator.single_level_grid_priors( + (h, w), idx, device=x.device) + anchor = torch.cat([anchor for _ in range(b)]) + # extract task interactive features + inter_feats = [] + for inter_conv in self.inter_convs: + x = inter_conv(x) + inter_feats.append(x) + feat = torch.cat(inter_feats, 1) + + # task decomposition + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + cls_feat = self.cls_decomp(feat, avg_feat) + reg_feat = self.reg_decomp(feat, avg_feat) + + # cls prediction and alignment + cls_logits = self.tood_cls(cls_feat) + cls_prob = self.cls_prob_module(feat) + cls_score = sigmoid_geometric_mean(cls_logits, cls_prob) + + # reg prediction and alignment + if self.anchor_type == 'anchor_free': + reg_dist = scale(self.tood_reg(reg_feat).exp()).float() + reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) + reg_bbox = distance2bbox( + self.anchor_center(anchor) / stride[0], + reg_dist).reshape(b, h, w, 4).permute(0, 3, 1, + 2) # (b, c, h, w) + elif self.anchor_type == 'anchor_based': + reg_dist = scale(self.tood_reg(reg_feat)).float() + reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) + reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape( + b, h, w, 4).permute(0, 3, 1, 2) / stride[0] + else: + raise NotImplementedError( + f'Unknown anchor type: {self.anchor_type}.' + f'Please use `anchor_free` or `anchor_based`.') + reg_offset = self.reg_offset_module(feat) + bbox_pred = self.deform_sampling(reg_bbox.contiguous(), + reg_offset.contiguous()) + + # After deform_sampling, some boxes will become invalid (The + # left-top point is at the right or bottom of the right-bottom + # point), which will make the GIoULoss negative. + invalid_bbox_idx = (bbox_pred[:, [0]] > bbox_pred[:, [2]]) | \ + (bbox_pred[:, [1]] > bbox_pred[:, [3]]) + invalid_bbox_idx = invalid_bbox_idx.expand_as(bbox_pred) + bbox_pred = torch.where(invalid_bbox_idx, reg_bbox, bbox_pred) + + cls_scores.append(cls_score) + bbox_preds.append(bbox_pred) + return tuple(cls_scores), tuple(bbox_preds) + + def deform_sampling(self, feat: Tensor, offset: Tensor) -> Tensor: + """Sampling the feature x according to offset. + + Args: + feat (Tensor): Feature + offset (Tensor): Spatial offset for feature sampling + """ + # it is an equivalent implementation of bilinear interpolation + b, c, h, w = feat.shape + weight = feat.new_ones(c, 1, 1, 1) + y = deform_conv2d(feat, offset, weight, 1, 0, 1, c, c) + return y + + def anchor_center(self, anchors: Tensor) -> Tensor: + """Get anchor centers from anchors. + + Args: + anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Anchor centers with shape (N, 2), "xy" format. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + return torch.stack([anchors_cx, anchors_cy], dim=-1) + + def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + alignment_metrics: Tensor, + stride: Tuple[int, int]) -> dict: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Decoded bboxes for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors). + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + alignment_metrics (Tensor): Alignment metrics with shape + (N, num_total_anchors). + stride (Tuple[int, int]): Downsample stride of the feature map. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + alignment_metrics = alignment_metrics.reshape(-1) + label_weights = label_weights.reshape(-1) + targets = labels if self.epoch < self.initial_epoch else ( + labels, alignment_metrics) + cls_loss_func = self.initial_loss_cls \ + if self.epoch < self.initial_epoch else self.loss_cls + + loss_cls = cls_loss_func( + cls_score, targets, label_weights, avg_factor=1.0) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + + pos_decode_bbox_pred = pos_bbox_pred + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + + # regression loss + pos_bbox_weight = self.centerness_target( + pos_anchors, pos_bbox_targets + ) if self.epoch < self.initial_epoch else alignment_metrics[ + pos_inds] + + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=pos_bbox_weight, + avg_factor=1.0) + else: + loss_bbox = bbox_pred.sum() * 0 + pos_bbox_weight = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, alignment_metrics.sum( + ), pos_bbox_weight.sum() + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(batch_img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1) + flatten_bbox_preds = torch.cat([ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) * stride[0] + for bbox_pred, stride in zip(bbox_preds, + self.prior_generator.strides) + ], 1) + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bbox_preds, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + alignment_metrics_list) = cls_reg_targets + + losses_cls, losses_bbox, \ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_by_feat_single, + anchor_list, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + alignment_metrics_list, + self.prior_generator.strides) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (:obj:`ConfigDict`, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bboxes, 5], where the first 4 columns are bounding \ + box positions (tl_x, tl_y, br_x, br_y) and the 5-th \ + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bboxes]. + """ + + cfg = self.test_cfg if cfg is None else cfg + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_labels = [] + for cls_score, bbox_pred, priors, stride in zip( + cls_score_list, bbox_pred_list, mlvl_priors, + self.prior_generator.strides): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) * stride[0] + scores = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + + bboxes = filtered_results['bbox_pred'] + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def get_targets(self, + cls_scores: List[List[Tensor]], + bbox_preds: List[List[Tensor]], + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores (list[list[Tensor]]): Classification predictions of + images, a 3D-Tensor with shape [num_imgs, num_priors, + num_classes]. + bbox_preds (list[list[Tensor]]): Decoded bboxes predictions of one + image, a 3D-Tensor with shape [num_imgs, num_priors, 4] in + [tl_x, tl_y, br_x, br_y] format. + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: a tuple containing learning targets. + + - anchors_list (list[list[Tensor]]): Anchors of each level. + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - norm_alignment_metrics_list (list[Tensor]): Normalized + alignment metrics of each level. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + # anchor_list: list(b * [-1, 4]) + + # get epoch information from message hub + message_hub = MessageHub.get_current_instance() + self.epoch = message_hub.get_info('epoch') + + if self.epoch < self.initial_epoch: + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_result) = multi_apply( + super()._get_targets_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + all_assign_metrics = [ + weight[..., 0] for weight in all_bbox_weights + ] + else: + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_assign_metrics) = multi_apply( + self._get_targets_single, + cls_scores, + bbox_preds, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + norm_alignment_metrics_list = images_to_levels(all_assign_metrics, + num_level_anchors) + + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, norm_alignment_metrics_list) + + def _get_targets_single(self, + cls_scores: Tensor, + bbox_preds: Tensor, + flat_anchors: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression, classification targets for anchors in a single + image. + + Args: + cls_scores (Tensor): Box scores for each image. + bbox_preds (Tensor): Box energies / deltas for each image. + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + anchors (Tensor): All anchors in the image with shape (N, 4). + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + norm_alignment_metrics (Tensor): Normalized alignment metrics + of all priors in the image with shape (N,). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + pred_instances = InstanceData( + priors=anchors, + scores=cls_scores[inside_flags, :], + bboxes=bbox_preds[inside_flags, :]) + assign_result = self.alignment_assigner.assign(pred_instances, + gt_instances, + gt_instances_ignore, + self.alpha, self.beta) + assign_ious = assign_result.max_overlaps + assign_metrics = assign_result.assign_metrics + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + norm_alignment_metrics = anchors.new_zeros( + num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + # point-based + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + class_assigned_gt_inds = torch.unique( + sampling_result.pos_assigned_gt_inds) + for gt_inds in class_assigned_gt_inds: + gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds == + gt_inds] + pos_alignment_metrics = assign_metrics[gt_class_inds] + pos_ious = assign_ious[gt_class_inds] + pos_norm_alignment_metrics = pos_alignment_metrics / ( + pos_alignment_metrics.max() + 10e-8) * pos_ious.max() + norm_alignment_metrics[gt_class_inds] = pos_norm_alignment_metrics + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + norm_alignment_metrics = unmap(norm_alignment_metrics, + num_total_anchors, inside_flags) + return (anchors, labels, label_weights, bbox_targets, + norm_alignment_metrics) diff --git a/mmdet/models/dense_heads/vfnet_head.py b/mmdet/models/dense_heads/vfnet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..430b06d085d94760d56a7ea083eaf23bd32b1f53 --- /dev/null +++ b/mmdet/models/dense_heads/vfnet_head.py @@ -0,0 +1,722 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Scale +from mmcv.ops import DeformConv2d +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, + OptInstanceList, RangeType, reduce_mean) +from ..task_modules.prior_generators import MlvlPointGenerator +from ..task_modules.samplers import PseudoSampler +from ..utils import multi_apply +from .atss_head import ATSSHead +from .fcos_head import FCOSHead + +INF = 1e8 + + +@MODELS.register_module() +class VFNetHead(ATSSHead, FCOSHead): + """Head of `VarifocalNet (VFNet): An IoU-aware Dense Object + Detector.`_. + + The VFNet predicts IoU-aware classification scores which mix the + object presence confidence and object localization accuracy as the + detection score. It is built on the FCOS architecture and uses ATSS + for defining positive/negative training examples. The VFNet is trained + with Varifocal Loss and empolys star-shaped deformable convolution to + extract features for a bbox. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple + level points. + center_sampling (bool): If true, use center sampling. Defaults to False. + center_sample_radius (float): Radius of center sampling. Defaults to 1.5. + sync_num_pos (bool): If true, synchronize the number of positive + examples across GPUs. Defaults to True + gradient_mul (float): The multiplier to gradients from bbox refinement + and recognition. Defaults to 0.1. + bbox_norm_type (str): The bbox normalization type, 'reg_denom' or + 'stride'. Defaults to reg_denom + loss_cls_fl (:obj:`ConfigDict` or dict): Config of focal loss. + use_vfl (bool): If true, use varifocal loss for training. + Defaults to True. + loss_cls (:obj:`ConfigDict` or dict): Config of varifocal loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss, + GIoU Loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization + refinement loss, GIoU Loss. + norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config norm layer. Defaults to norm_cfg=dict(type='GN', + num_groups=32, requires_grad=True). + use_atss (bool): If true, use ATSS to define positive/negative + examples. Defaults to True. + anchor_generator (:obj:`ConfigDict` or dict): Config of anchor + generator for ATSS. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. + + Example: + >>> self = VFNetHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred, bbox_pred_refine= self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ # noqa: E501 + + def __init__(self, + num_classes: int, + in_channels: int, + regress_ranges: RangeType = ((-1, 64), (64, 128), (128, 256), + (256, 512), (512, INF)), + center_sampling: bool = False, + center_sample_radius: float = 1.5, + sync_num_pos: bool = True, + gradient_mul: float = 0.1, + bbox_norm_type: str = 'reg_denom', + loss_cls_fl: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + use_vfl: bool = True, + loss_cls: ConfigType = dict( + type='VarifocalLoss', + use_sigmoid=True, + alpha=0.75, + gamma=2.0, + iou_weighted=True, + loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='GIoULoss', loss_weight=1.5), + loss_bbox_refine: ConfigType = dict( + type='GIoULoss', loss_weight=2.0), + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + use_atss: bool = True, + reg_decoded_bbox: bool = True, + anchor_generator: ConfigType = dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + center_offset=0.0, + strides=[8, 16, 32, 64, 128]), + init_cfg: MultiConfig = dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='vfnet_cls', + std=0.01, + bias_prob=0.01)), + **kwargs) -> None: + # dcn base offsets, adapted from reppoints_head.py + self.num_dconv_points = 9 + self.dcn_kernel = int(np.sqrt(self.num_dconv_points)) + self.dcn_pad = int((self.dcn_kernel - 1) / 2) + dcn_base = np.arange(-self.dcn_pad, + self.dcn_pad + 1).astype(np.float64) + dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) + dcn_base_x = np.tile(dcn_base, self.dcn_kernel) + dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape( + (-1)) + self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1) + + super(FCOSHead, self).__init__( + num_classes=num_classes, + in_channels=in_channels, + norm_cfg=norm_cfg, + init_cfg=init_cfg, + **kwargs) + self.regress_ranges = regress_ranges + self.reg_denoms = [ + regress_range[-1] for regress_range in regress_ranges + ] + self.reg_denoms[-1] = self.reg_denoms[-2] * 2 + self.center_sampling = center_sampling + self.center_sample_radius = center_sample_radius + self.sync_num_pos = sync_num_pos + self.bbox_norm_type = bbox_norm_type + self.gradient_mul = gradient_mul + self.use_vfl = use_vfl + if self.use_vfl: + self.loss_cls = MODELS.build(loss_cls) + else: + self.loss_cls = MODELS.build(loss_cls_fl) + self.loss_bbox = MODELS.build(loss_bbox) + self.loss_bbox_refine = MODELS.build(loss_bbox_refine) + + # for getting ATSS targets + self.use_atss = use_atss + self.reg_decoded_bbox = reg_decoded_bbox + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + + self.anchor_center_offset = anchor_generator['center_offset'] + + self.num_base_priors = self.prior_generator.num_base_priors[0] + + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + if self.train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + else: + self.sampler = PseudoSampler() + # only be used in `get_atss_targets` when `use_atss` is True + self.atss_prior_generator = TASK_UTILS.build(anchor_generator) + + self.fcos_prior_generator = MlvlPointGenerator( + anchor_generator['strides'], + self.anchor_center_offset if self.use_atss else 0.5) + + # In order to reuse the `get_bboxes` in `BaseDenseHead. + # Only be used in testing phase. + self.prior_generator = self.fcos_prior_generator + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + super(FCOSHead, self)._init_cls_convs() + super(FCOSHead, self)._init_reg_convs() + self.relu = nn.ReLU() + self.vfnet_reg_conv = ConvModule( + self.feat_channels, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias) + self.vfnet_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + self.vfnet_reg_refine_dconv = DeformConv2d( + self.feat_channels, + self.feat_channels, + self.dcn_kernel, + 1, + padding=self.dcn_pad) + self.vfnet_reg_refine = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + self.scales_refine = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + self.vfnet_cls_dconv = DeformConv2d( + self.feat_channels, + self.feat_channels, + self.dcn_kernel, + 1, + padding=self.dcn_pad) + self.vfnet_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: + + - cls_scores (list[Tensor]): Box iou-aware scores for each scale + level, each is a 4D-tensor, the channel number is + num_points * num_classes. + - bbox_preds (list[Tensor]): Box offsets for each + scale level, each is a 4D-tensor, the channel number is + num_points * 4. + - bbox_preds_refine (list[Tensor]): Refined Box offsets for + each scale level, each is a 4D-tensor, the channel + number is num_points * 4. + """ + return multi_apply(self.forward_single, x, self.scales, + self.scales_refine, self.strides, self.reg_denoms) + + def forward_single(self, x: Tensor, scale: Scale, scale_refine: Scale, + stride: int, reg_denom: int) -> tuple: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + scale_refine (:obj: `mmcv.cnn.Scale`): Learnable scale module to + resize the refined bbox prediction. + stride (int): The corresponding stride for feature maps, + used to normalize the bbox prediction when + bbox_norm_type = 'stride'. + reg_denom (int): The corresponding regression range for feature + maps, only used to normalize the bbox prediction when + bbox_norm_type = 'reg_denom'. + + Returns: + tuple: iou-aware cls scores for each box, bbox predictions and + refined bbox predictions of input feature maps. + """ + cls_feat = x + reg_feat = x + + for cls_layer in self.cls_convs: + cls_feat = cls_layer(cls_feat) + + for reg_layer in self.reg_convs: + reg_feat = reg_layer(reg_feat) + + # predict the bbox_pred of different level + reg_feat_init = self.vfnet_reg_conv(reg_feat) + if self.bbox_norm_type == 'reg_denom': + bbox_pred = scale( + self.vfnet_reg(reg_feat_init)).float().exp() * reg_denom + elif self.bbox_norm_type == 'stride': + bbox_pred = scale( + self.vfnet_reg(reg_feat_init)).float().exp() * stride + else: + raise NotImplementedError + + # compute star deformable convolution offsets + # converting dcn_offset to reg_feat.dtype thus VFNet can be + # trained with FP16 + dcn_offset = self.star_dcn_offset(bbox_pred, self.gradient_mul, + stride).to(reg_feat.dtype) + + # refine the bbox_pred + reg_feat = self.relu(self.vfnet_reg_refine_dconv(reg_feat, dcn_offset)) + bbox_pred_refine = scale_refine( + self.vfnet_reg_refine(reg_feat)).float().exp() + bbox_pred_refine = bbox_pred_refine * bbox_pred.detach() + + # predict the iou-aware cls score + cls_feat = self.relu(self.vfnet_cls_dconv(cls_feat, dcn_offset)) + cls_score = self.vfnet_cls(cls_feat) + + if self.training: + return cls_score, bbox_pred, bbox_pred_refine + else: + return cls_score, bbox_pred_refine + + def star_dcn_offset(self, bbox_pred: Tensor, gradient_mul: float, + stride: int) -> Tensor: + """Compute the star deformable conv offsets. + + Args: + bbox_pred (Tensor): Predicted bbox distance offsets (l, r, t, b). + gradient_mul (float): Gradient multiplier. + stride (int): The corresponding stride for feature maps, + used to project the bbox onto the feature map. + + Returns: + Tensor: The offsets for deformable convolution. + """ + dcn_base_offset = self.dcn_base_offset.type_as(bbox_pred) + bbox_pred_grad_mul = (1 - gradient_mul) * bbox_pred.detach() + \ + gradient_mul * bbox_pred + # map to the feature map scale + bbox_pred_grad_mul = bbox_pred_grad_mul / stride + N, C, H, W = bbox_pred.size() + + x1 = bbox_pred_grad_mul[:, 0, :, :] + y1 = bbox_pred_grad_mul[:, 1, :, :] + x2 = bbox_pred_grad_mul[:, 2, :, :] + y2 = bbox_pred_grad_mul[:, 3, :, :] + bbox_pred_grad_mul_offset = bbox_pred.new_zeros( + N, 2 * self.num_dconv_points, H, W) + bbox_pred_grad_mul_offset[:, 0, :, :] = -1.0 * y1 # -y1 + bbox_pred_grad_mul_offset[:, 1, :, :] = -1.0 * x1 # -x1 + bbox_pred_grad_mul_offset[:, 2, :, :] = -1.0 * y1 # -y1 + bbox_pred_grad_mul_offset[:, 4, :, :] = -1.0 * y1 # -y1 + bbox_pred_grad_mul_offset[:, 5, :, :] = x2 # x2 + bbox_pred_grad_mul_offset[:, 7, :, :] = -1.0 * x1 # -x1 + bbox_pred_grad_mul_offset[:, 11, :, :] = x2 # x2 + bbox_pred_grad_mul_offset[:, 12, :, :] = y2 # y2 + bbox_pred_grad_mul_offset[:, 13, :, :] = -1.0 * x1 # -x1 + bbox_pred_grad_mul_offset[:, 14, :, :] = y2 # y2 + bbox_pred_grad_mul_offset[:, 16, :, :] = y2 # y2 + bbox_pred_grad_mul_offset[:, 17, :, :] = x2 # x2 + dcn_offset = bbox_pred_grad_mul_offset - dcn_base_offset + + return dcn_offset + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + bbox_preds_refine: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Compute loss of the head. + + Args: + cls_scores (list[Tensor]): Box iou-aware scores for each scale + level, each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box offsets for each + scale level, each is a 4D-tensor, the channel number is + num_points * 4. + bbox_preds_refine (list[Tensor]): Refined Box offsets for + each scale level, each is a 4D-tensor, the channel + number is num_points * 4. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.fcos_prior_generator.grid_priors( + featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) + labels, label_weights, bbox_targets, bbox_weights = self.get_targets( + cls_scores, + all_level_points, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores, bbox_preds and bbox_preds_refine + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, + 1).reshape(-1, + self.cls_out_channels).contiguous() + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4).contiguous() + for bbox_pred in bbox_preds + ] + flatten_bbox_preds_refine = [ + bbox_pred_refine.permute(0, 2, 3, 1).reshape(-1, 4).contiguous() + for bbox_pred_refine in bbox_preds_refine + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_bbox_preds_refine = torch.cat(flatten_bbox_preds_refine) + flatten_labels = torch.cat(labels) + flatten_bbox_targets = torch.cat(bbox_targets) + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + # FG cat_id: [0, num_classes - 1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = torch.where( + ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)) > 0)[0] + num_pos = len(pos_inds) + + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_bbox_preds_refine = flatten_bbox_preds_refine[pos_inds] + pos_labels = flatten_labels[pos_inds] + + # sync num_pos across all gpus + if self.sync_num_pos: + num_pos_avg_per_gpu = reduce_mean( + pos_inds.new_tensor(num_pos).float()).item() + num_pos_avg_per_gpu = max(num_pos_avg_per_gpu, 1.0) + else: + num_pos_avg_per_gpu = num_pos + + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_points = flatten_points[pos_inds] + + pos_decoded_bbox_preds = self.bbox_coder.decode( + pos_points, pos_bbox_preds) + pos_decoded_target_preds = self.bbox_coder.decode( + pos_points, pos_bbox_targets) + iou_targets_ini = bbox_overlaps( + pos_decoded_bbox_preds, + pos_decoded_target_preds.detach(), + is_aligned=True).clamp(min=1e-6) + bbox_weights_ini = iou_targets_ini.clone().detach() + bbox_avg_factor_ini = reduce_mean( + bbox_weights_ini.sum()).clamp_(min=1).item() + + pos_decoded_bbox_preds_refine = \ + self.bbox_coder.decode(pos_points, pos_bbox_preds_refine) + iou_targets_rf = bbox_overlaps( + pos_decoded_bbox_preds_refine, + pos_decoded_target_preds.detach(), + is_aligned=True).clamp(min=1e-6) + bbox_weights_rf = iou_targets_rf.clone().detach() + bbox_avg_factor_rf = reduce_mean( + bbox_weights_rf.sum()).clamp_(min=1).item() + + if num_pos > 0: + loss_bbox = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds.detach(), + weight=bbox_weights_ini, + avg_factor=bbox_avg_factor_ini) + + loss_bbox_refine = self.loss_bbox_refine( + pos_decoded_bbox_preds_refine, + pos_decoded_target_preds.detach(), + weight=bbox_weights_rf, + avg_factor=bbox_avg_factor_rf) + + # build IoU-aware cls_score targets + if self.use_vfl: + pos_ious = iou_targets_rf.clone().detach() + cls_iou_targets = torch.zeros_like(flatten_cls_scores) + cls_iou_targets[pos_inds, pos_labels] = pos_ious + else: + loss_bbox = pos_bbox_preds.sum() * 0 + loss_bbox_refine = pos_bbox_preds_refine.sum() * 0 + if self.use_vfl: + cls_iou_targets = torch.zeros_like(flatten_cls_scores) + + if self.use_vfl: + loss_cls = self.loss_cls( + flatten_cls_scores, + cls_iou_targets, + avg_factor=num_pos_avg_per_gpu) + else: + loss_cls = self.loss_cls( + flatten_cls_scores, + flatten_labels, + weight=label_weights, + avg_factor=num_pos_avg_per_gpu) + + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_bbox_rf=loss_bbox_refine) + + def get_targets( + self, + cls_scores: List[Tensor], + mlvl_points: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> tuple: + """A wrapper for computing ATSS and FCOS targets for points in multiple + images. + + Args: + cls_scores (list[Tensor]): Box iou-aware scores for each scale + level with shape (N, num_points * num_classes, H, W). + mlvl_points (list[Tensor]): Points of each fpn level, each has + shape (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + tuple: + + - labels_list (list[Tensor]): Labels of each level. + - label_weights (Tensor/None): Label weights of all levels. + - bbox_targets_list (list[Tensor]): Regression targets of each + level, (l, t, r, b). + - bbox_weights (Tensor/None): Bbox weights of all levels. + """ + if self.use_atss: + return self.get_atss_targets(cls_scores, mlvl_points, + batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + else: + self.norm_on_bbox = False + return self.get_fcos_targets(mlvl_points, batch_gt_instances) + + def _get_targets_single(self, *args, **kwargs): + """Avoid ambiguity in multiple inheritance.""" + if self.use_atss: + return ATSSHead._get_targets_single(self, *args, **kwargs) + else: + return FCOSHead._get_targets_single(self, *args, **kwargs) + + def get_fcos_targets(self, points: List[Tensor], + batch_gt_instances: InstanceList) -> tuple: + """Compute FCOS regression and classification targets for points in + multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: + + - labels (list[Tensor]): Labels of each level. + - label_weights: None, to be compatible with ATSS targets. + - bbox_targets (list[Tensor]): BBox targets of each level. + - bbox_weights: None, to be compatible with ATSS targets. + """ + labels, bbox_targets = FCOSHead.get_targets(self, points, + batch_gt_instances) + label_weights = None + bbox_weights = None + return labels, label_weights, bbox_targets, bbox_weights + + def get_anchors(self, + featmap_sizes: List[Tuple], + batch_img_metas: List[dict], + device: str = 'cuda') -> tuple: + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + batch_img_metas (list[dict]): Image meta info. + device (str): Device for returned tensors + + Returns: + tuple: + + - anchor_list (list[Tensor]): Anchors of each image. + - valid_flag_list (list[Tensor]): Valid flags of each image. + """ + num_imgs = len(batch_img_metas) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = self.atss_prior_generator.grid_priors( + featmap_sizes, device=device) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + multi_level_flags = self.atss_prior_generator.valid_flags( + featmap_sizes, img_meta['pad_shape'], device=device) + valid_flag_list.append(multi_level_flags) + + return anchor_list, valid_flag_list + + def get_atss_targets( + self, + cls_scores: List[Tensor], + mlvl_points: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> tuple: + """A wrapper for computing ATSS targets for points in multiple images. + + Args: + cls_scores (list[Tensor]): Box iou-aware scores for each scale + level with shape (N, num_points * num_classes, H, W). + mlvl_points (list[Tensor]): Points of each fpn level, each has + shape (num_points, 2). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + tuple: + + - labels_list (list[Tensor]): Labels of each level. + - label_weights (Tensor): Label weights of all levels. + - bbox_targets_list (list[Tensor]): Regression targets of each + level, (l, t, r, b). + - bbox_weights (Tensor): Bbox weights of all levels. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len( + featmap_sizes + ) == self.atss_prior_generator.num_levels == \ + self.fcos_prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = ATSSHead.get_targets( + self, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=True) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_reg_targets + + bbox_targets_list = [ + bbox_targets.reshape(-1, 4) for bbox_targets in bbox_targets_list + ] + + num_imgs = len(batch_img_metas) + # transform bbox_targets (x1, y1, x2, y2) into (l, t, r, b) format + bbox_targets_list = self.transform_bbox_targets( + bbox_targets_list, mlvl_points, num_imgs) + + labels_list = [labels.reshape(-1) for labels in labels_list] + label_weights_list = [ + label_weights.reshape(-1) for label_weights in label_weights_list + ] + bbox_weights_list = [ + bbox_weights.reshape(-1) for bbox_weights in bbox_weights_list + ] + label_weights = torch.cat(label_weights_list) + bbox_weights = torch.cat(bbox_weights_list) + return labels_list, label_weights, bbox_targets_list, bbox_weights + + def transform_bbox_targets(self, decoded_bboxes: List[Tensor], + mlvl_points: List[Tensor], + num_imgs: int) -> List[Tensor]: + """Transform bbox_targets (x1, y1, x2, y2) into (l, t, r, b) format. + + Args: + decoded_bboxes (list[Tensor]): Regression targets of each level, + in the form of (x1, y1, x2, y2). + mlvl_points (list[Tensor]): Points of each fpn level, each has + shape (num_points, 2). + num_imgs (int): the number of images in a batch. + + Returns: + bbox_targets (list[Tensor]): Regression targets of each level in + the form of (l, t, r, b). + """ + # TODO: Re-implemented in Class PointCoder + assert len(decoded_bboxes) == len(mlvl_points) + num_levels = len(decoded_bboxes) + mlvl_points = [points.repeat(num_imgs, 1) for points in mlvl_points] + bbox_targets = [] + for i in range(num_levels): + bbox_target = self.bbox_coder.encode(mlvl_points[i], + decoded_bboxes[i]) + bbox_targets.append(bbox_target) + + return bbox_targets + + def _load_from_state_dict(self, state_dict: dict, prefix: str, + local_metadata: dict, strict: bool, + missing_keys: Union[List[str], str], + unexpected_keys: Union[List[str], str], + error_msgs: Union[List[str], str]) -> None: + """Override the method in the parent class to avoid changing para's + name.""" + pass diff --git a/mmdet/models/dense_heads/yolact_head.py b/mmdet/models/dense_heads/yolact_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3390c136a31bee81134667eb28ad8829ddb84cc3 --- /dev/null +++ b/mmdet/models/dense_heads/yolact_head.py @@ -0,0 +1,1193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, OptMultiConfig) +from ..layers import fast_nms +from ..utils import images_to_levels, multi_apply, select_single_mlvl +from ..utils.misc import empty_instances +from .anchor_head import AnchorHead +from .base_mask_head import BaseMaskHead + + +@MODELS.register_module() +class YOLACTHead(AnchorHead): + """YOLACT box head used in https://arxiv.org/abs/1904.02689. + + Note that YOLACT head is a light version of RetinaNet head. + Four differences are described as follows: + + 1. YOLACT box head has three-times fewer anchors. + 2. YOLACT box head shares the convs for box and cls branches. + 3. YOLACT box head uses OHEM instead of Focal loss. + 4. YOLACT box head predicts a set of mask coefficients for each box. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + anchor_generator (:obj:`ConfigDict` or dict): Config dict for + anchor generator + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. + num_head_convs (int): Number of the conv layers shared by + box and cls branches. + num_protos (int): Number of the mask coefficients. + use_ohem (bool): If true, ``loss_single_OHEM`` will be used for + cls loss calculation. If false, ``loss_single`` will be used. + conv_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to + construct and config conv layer. + norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to + construct and config norm layer. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + anchor_generator: ConfigType = dict( + type='AnchorGenerator', + octave_base_scale=3, + scales_per_octave=1, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + reduction='none', + loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='SmoothL1Loss', beta=1.0, loss_weight=1.5), + num_head_convs: int = 1, + num_protos: int = 32, + use_ohem: bool = True, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = dict( + type='Xavier', + distribution='uniform', + bias=0, + layer='Conv2d'), + **kwargs) -> None: + self.num_head_convs = num_head_convs + self.num_protos = num_protos + self.use_ohem = use_ohem + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + anchor_generator=anchor_generator, + init_cfg=init_cfg, + **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.head_convs = ModuleList() + for i in range(self.num_head_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.head_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.conv_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.conv_reg = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + self.conv_coeff = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.num_protos, + 3, + padding=1) + + def forward_single(self, x: Tensor) -> tuple: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + + - cls_score (Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for a single scale + level, the channels number is num_anchors * 4. + - coeff_pred (Tensor): Mask coefficients for a single scale + level, the channels number is num_anchors * num_protos. + """ + for head_conv in self.head_convs: + x = head_conv(x) + cls_score = self.conv_cls(x) + bbox_pred = self.conv_reg(x) + coeff_pred = self.conv_coeff(x).tanh() + return cls_score, bbox_pred, coeff_pred + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + coeff_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the bbox head. + + When ``self.use_ohem == True``, it functions like ``SSDHead.loss``, + otherwise, it follows ``AnchorHead.loss``. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + coeff_preds (list[Tensor]): Mask coefficients for each scale + level with shape (N, num_anchors * num_protos, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + unmap_outputs=not self.use_ohem, + return_sampling_results=True) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor, sampling_results) = cls_reg_targets + + if self.use_ohem: + num_images = len(batch_img_metas) + all_cls_scores = torch.cat([ + s.permute(0, 2, 3, 1).reshape( + num_images, -1, self.cls_out_channels) for s in cls_scores + ], 1) + all_labels = torch.cat(labels_list, -1).view(num_images, -1) + all_label_weights = torch.cat(label_weights_list, + -1).view(num_images, -1) + all_bbox_preds = torch.cat([ + b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) + for b in bbox_preds + ], -2) + all_bbox_targets = torch.cat(bbox_targets_list, + -2).view(num_images, -1, 4) + all_bbox_weights = torch.cat(bbox_weights_list, + -2).view(num_images, -1, 4) + + # concat all level anchors to a single tensor + all_anchors = [] + for i in range(num_images): + all_anchors.append(torch.cat(anchor_list[i])) + + # check NaN and Inf + assert torch.isfinite(all_cls_scores).all().item(), \ + 'classification scores become infinite or NaN!' + assert torch.isfinite(all_bbox_preds).all().item(), \ + 'bbox predications become infinite or NaN!' + + losses_cls, losses_bbox = multi_apply( + self.OHEMloss_by_feat_single, + all_cls_scores, + all_bbox_preds, + all_anchors, + all_labels, + all_label_weights, + all_bbox_targets, + all_bbox_weights, + avg_factor=avg_factor) + else: + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i in range(len(anchor_list)): + concat_anchor_list.append(torch.cat(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + losses_cls, losses_bbox = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + all_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + avg_factor=avg_factor) + losses = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + # update `_raw_positive_infos`, which will be used when calling + # `get_positive_infos`. + self._raw_positive_infos.update(coeff_preds=coeff_preds) + return losses + + def OHEMloss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + anchors: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, + avg_factor: int) -> tuple: + """Compute loss of a single image. Similar to + func:``SSDHead.loss_by_feat_single`` + + Args: + cls_score (Tensor): Box scores for eachimage + Has shape (num_total_anchors, num_classes). + bbox_pred (Tensor): Box energies / deltas for each image + level with shape (num_total_anchors, 4). + anchors (Tensor): Box reference for each scale level with shape + (num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape + (num_total_anchors,). + label_weights (Tensor): Label weights of each anchor with shape + (num_total_anchors,) + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (num_total_anchors, 4). + bbox_weights (Tensor): BBox regression loss weights of each anchor + with shape (num_total_anchors, 4). + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + Tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one + feature map. + """ + + loss_cls_all = self.loss_cls(cls_score, labels, label_weights) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero( + as_tuple=False).reshape(-1) + neg_inds = (labels == self.num_classes).nonzero( + as_tuple=False).view(-1) + + num_pos_samples = pos_inds.size(0) + if num_pos_samples == 0: + num_neg_samples = neg_inds.size(0) + else: + num_neg_samples = self.train_cfg['neg_pos_ratio'] * \ + num_pos_samples + if num_neg_samples > neg_inds.size(0): + num_neg_samples = neg_inds.size(0) + topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) + loss_cls_pos = loss_cls_all[pos_inds].sum() + loss_cls_neg = topk_loss_cls_neg.sum() + loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor + if self.reg_decoded_bbox: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, it + # decodes the already encoded coordinates to absolute format. + bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) + loss_bbox = self.loss_bbox( + bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) + return loss_cls[None], loss_bbox + + def get_positive_infos(self) -> InstanceList: + """Get positive information from sampling results. + + Returns: + list[:obj:`InstanceData`]: Positive Information of each image, + usually including positive bboxes, positive labels, positive + priors, positive coeffs, etc. + """ + assert len(self._raw_positive_infos) > 0 + sampling_results = self._raw_positive_infos['sampling_results'] + num_imgs = len(sampling_results) + + coeff_pred_list = [] + for coeff_pred_per_level in self._raw_positive_infos['coeff_preds']: + coeff_pred_per_level = \ + coeff_pred_per_level.permute( + 0, 2, 3, 1).reshape(num_imgs, -1, self.num_protos) + coeff_pred_list.append(coeff_pred_per_level) + coeff_preds = torch.cat(coeff_pred_list, dim=1) + + pos_info_list = [] + for idx, sampling_result in enumerate(sampling_results): + pos_info = InstanceData() + coeff_preds_single = coeff_preds[idx] + pos_info.pos_assigned_gt_inds = \ + sampling_result.pos_assigned_gt_inds + pos_info.pos_inds = sampling_result.pos_inds + pos_info.coeffs = coeff_preds_single[sampling_result.pos_inds] + pos_info.bboxes = sampling_result.pos_gt_bboxes + pos_info_list.append(pos_info) + return pos_info_list + + def predict_by_feat(self, + cls_scores, + bbox_preds, + coeff_preds, + batch_img_metas, + cfg=None, + rescale=True, + **kwargs): + """Similar to func:``AnchorHead.get_bboxes``, but additionally + processes coeff_preds. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + with shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + coeff_preds (list[Tensor]): Mask coefficients for each scale + level with shape (N, num_anchors * num_protos, H, W) + batch_img_metas (list[dict]): Batch image meta info. + cfg (:obj:`Config` | None): Test / postprocessing configuration, + if None, test_cfg would be used + rescale (bool): If True, return boxes in original image space. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - coeffs (Tensor): the predicted mask coefficients of + instance inside the corresponding box has a shape + (n, num_protos). + """ + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, device=device) + + result_list = [] + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl(cls_scores, img_id) + bbox_pred_list = select_single_mlvl(bbox_preds, img_id) + coeff_pred_list = select_single_mlvl(coeff_preds, img_id) + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + coeff_preds_list=coeff_pred_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + coeff_preds_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigType, + rescale: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. Similar to func:``AnchorHead._predict_by_feat_single``, + but additionally processes coeff_preds_list and uses fast NMS instead + of traditional NMS. + + Args: + cls_score_list (list[Tensor]): Box scores for a single scale level + Has shape (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas for a single + scale level with shape (num_priors * 4, H, W). + coeff_preds_list (list[Tensor]): Mask coefficients for a single + scale level with shape (num_priors * num_protos, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, + has shape (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - coeffs (Tensor): the predicted mask coefficients of + instance inside the corresponding box has a shape + (n, num_protos). + """ + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_priors) + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_coeffs = [] + for cls_score, bbox_pred, coeff_pred, priors in \ + zip(cls_score_list, bbox_pred_list, + coeff_preds_list, mlvl_priors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1) + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + coeff_pred = coeff_pred.permute(1, 2, + 0).reshape(-1, self.num_protos) + + if 0 < nms_pre < scores.shape[0]: + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores, _ = scores.max(dim=1) + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + max_scores, _ = scores[:, :-1].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + priors = priors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + coeff_pred = coeff_pred[topk_inds, :] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_coeffs.append(coeff_pred) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = torch.cat(mlvl_valid_priors) + multi_bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=img_shape) + + multi_scores = torch.cat(mlvl_scores) + multi_coeffs = torch.cat(mlvl_coeffs) + + return self._bbox_post_process( + multi_bboxes=multi_bboxes, + multi_scores=multi_scores, + multi_coeffs=multi_coeffs, + cfg=cfg, + rescale=rescale, + img_meta=img_meta) + + def _bbox_post_process(self, + multi_bboxes: Tensor, + multi_scores: Tensor, + multi_coeffs: Tensor, + cfg: ConfigType, + rescale: bool = False, + img_meta: Optional[dict] = None, + **kwargs) -> InstanceData: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + + Args: + multi_bboxes (Tensor): Predicted bbox that concat all levels. + multi_scores (Tensor): Bbox scores that concat all levels. + multi_coeffs (Tensor): Mask coefficients that concat all levels. + cfg (ConfigDict): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - coeffs (Tensor): the predicted mask coefficients of + instance inside the corresponding box has a shape + (n, num_protos). + """ + if rescale: + assert img_meta.get('scale_factor') is not None + multi_bboxes /= multi_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + # mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + + if self.use_sigmoid_cls: + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + + padding = multi_scores.new_zeros(multi_scores.shape[0], 1) + multi_scores = torch.cat([multi_scores, padding], dim=1) + det_bboxes, det_labels, det_coeffs = fast_nms( + multi_bboxes, multi_scores, multi_coeffs, cfg.score_thr, + cfg.iou_thr, cfg.top_k, cfg.max_per_img) + results = InstanceData() + results.bboxes = det_bboxes[:, :4] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + results.coeffs = det_coeffs + return results + + +@MODELS.register_module() +class YOLACTProtonet(BaseMaskHead): + """YOLACT mask head used in https://arxiv.org/abs/1904.02689. + + This head outputs the mask prototypes for YOLACT. + + Args: + in_channels (int): Number of channels in the input feature map. + proto_channels (tuple[int]): Output channels of protonet convs. + proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs. + include_last_relu (bool): If keep the last relu of protonet. + num_protos (int): Number of prototypes. + num_classes (int): Number of categories excluding the background + category. + loss_mask_weight (float): Reweight the mask loss by this factor. + max_masks_to_train (int): Maximum number of masks to train for + each image. + with_seg_branch (bool): Whether to apply a semantic segmentation + branch and calculate loss during training to increase + performance with no speed penalty. Defaults to True. + loss_segm (:obj:`ConfigDict` or dict, optional): Config of + semantic segmentation loss. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config + of head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + head. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + in_channels: int = 256, + proto_channels: tuple = (256, 256, 256, None, 256, 32), + proto_kernel_sizes: tuple = (3, 3, 3, -2, 3, 1), + include_last_relu: bool = True, + num_protos: int = 32, + loss_mask_weight: float = 1.0, + max_masks_to_train: int = 100, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + with_seg_branch: bool = True, + loss_segm: ConfigType = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + init_cfg=dict( + type='Xavier', + distribution='uniform', + override=dict(name='protonet')) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.proto_channels = proto_channels + self.proto_kernel_sizes = proto_kernel_sizes + self.include_last_relu = include_last_relu + + # Segmentation branch + self.with_seg_branch = with_seg_branch + self.segm_branch = SegmentationModule( + num_classes=num_classes, in_channels=in_channels) \ + if with_seg_branch else None + self.loss_segm = MODELS.build(loss_segm) if with_seg_branch else None + + self.loss_mask_weight = loss_mask_weight + self.num_protos = num_protos + self.num_classes = num_classes + self.max_masks_to_train = max_masks_to_train + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + # Possible patterns: + # ( 256, 3) -> conv + # ( 256,-2) -> deconv + # (None,-2) -> bilinear interpolate + in_channels = self.in_channels + protonets = ModuleList() + for num_channels, kernel_size in zip(self.proto_channels, + self.proto_kernel_sizes): + if kernel_size > 0: + layer = nn.Conv2d( + in_channels, + num_channels, + kernel_size, + padding=kernel_size // 2) + else: + if num_channels is None: + layer = InterpolateModule( + scale_factor=-kernel_size, + mode='bilinear', + align_corners=False) + else: + layer = nn.ConvTranspose2d( + in_channels, + num_channels, + -kernel_size, + padding=kernel_size // 2) + protonets.append(layer) + protonets.append(nn.ReLU(inplace=True)) + in_channels = num_channels if num_channels is not None \ + else in_channels + if not self.include_last_relu: + protonets = protonets[:-1] + self.protonet = nn.Sequential(*protonets) + + def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: + """Forward feature from the upstream network to get prototypes and + linearly combine the prototypes, using masks coefficients, into + instance masks. Finally, crop the instance masks with given bboxes. + + Args: + x (Tuple[Tensor]): Feature from the upstream network, which is + a 4D-tensor. + positive_infos (List[:obj:``InstanceData``]): Positive information + that calculate from detect head. + + Returns: + tuple: Predicted instance segmentation masks and + semantic segmentation map. + """ + # YOLACT used single feature map to get segmentation masks + single_x = x[0] + + # YOLACT segmentation branch, if not training or segmentation branch + # is None, will not process the forward function. + if self.segm_branch is not None and self.training: + segm_preds = self.segm_branch(single_x) + else: + segm_preds = None + # YOLACT mask head + prototypes = self.protonet(single_x) + prototypes = prototypes.permute(0, 2, 3, 1).contiguous() + + num_imgs = single_x.size(0) + + mask_pred_list = [] + for idx in range(num_imgs): + cur_prototypes = prototypes[idx] + pos_coeffs = positive_infos[idx].coeffs + + # Linearly combine the prototypes with the mask coefficients + mask_preds = cur_prototypes @ pos_coeffs.t() + mask_preds = torch.sigmoid(mask_preds) + mask_pred_list.append(mask_preds) + return mask_pred_list, segm_preds + + def loss_by_feat(self, mask_preds: List[Tensor], segm_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], positive_infos: InstanceList, + **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (list[Tensor]): List of predicted prototypes, each has + shape (num_classes, H, W). + segm_preds (Tensor): Predicted semantic segmentation map with + shape (N, num_classes, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Information of + positive samples of each image that are assigned in detection + head. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert positive_infos is not None, \ + 'positive_infos should not be None in `YOLACTProtonet`' + losses = dict() + + # crop + croped_mask_pred = self.crop_mask_preds(mask_preds, batch_img_metas, + positive_infos) + + loss_mask = [] + loss_segm = [] + num_imgs, _, mask_h, mask_w = segm_preds.size() + assert num_imgs == len(croped_mask_pred) + segm_avg_factor = num_imgs * mask_h * mask_w + total_pos = 0 + + if self.segm_branch is not None: + assert segm_preds is not None + + for idx in range(num_imgs): + img_meta = batch_img_metas[idx] + + (mask_preds, pos_mask_targets, segm_targets, num_pos, + gt_bboxes_for_reweight) = self._get_targets_single( + croped_mask_pred[idx], segm_preds[idx], + batch_gt_instances[idx], positive_infos[idx]) + + # segmentation loss + if self.with_seg_branch: + if segm_targets is None: + loss = segm_preds[idx].sum() * 0. + else: + loss = self.loss_segm( + segm_preds[idx], + segm_targets, + avg_factor=segm_avg_factor) + loss_segm.append(loss) + # mask loss + total_pos += num_pos + if num_pos == 0 or pos_mask_targets is None: + loss = mask_preds.sum() * 0. + else: + mask_preds = torch.clamp(mask_preds, 0, 1) + loss = F.binary_cross_entropy( + mask_preds, pos_mask_targets, + reduction='none') * self.loss_mask_weight + + h, w = img_meta['img_shape'][:2] + gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] - + gt_bboxes_for_reweight[:, 0]) / w + gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] - + gt_bboxes_for_reweight[:, 1]) / h + loss = loss.mean(dim=(1, + 2)) / gt_bboxes_width / gt_bboxes_height + loss = torch.sum(loss) + loss_mask.append(loss) + + if total_pos == 0: + total_pos += 1 # avoid nan + loss_mask = [x / total_pos for x in loss_mask] + + losses.update(loss_mask=loss_mask) + if self.with_seg_branch: + losses.update(loss_segm=loss_segm) + + return losses + + def _get_targets_single(self, mask_preds: Tensor, segm_pred: Tensor, + gt_instances: InstanceData, + positive_info: InstanceData): + """Compute targets for predictions of single image. + + Args: + mask_preds (Tensor): Predicted prototypes with shape + (num_classes, H, W). + segm_pred (Tensor): Predicted semantic segmentation map + with shape (num_classes, H, W). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + positive_info (:obj:`InstanceData`): Information of positive + samples that are assigned in detection head. It usually + contains following keys. + + - pos_assigned_gt_inds (Tensor): Assigner GT indexes of + positive proposals, has shape (num_pos, ) + - pos_inds (Tensor): Positive index of image, has + shape (num_pos, ). + - coeffs (Tensor): Positive mask coefficients + with shape (num_pos, num_protos). + - bboxes (Tensor): Positive bboxes with shape + (num_pos, 4) + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - mask_preds (Tensor): Positive predicted mask with shape + (num_pos, mask_h, mask_w). + - pos_mask_targets (Tensor): Positive mask targets with shape + (num_pos, mask_h, mask_w). + - segm_targets (Tensor): Semantic segmentation targets with shape + (num_classes, segm_h, segm_w). + - num_pos (int): Positive numbers. + - gt_bboxes_for_reweight (Tensor): GT bboxes that match to the + positive priors has shape (num_pos, 4). + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + device = gt_bboxes.device + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device).float() + if gt_masks.size(0) == 0: + return mask_preds, None, None, 0, None + + # process with semantic segmentation targets + if segm_pred is not None: + num_classes, segm_h, segm_w = segm_pred.size() + with torch.no_grad(): + downsampled_masks = F.interpolate( + gt_masks.unsqueeze(0), (segm_h, segm_w), + mode='bilinear', + align_corners=False).squeeze(0) + downsampled_masks = downsampled_masks.gt(0.5).float() + segm_targets = torch.zeros_like(segm_pred, requires_grad=False) + for obj_idx in range(downsampled_masks.size(0)): + segm_targets[gt_labels[obj_idx] - 1] = torch.max( + segm_targets[gt_labels[obj_idx] - 1], + downsampled_masks[obj_idx]) + else: + segm_targets = None + # process with mask targets + pos_assigned_gt_inds = positive_info.pos_assigned_gt_inds + num_pos = pos_assigned_gt_inds.size(0) + # Since we're producing (near) full image masks, + # it'd take too much vram to backprop on every single mask. + # Thus we select only a subset. + if num_pos > self.max_masks_to_train: + perm = torch.randperm(num_pos) + select = perm[:self.max_masks_to_train] + mask_preds = mask_preds[select] + pos_assigned_gt_inds = pos_assigned_gt_inds[select] + num_pos = self.max_masks_to_train + + gt_bboxes_for_reweight = gt_bboxes[pos_assigned_gt_inds] + + mask_h, mask_w = mask_preds.shape[-2:] + gt_masks = F.interpolate( + gt_masks.unsqueeze(0), (mask_h, mask_w), + mode='bilinear', + align_corners=False).squeeze(0) + gt_masks = gt_masks.gt(0.5).float() + pos_mask_targets = gt_masks[pos_assigned_gt_inds] + + return (mask_preds, pos_mask_targets, segm_targets, num_pos, + gt_bboxes_for_reweight) + + def crop_mask_preds(self, mask_preds: List[Tensor], + batch_img_metas: List[dict], + positive_infos: InstanceList) -> list: + """Crop predicted masks by zeroing out everything not in the predicted + bbox. + + Args: + mask_preds (list[Tensor]): Predicted prototypes with shape + (num_classes, H, W). + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Positive + information that calculate from detect head. + + Returns: + list: The cropped masks. + """ + croped_mask_preds = [] + for img_meta, mask_preds, cur_info in zip(batch_img_metas, mask_preds, + positive_infos): + bboxes_for_cropping = copy.deepcopy(cur_info.bboxes) + h, w = img_meta['img_shape'][:2] + bboxes_for_cropping[:, 0::2] /= w + bboxes_for_cropping[:, 1::2] /= h + mask_preds = self.crop_single(mask_preds, bboxes_for_cropping) + mask_preds = mask_preds.permute(2, 0, 1).contiguous() + croped_mask_preds.append(mask_preds) + return croped_mask_preds + + def crop_single(self, + masks: Tensor, + boxes: Tensor, + padding: int = 1) -> Tensor: + """Crop single predicted masks by zeroing out everything not in the + predicted bbox. + + Args: + masks (Tensor): Predicted prototypes, has shape [H, W, N]. + boxes (Tensor): Bbox coords in relative point form with + shape [N, 4]. + padding (int): Image padding size. + + Return: + Tensor: The cropped masks. + """ + h, w, n = masks.size() + x1, x2 = self.sanitize_coordinates( + boxes[:, 0], boxes[:, 2], w, padding, cast=False) + y1, y2 = self.sanitize_coordinates( + boxes[:, 1], boxes[:, 3], h, padding, cast=False) + + rows = torch.arange( + w, device=masks.device, dtype=x1.dtype).view(1, -1, + 1).expand(h, w, n) + cols = torch.arange( + h, device=masks.device, dtype=x1.dtype).view(-1, 1, + 1).expand(h, w, n) + + masks_left = rows >= x1.view(1, 1, -1) + masks_right = rows < x2.view(1, 1, -1) + masks_up = cols >= y1.view(1, 1, -1) + masks_down = cols < y2.view(1, 1, -1) + + crop_mask = masks_left * masks_right * masks_up * masks_down + + return masks * crop_mask.float() + + def sanitize_coordinates(self, + x1: Tensor, + x2: Tensor, + img_size: int, + padding: int = 0, + cast: bool = True) -> tuple: + """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, + and x2 <= image_size. Also converts from relative to absolute + coordinates and casts the results to long tensors. + + Warning: this does things in-place behind the scenes so + copy if necessary. + + Args: + x1 (Tensor): shape (N, ). + x2 (Tensor): shape (N, ). + img_size (int): Size of the input image. + padding (int): x1 >= padding, x2 <= image_size-padding. + cast (bool): If cast is false, the result won't be cast to longs. + + Returns: + tuple: + + - x1 (Tensor): Sanitized _x1. + - x2 (Tensor): Sanitized _x2. + """ + x1 = x1 * img_size + x2 = x2 * img_size + if cast: + x1 = x1.long() + x2 = x2.long() + x1 = torch.min(x1, x2) + x2 = torch.max(x1, x2) + x1 = torch.clamp(x1 - padding, min=0) + x2 = torch.clamp(x2 + padding, max=img_size) + return x1, x2 + + def predict_by_feat(self, + mask_preds: List[Tensor], + segm_preds: Tensor, + results_list: InstanceList, + batch_img_metas: List[dict], + rescale: bool = True, + **kwargs) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. + + Args: + mask_preds (list[Tensor]): Predicted prototypes with shape + (num_classes, H, W). + results_list (List[:obj:``InstanceData``]): BBoxHead results. + batch_img_metas (list[dict]): Meta information of all images. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + assert len(mask_preds) == len(results_list) == len(batch_img_metas) + + croped_mask_pred = self.crop_mask_preds(mask_preds, batch_img_metas, + results_list) + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = results_list[img_id] + bboxes = results.bboxes + mask_preds = croped_mask_pred[img_id] + if bboxes.shape[0] == 0 or mask_preds.shape[0] == 0: + results_list[img_id] = empty_instances( + [img_meta], + bboxes.device, + task_type='mask', + instance_results=[results])[0] + else: + im_mask = self._predict_by_feat_single( + mask_preds=croped_mask_pred[img_id], + bboxes=bboxes, + img_meta=img_meta, + rescale=rescale) + results.masks = im_mask + return results_list + + def _predict_by_feat_single(self, + mask_preds: Tensor, + bboxes: Tensor, + img_meta: dict, + rescale: bool, + cfg: OptConfigType = None): + """Transform a single image's features extracted from the head into + mask results. + + Args: + mask_preds (Tensor): Predicted prototypes, has shape [H, W, N]. + bboxes (Tensor): Bbox coords in relative point form with + shape [N, 4]. + img_meta (dict): Meta information of each image, e.g., + image size, scaling factor, etc. + rescale (bool): If rescale is False, then returned masks will + fit the scale of imgs[0]. + cfg (dict, optional): Config used in test phase. + Defaults to None. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + cfg = self.test_cfg if cfg is None else cfg + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + img_h, img_w = img_meta['ori_shape'][:2] + if rescale: # in-placed rescale the bboxes + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + bboxes /= scale_factor + else: + w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] + img_h = np.round(img_h * h_scale.item()).astype(np.int32) + img_w = np.round(img_w * w_scale.item()).astype(np.int32) + + masks = F.interpolate( + mask_preds.unsqueeze(0), (img_h, img_w), + mode='bilinear', + align_corners=False).squeeze(0) > cfg.mask_thr + + if cfg.mask_thr_binary < 0: + # for visualization and debugging + masks = (masks * 255).to(dtype=torch.uint8) + + return masks + + +class SegmentationModule(BaseModule): + """YOLACT segmentation branch used in `_ + + In mmdet v2.x `segm_loss` is calculated in YOLACTSegmHead, while in + mmdet v3.x `SegmentationModule` is used to obtain the predicted semantic + segmentation map and `segm_loss` is calculated in YOLACTProtonet. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + in_channels: int = 256, + init_cfg: ConfigType = dict( + type='Xavier', + distribution='uniform', + override=dict(name='segm_conv')) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_classes = num_classes + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.segm_conv = nn.Conv2d( + self.in_channels, self.num_classes, kernel_size=1) + + def forward(self, x: Tensor) -> Tensor: + """Forward feature from the upstream network. + + Args: + x (Tensor): Feature from the upstream network, which is + a 4D-tensor. + + Returns: + Tensor: Predicted semantic segmentation map with shape + (N, num_classes, H, W). + """ + return self.segm_conv(x) + + +class InterpolateModule(BaseModule): + """This is a module version of F.interpolate. + + Any arguments you give it just get passed along for the ride. + """ + + def __init__(self, *args, init_cfg=None, **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.args = args + self.kwargs = kwargs + + def forward(self, x: Tensor) -> Tensor: + """Forward features from the upstream network. + + Args: + x (Tensor): Feature from the upstream network, which is + a 4D-tensor. + + Returns: + Tensor: A 4D-tensor feature map. + """ + return F.interpolate(x, *self.args, **self.kwargs) diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0f63afbbc94353e16e4c67ec5bc0b6cd1200de07 --- /dev/null +++ b/mmdet/models/dense_heads/yolo_head.py @@ -0,0 +1,527 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +import copy +import warnings +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, is_norm +from mmengine.model import bias_init_with_prob, constant_init, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList) +from ..task_modules.samplers import PseudoSampler +from ..utils import filter_scores_and_topk, images_to_levels, multi_apply +from .base_dense_head import BaseDenseHead + + +@MODELS.register_module() +class YOLOV3Head(BaseDenseHead): + """YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767. + + Args: + num_classes (int): The number of object classes (w/o background) + in_channels (Sequence[int]): Number of input channels per scale. + out_channels (Sequence[int]): The number of output channels per scale + before the final 1x1 layer. Default: (1024, 512, 256). + anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor + generator. + bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder. + featmap_strides (Sequence[int]): The stride of each scale. + Should be in descending order. Defaults to (32, 16, 8). + one_hot_smoother (float): Set a non-zero value to enable label-smooth + Defaults to 0. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and + config norm layer. Defaults to dict(type='BN', requires_grad=True). + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Defaults to dict(type='LeakyReLU', negative_slope=0.1). + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_conf (:obj:`ConfigDict` or dict): Config of confidence loss. + loss_xy (:obj:`ConfigDict` or dict): Config of xy coordinate loss. + loss_wh (:obj:`ConfigDict` or dict): Config of wh coordinate loss. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + YOLOV3 head. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + YOLOV3 head. Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: Sequence[int], + out_channels: Sequence[int] = (1024, 512, 256), + anchor_generator: ConfigType = dict( + type='YOLOAnchorGenerator', + base_sizes=[[(116, 90), (156, 198), (373, 326)], + [(30, 61), (62, 45), (59, 119)], + [(10, 13), (16, 30), (33, 23)]], + strides=[32, 16, 8]), + bbox_coder: ConfigType = dict(type='YOLOBBoxCoder'), + featmap_strides: Sequence[int] = (32, 16, 8), + one_hot_smoother: float = 0., + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + act_cfg: ConfigType = dict( + type='LeakyReLU', negative_slope=0.1), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_conf: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_xy: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_wh: ConfigType = dict(type='MSELoss', loss_weight=1.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None) -> None: + super().__init__(init_cfg=None) + # Check params + assert (len(in_channels) == len(out_channels) == len(featmap_strides)) + + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.featmap_strides = featmap_strides + self.train_cfg = train_cfg + self.test_cfg = test_cfg + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + if train_cfg.get('sampler', None) is not None: + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], context=self) + else: + self.sampler = PseudoSampler() + + self.one_hot_smoother = one_hot_smoother + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + + self.prior_generator = TASK_UTILS.build(anchor_generator) + + self.loss_cls = MODELS.build(loss_cls) + self.loss_conf = MODELS.build(loss_conf) + self.loss_xy = MODELS.build(loss_xy) + self.loss_wh = MODELS.build(loss_wh) + + self.num_base_priors = self.prior_generator.num_base_priors[0] + assert len( + self.prior_generator.num_base_priors) == len(featmap_strides) + self._init_layers() + + @property + def num_levels(self) -> int: + """int: number of feature map levels""" + return len(self.featmap_strides) + + @property + def num_attrib(self) -> int: + """int: number of attributes in pred_map, bboxes (4) + + objectness (1) + num_classes""" + + return 5 + self.num_classes + + def _init_layers(self) -> None: + """initialize conv layers in YOLOv3 head.""" + self.convs_bridge = nn.ModuleList() + self.convs_pred = nn.ModuleList() + for i in range(self.num_levels): + conv_bridge = ConvModule( + self.in_channels[i], + self.out_channels[i], + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + conv_pred = nn.Conv2d(self.out_channels[i], + self.num_base_priors * self.num_attrib, 1) + + self.convs_bridge.append(conv_bridge) + self.convs_pred.append(conv_pred) + + def init_weights(self) -> None: + """initialize weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + + # Use prior in model initialization to improve stability + for conv_pred, stride in zip(self.convs_pred, self.featmap_strides): + bias = conv_pred.bias.reshape(self.num_base_priors, -1) + # init objectness with prior of 8 objects per feature map + # refer to https://github.com/ultralytics/yolov3 + nn.init.constant_(bias.data[:, 4], + bias_init_with_prob(8 / (608 / stride)**2)) + nn.init.constant_(bias.data[:, 5:], bias_init_with_prob(0.01)) + + def forward(self, x: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple[Tensor]: A tuple of multi-level predication map, each is a + 4D-tensor of shape (batch_size, 5+num_classes, height, width). + """ + + assert len(x) == self.num_levels + pred_maps = [] + for i in range(self.num_levels): + feat = x[i] + feat = self.convs_bridge[i](feat) + pred_map = self.convs_pred[i](feat) + pred_maps.append(pred_map) + + return tuple(pred_maps), + + def predict_by_feat(self, + pred_maps: Sequence[Tensor], + batch_img_metas: Optional[List[dict]], + cfg: OptConfigType = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. It has been accelerated since PR #5991. + + Args: + pred_maps (Sequence[Tensor]): Raw predictions for a batch of + images. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (:obj:`ConfigDict` or dict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(pred_maps) == self.num_levels + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + + num_imgs = len(batch_img_metas) + featmap_sizes = [pred_map.shape[-2:] for pred_map in pred_maps] + + mlvl_anchors = self.prior_generator.grid_priors( + featmap_sizes, device=pred_maps[0].device) + flatten_preds = [] + flatten_strides = [] + for pred, stride in zip(pred_maps, self.featmap_strides): + pred = pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.num_attrib) + pred[..., :2].sigmoid_() + flatten_preds.append(pred) + flatten_strides.append( + pred.new_tensor(stride).expand(pred.size(1))) + + flatten_preds = torch.cat(flatten_preds, dim=1) + flatten_bbox_preds = flatten_preds[..., :4] + flatten_objectness = flatten_preds[..., 4].sigmoid() + flatten_cls_scores = flatten_preds[..., 5:].sigmoid() + flatten_anchors = torch.cat(mlvl_anchors) + flatten_strides = torch.cat(flatten_strides) + flatten_bboxes = self.bbox_coder.decode(flatten_anchors, + flatten_bbox_preds, + flatten_strides.unsqueeze(-1)) + results_list = [] + for (bboxes, scores, objectness, + img_meta) in zip(flatten_bboxes, flatten_cls_scores, + flatten_objectness, batch_img_metas): + # Filtering out all predictions with conf < conf_thr + conf_thr = cfg.get('conf_thr', -1) + if conf_thr > 0: + conf_inds = objectness >= conf_thr + bboxes = bboxes[conf_inds, :] + scores = scores[conf_inds, :] + objectness = objectness[conf_inds] + + score_thr = cfg.get('score_thr', 0) + nms_pre = cfg.get('nms_pre', -1) + scores, labels, keep_idxs, _ = filter_scores_and_topk( + scores, score_thr, nms_pre) + + results = InstanceData( + scores=scores, + labels=labels, + bboxes=bboxes[keep_idxs], + score_factors=objectness[keep_idxs], + ) + results = self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + results_list.append(results) + return results_list + + def loss_by_feat( + self, + pred_maps: Sequence[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + pred_maps (list[Tensor]): Prediction map for each scale level, + shape (N, num_anchors * num_attrib, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + num_imgs = len(batch_img_metas) + device = pred_maps[0][0].device + + featmap_sizes = [ + pred_maps[i].shape[-2:] for i in range(self.num_levels) + ] + mlvl_anchors = self.prior_generator.grid_priors( + featmap_sizes, device=device) + anchor_list = [mlvl_anchors for _ in range(num_imgs)] + + responsible_flag_list = [] + for img_id in range(num_imgs): + responsible_flag_list.append( + self.responsible_flags(featmap_sizes, + batch_gt_instances[img_id].bboxes, + device)) + + target_maps_list, neg_maps_list = self.get_targets( + anchor_list, responsible_flag_list, batch_gt_instances) + + losses_cls, losses_conf, losses_xy, losses_wh = multi_apply( + self.loss_by_feat_single, pred_maps, target_maps_list, + neg_maps_list) + + return dict( + loss_cls=losses_cls, + loss_conf=losses_conf, + loss_xy=losses_xy, + loss_wh=losses_wh) + + def loss_by_feat_single(self, pred_map: Tensor, target_map: Tensor, + neg_map: Tensor) -> tuple: + """Calculate the loss of a single scale level based on the features + extracted by the detection head. + + Args: + pred_map (Tensor): Raw predictions for a single level. + target_map (Tensor): The Ground-Truth target for a single level. + neg_map (Tensor): The negative masks for a single level. + + Returns: + tuple: + loss_cls (Tensor): Classification loss. + loss_conf (Tensor): Confidence loss. + loss_xy (Tensor): Regression loss of x, y coordinate. + loss_wh (Tensor): Regression loss of w, h coordinate. + """ + + num_imgs = len(pred_map) + pred_map = pred_map.permute(0, 2, 3, + 1).reshape(num_imgs, -1, self.num_attrib) + neg_mask = neg_map.float() + pos_mask = target_map[..., 4] + pos_and_neg_mask = neg_mask + pos_mask + pos_mask = pos_mask.unsqueeze(dim=-1) + if torch.max(pos_and_neg_mask) > 1.: + warnings.warn('There is overlap between pos and neg sample.') + pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.) + + pred_xy = pred_map[..., :2] + pred_wh = pred_map[..., 2:4] + pred_conf = pred_map[..., 4] + pred_label = pred_map[..., 5:] + + target_xy = target_map[..., :2] + target_wh = target_map[..., 2:4] + target_conf = target_map[..., 4] + target_label = target_map[..., 5:] + + loss_cls = self.loss_cls(pred_label, target_label, weight=pos_mask) + loss_conf = self.loss_conf( + pred_conf, target_conf, weight=pos_and_neg_mask) + loss_xy = self.loss_xy(pred_xy, target_xy, weight=pos_mask) + loss_wh = self.loss_wh(pred_wh, target_wh, weight=pos_mask) + + return loss_cls, loss_conf, loss_xy, loss_wh + + def get_targets(self, anchor_list: List[List[Tensor]], + responsible_flag_list: List[List[Tensor]], + batch_gt_instances: List[InstanceData]) -> tuple: + """Compute target maps for anchors in multiple images. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_total_anchors, 4). + responsible_flag_list (list[list[Tensor]]): Multi level responsible + flags of each image. Each element is a tensor of shape + (num_total_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: Usually returns a tuple containing learning targets. + - target_map_list (list[Tensor]): Target map of each level. + - neg_map_list (list[Tensor]): Negative map of each level. + """ + num_imgs = len(anchor_list) + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + + results = multi_apply(self._get_targets_single, anchor_list, + responsible_flag_list, batch_gt_instances) + + all_target_maps, all_neg_maps = results + assert num_imgs == len(all_target_maps) == len(all_neg_maps) + target_maps_list = images_to_levels(all_target_maps, num_level_anchors) + neg_maps_list = images_to_levels(all_neg_maps, num_level_anchors) + + return target_maps_list, neg_maps_list + + def _get_targets_single(self, anchors: List[Tensor], + responsible_flags: List[Tensor], + gt_instances: InstanceData) -> tuple: + """Generate matching bounding box prior and converted GT. + + Args: + anchors (List[Tensor]): Multi-level anchors of the image. + responsible_flags (List[Tensor]): Multi-level responsible flags of + anchors + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + + Returns: + tuple: + target_map (Tensor): Predication target map of each + scale level, shape (num_total_anchors, + 5+num_classes) + neg_map (Tensor): Negative map of each scale level, + shape (num_total_anchors,) + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + anchor_strides = [] + for i in range(len(anchors)): + anchor_strides.append( + torch.tensor(self.featmap_strides[i], + device=gt_bboxes.device).repeat(len(anchors[i]))) + concat_anchors = torch.cat(anchors) + concat_responsible_flags = torch.cat(responsible_flags) + + anchor_strides = torch.cat(anchor_strides) + assert len(anchor_strides) == len(concat_anchors) == \ + len(concat_responsible_flags) + pred_instances = InstanceData( + priors=concat_anchors, responsible_flags=concat_responsible_flags) + + assign_result = self.assigner.assign(pred_instances, gt_instances) + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + target_map = concat_anchors.new_zeros( + concat_anchors.size(0), self.num_attrib) + + target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode( + sampling_result.pos_priors, sampling_result.pos_gt_bboxes, + anchor_strides[sampling_result.pos_inds]) + + target_map[sampling_result.pos_inds, 4] = 1 + + gt_labels_one_hot = F.one_hot( + gt_labels, num_classes=self.num_classes).float() + if self.one_hot_smoother != 0: # label smooth + gt_labels_one_hot = gt_labels_one_hot * ( + 1 - self.one_hot_smoother + ) + self.one_hot_smoother / self.num_classes + target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[ + sampling_result.pos_assigned_gt_inds] + + neg_map = concat_anchors.new_zeros( + concat_anchors.size(0), dtype=torch.uint8) + neg_map[sampling_result.neg_inds] = 1 + + return target_map, neg_map + + def responsible_flags(self, featmap_sizes: List[tuple], gt_bboxes: Tensor, + device: str) -> List[Tensor]: + """Generate responsible anchor flags of grid cells in multiple scales. + + Args: + featmap_sizes (List[tuple]): List of feature map sizes in multiple + feature levels. + gt_bboxes (Tensor): Ground truth boxes, shape (n, 4). + device (str): Device where the anchors will be put on. + + Return: + List[Tensor]: responsible flags of anchors in multiple level + """ + assert self.num_levels == len(featmap_sizes) + multi_level_responsible_flags = [] + for i in range(self.num_levels): + anchor_stride = self.prior_generator.strides[i] + feat_h, feat_w = featmap_sizes[i] + gt_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device) + gt_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device) + gt_grid_x = torch.floor(gt_cx / anchor_stride[0]).long() + gt_grid_y = torch.floor(gt_cy / anchor_stride[1]).long() + # row major indexing + gt_bboxes_grid_idx = gt_grid_y * feat_w + gt_grid_x + + responsible_grid = torch.zeros( + feat_h * feat_w, dtype=torch.uint8, device=device) + responsible_grid[gt_bboxes_grid_idx] = 1 + + responsible_grid = responsible_grid[:, None].expand( + responsible_grid.size(0), + self.prior_generator.num_base_priors[i]).contiguous().view(-1) + + multi_level_responsible_flags.append(responsible_grid) + return multi_level_responsible_flags diff --git a/mmdet/models/dense_heads/yolof_head.py b/mmdet/models/dense_heads/yolof_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e5e6b7a92861bcd2ba3824df1f94270ba51160 --- /dev/null +++ b/mmdet/models/dense_heads/yolof_head.py @@ -0,0 +1,399 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, is_norm +from mmengine.model import bias_init_with_prob, constant_init, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from ..task_modules.prior_generators import anchor_inside_flags +from ..utils import levels_to_images, multi_apply, unmap +from .anchor_head import AnchorHead + +INF = 1e8 + + +@MODELS.register_module() +class YOLOFHead(AnchorHead): + """Detection Head of `YOLOF `_ + + Args: + num_classes (int): The number of object classes (w/o background) + in_channels (list[int]): The number of input channels per scale. + cls_num_convs (int): The number of convolutions of cls branch. + Defaults to 2. + reg_num_convs (int): The number of convolutions of reg branch. + Defaults to 4. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to ``dict(type='BN', requires_grad=True)``. + """ + + def __init__(self, + num_classes: int, + in_channels: List[int], + num_cls_convs: int = 2, + num_reg_convs: int = 4, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + **kwargs) -> None: + self.num_cls_convs = num_cls_convs + self.num_reg_convs = num_reg_convs + self.norm_cfg = norm_cfg + super().__init__( + num_classes=num_classes, in_channels=in_channels, **kwargs) + + def _init_layers(self) -> None: + cls_subnet = [] + bbox_subnet = [] + for i in range(self.num_cls_convs): + cls_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + for i in range(self.num_reg_convs): + bbox_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + self.cls_subnet = nn.Sequential(*cls_subnet) + self.bbox_subnet = nn.Sequential(*bbox_subnet) + self.cls_score = nn.Conv2d( + self.in_channels, + self.num_base_priors * self.num_classes, + kernel_size=3, + stride=1, + padding=1) + self.bbox_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors * 4, + kernel_size=3, + stride=1, + padding=1) + self.object_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors, + kernel_size=3, + stride=1, + padding=1) + + def init_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + + # Use prior in model initialization to improve stability + bias_cls = bias_init_with_prob(0.01) + torch.nn.init.constant_(self.cls_score.bias, bias_cls) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + normalized_cls_score (Tensor): Normalized Cls scores for a \ + single scale level, the channels number is \ + num_base_priors * num_classes. + bbox_reg (Tensor): Box energies / deltas for a single scale \ + level, the channels number is num_base_priors * 4. + """ + cls_score = self.cls_score(self.cls_subnet(x)) + N, _, H, W = cls_score.shape + cls_score = cls_score.view(N, -1, self.num_classes, H, W) + + reg_feat = self.bbox_subnet(x) + bbox_reg = self.bbox_pred(reg_feat) + objectness = self.object_pred(reg_feat) + + # implicit objectness + objectness = objectness.view(N, -1, 1, H, W) + normalized_cls_score = cls_score + objectness - torch.log( + 1. + torch.clamp(cls_score.exp(), max=INF) + + torch.clamp(objectness.exp(), max=INF)) + normalized_cls_score = normalized_cls_score.view(N, -1, H, W) + return normalized_cls_score, bbox_reg + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + assert len(cls_scores) == 1 + assert self.prior_generator.num_levels == 1 + + device = cls_scores[0].device + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + # The output level is always 1 + anchor_list = [anchors[0] for anchors in anchor_list] + valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list] + + cls_scores_list = levels_to_images(cls_scores) + bbox_preds_list = levels_to_images(bbox_preds) + + cls_reg_targets = self.get_targets( + cls_scores_list, + bbox_preds_list, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + if cls_reg_targets is None: + return None + (batch_labels, batch_label_weights, avg_factor, batch_bbox_weights, + batch_pos_predicted_boxes, batch_target_boxes) = cls_reg_targets + + flatten_labels = batch_labels.reshape(-1) + batch_label_weights = batch_label_weights.reshape(-1) + cls_score = cls_scores[0].permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + # classification loss + loss_cls = self.loss_cls( + cls_score, + flatten_labels, + batch_label_weights, + avg_factor=avg_factor) + + # regression loss + if batch_pos_predicted_boxes.shape[0] == 0: + # no pos sample + loss_bbox = batch_pos_predicted_boxes.sum() * 0 + else: + loss_bbox = self.loss_bbox( + batch_pos_predicted_boxes, + batch_target_boxes, + batch_bbox_weights.float(), + avg_factor=avg_factor) + + return dict(loss_cls=loss_cls, loss_bbox=loss_bbox) + + def get_targets(self, + cls_scores_list: List[Tensor], + bbox_preds_list: List[Tensor], + anchor_list: List[Tensor], + valid_flag_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores_list (list[Tensor]): Classification scores of + each image. each is a 4D-tensor, the shape is + (h * w, num_anchors * num_classes). + bbox_preds_list (list[Tensor]): Bbox preds of each image. + each is a 4D-tensor, the shape is (h * w, num_anchors * 4). + anchor_list (list[Tensor]): Anchors of each image. Each element of + is a tensor of shape (h * w * num_anchors, 4). + valid_flag_list (list[Tensor]): Valid flags of each image. Each + element of is a tensor of shape (h * w * num_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - batch_labels (Tensor): Label of all images. Each element \ + of is a tensor of shape (batch, h * w * num_anchors) + - batch_label_weights (Tensor): Label weights of all images \ + of is a tensor of shape (batch, h * w * num_anchors) + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + results = multi_apply( + self._get_targets_single, + bbox_preds_list, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + (all_labels, all_label_weights, pos_inds, neg_inds, + sampling_results_list) = results[:5] + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + rest_results = list(results[5:]) # user-added return values + + batch_labels = torch.stack(all_labels, 0) + batch_label_weights = torch.stack(all_label_weights, 0) + + res = (batch_labels, batch_label_weights, avg_factor) + for i, rests in enumerate(rest_results): # user-added return values + rest_results[i] = torch.cat(rests, 0) + + return res + tuple(rest_results) + + def _get_targets_single(self, + bbox_preds: Tensor, + flat_anchors: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + Args: + bbox_preds (Tensor): Bbox prediction of the image, which + shape is (h * w ,4) + flat_anchors (Tensor): Anchors of the image, which shape is + (h * w * num_anchors ,4) + valid_flags (Tensor): Valid flags of the image, which shape is + (h * w * num_anchors,). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: + labels (Tensor): Labels of image, which shape is + (h * w * num_anchors, ). + label_weights (Tensor): Label weights of image, which shape is + (h * w * num_anchors, ). + pos_inds (Tensor): Pos index of image. + neg_inds (Tensor): Neg index of image. + sampling_result (obj:`SamplingResult`): Sampling result. + pos_bbox_weights (Tensor): The Weight of using to calculate + the bbox branch loss, which shape is (num, ). + pos_predicted_boxes (Tensor): boxes predicted value of + using to calculate the bbox branch loss, which shape is + (num, 4). + pos_target_boxes (Tensor): boxes target value of + using to calculate the bbox branch loss, which shape is + (num, 4). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + bbox_preds = bbox_preds.reshape(-1, 4) + bbox_preds = bbox_preds[inside_flags, :] + + # decoded bbox + decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds) + pred_instances = InstanceData( + priors=anchors, decoder_priors=decoder_bbox_preds) + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + + pos_bbox_weights = assign_result.get_extra_property('pos_idx') + pos_predicted_boxes = assign_result.get_extra_property( + 'pos_predicted_boxes') + pos_target_boxes = assign_result.get_extra_property('target_boxes') + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + num_valid_anchors = anchors.shape[0] + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + + return (labels, label_weights, pos_inds, neg_inds, sampling_result, + pos_bbox_weights, pos_predicted_boxes, pos_target_boxes) diff --git a/mmdet/models/dense_heads/yolox_head.py b/mmdet/models/dense_heads/yolox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..00fe1e42766e4ca0052cf31d2e940dfab73fb200 --- /dev/null +++ b/mmdet/models/dense_heads/yolox_head.py @@ -0,0 +1,618 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.ops.nms import batched_nms +from mmengine.config import ConfigDict +from mmengine.model import bias_init_with_prob +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh +from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList, + OptMultiConfig, reduce_mean) +from ..task_modules.prior_generators import MlvlPointGenerator +from ..task_modules.samplers import PseudoSampler +from ..utils import multi_apply +from .base_dense_head import BaseDenseHead + + +@MODELS.register_module() +class YOLOXHead(BaseDenseHead): + """YOLOXHead head used in `YOLOX `_. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels in stacking convs. + Defaults to 256 + stacked_convs (int): Number of stacking convs of the head. + Defaults to (8, 16, 32). + strides (Sequence[int]): Downsample factor of each feature map. + Defaults to None. + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Defaults to False. + dcn_on_last_conv (bool): If true, use dcn in the last layer of + towers. Defaults to False. + conv_bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias of conv will be set as True if `norm_cfg` is + None, otherwise False. Defaults to "auto". + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001). + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Defaults to None. + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. + loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss. + loss_l1 (:obj:`ConfigDict` or dict): Config of L1 loss. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + anchor head. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + anchor head. Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 2, + strides: Sequence[int] = (8, 16, 32), + use_depthwise: bool = False, + dcn_on_last_conv: bool = False, + conv_bias: Union[bool, str] = 'auto', + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='Swish'), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_l1: ConfigType = dict( + type='L1Loss', reduction='sum', loss_weight=1.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu') + ) -> None: + + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.cls_out_channels = num_classes + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.strides = strides + self.use_depthwise = use_depthwise + self.dcn_on_last_conv = dcn_on_last_conv + assert conv_bias == 'auto' or isinstance(conv_bias, bool) + self.conv_bias = conv_bias + self.use_sigmoid_cls = True + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.loss_cls: nn.Module = MODELS.build(loss_cls) + self.loss_bbox: nn.Module = MODELS.build(loss_bbox) + self.loss_obj: nn.Module = MODELS.build(loss_obj) + + self.use_l1 = False # This flag will be modified by hooks. + self.loss_l1: nn.Module = MODELS.build(loss_l1) + + self.prior_generator = MlvlPointGenerator(strides, offset=0) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + + if self.train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + # YOLOX does not support sampling + self.sampler = PseudoSampler() + + self._init_layers() + + def _init_layers(self) -> None: + """Initialize heads for all level feature maps.""" + self.multi_level_cls_convs = nn.ModuleList() + self.multi_level_reg_convs = nn.ModuleList() + self.multi_level_conv_cls = nn.ModuleList() + self.multi_level_conv_reg = nn.ModuleList() + self.multi_level_conv_obj = nn.ModuleList() + for _ in self.strides: + self.multi_level_cls_convs.append(self._build_stacked_convs()) + self.multi_level_reg_convs.append(self._build_stacked_convs()) + conv_cls, conv_reg, conv_obj = self._build_predictor() + self.multi_level_conv_cls.append(conv_cls) + self.multi_level_conv_reg.append(conv_reg) + self.multi_level_conv_obj.append(conv_obj) + + def _build_stacked_convs(self) -> nn.Sequential: + """Initialize conv layers of a single level head.""" + conv = DepthwiseSeparableConvModule \ + if self.use_depthwise else ConvModule + stacked_convs = [] + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + if self.dcn_on_last_conv and i == self.stacked_convs - 1: + conv_cfg = dict(type='DCNv2') + else: + conv_cfg = self.conv_cfg + stacked_convs.append( + conv( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=self.conv_bias)) + return nn.Sequential(*stacked_convs) + + def _build_predictor(self) -> Tuple[nn.Module, nn.Module, nn.Module]: + """Initialize predictor layers of a single level head.""" + conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1) + conv_reg = nn.Conv2d(self.feat_channels, 4, 1) + conv_obj = nn.Conv2d(self.feat_channels, 1, 1) + return conv_cls, conv_reg, conv_obj + + def init_weights(self) -> None: + """Initialize weights of the head.""" + super(YOLOXHead, self).init_weights() + # Use prior in model initialization to improve stability + bias_init = bias_init_with_prob(0.01) + for conv_cls, conv_obj in zip(self.multi_level_conv_cls, + self.multi_level_conv_obj): + conv_cls.bias.data.fill_(bias_init) + conv_obj.bias.data.fill_(bias_init) + + def forward_single(self, x: Tensor, cls_convs: nn.Module, + reg_convs: nn.Module, conv_cls: nn.Module, + conv_reg: nn.Module, + conv_obj: nn.Module) -> Tuple[Tensor, Tensor, Tensor]: + """Forward feature of a single scale level.""" + + cls_feat = cls_convs(x) + reg_feat = reg_convs(x) + + cls_score = conv_cls(cls_feat) + bbox_pred = conv_reg(reg_feat) + objectness = conv_obj(reg_feat) + + return cls_score, bbox_pred, objectness + + def forward(self, x: Tuple[Tensor]) -> Tuple[List]: + """Forward features from the upstream network. + + Args: + x (Tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + Returns: + Tuple[List]: A tuple of multi-level classification scores, bbox + predictions, and objectnesses. + """ + + return multi_apply(self.forward_single, x, self.multi_level_cls_convs, + self.multi_level_reg_convs, + self.multi_level_conv_cls, + self.multi_level_conv_reg, + self.multi_level_conv_obj) + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + objectnesses: Optional[List[Tensor]], + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> List[InstanceData]: + """Transform a batch of output features extracted by the head into + bbox results. + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + objectnesses (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, 1, H, W). + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) == len(objectnesses) + cfg = self.test_cfg if cfg is None else cfg + + num_imgs = len(batch_img_metas) + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device, + with_stride=True) + + # flatten cls_scores, bbox_preds and objectness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in bbox_preds + ] + flatten_objectness = [ + objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) + for objectness in objectnesses + ] + + flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() + flatten_priors = torch.cat(mlvl_priors) + + flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds) + + result_list = [] + for img_id, img_meta in enumerate(batch_img_metas): + max_scores, labels = torch.max(flatten_cls_scores[img_id], 1) + valid_mask = flatten_objectness[ + img_id] * max_scores >= cfg.score_thr + results = InstanceData( + bboxes=flatten_bboxes[img_id][valid_mask], + scores=max_scores[valid_mask] * + flatten_objectness[img_id][valid_mask], + labels=labels[valid_mask]) + + result_list.append( + self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta)) + + return result_list + + def _bbox_decode(self, priors: Tensor, bbox_preds: Tensor) -> Tensor: + """Decode regression results (delta_x, delta_x, w, h) to bboxes (tl_x, + tl_y, br_x, br_y). + + Args: + priors (Tensor): Center proiors of an image, has shape + (num_instances, 2). + bbox_preds (Tensor): Box energies / deltas for all instances, + has shape (batch_size, num_instances, 4). + + Returns: + Tensor: Decoded bboxes in (tl_x, tl_y, br_x, br_y) format. Has + shape (batch_size, num_instances, 4). + """ + xys = (bbox_preds[..., :2] * priors[:, 2:]) + priors[:, :2] + whs = bbox_preds[..., 2:].exp() * priors[:, 2:] + + tl_x = (xys[..., 0] - whs[..., 0] / 2) + tl_y = (xys[..., 1] - whs[..., 1] / 2) + br_x = (xys[..., 0] + whs[..., 0] / 2) + br_y = (xys[..., 1] + whs[..., 1] / 2) + + decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) + return decoded_bboxes + + def _bbox_post_process(self, + results: InstanceData, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> InstanceData: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + + if rescale: + assert img_meta.get('scale_factor') is not None + results.bboxes /= results.bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + + if with_nms and results.bboxes.numel() > 0: + det_bboxes, keep_idxs = batched_nms(results.bboxes, results.scores, + results.labels, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + return results + + def loss_by_feat( + self, + cls_scores: Sequence[Tensor], + bbox_preds: Sequence[Tensor], + objectnesses: Sequence[Tensor], + batch_gt_instances: Sequence[InstanceData], + batch_img_metas: Sequence[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (Sequence[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_priors * num_classes. + bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_priors * 4. + objectnesses (Sequence[Tensor]): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, 1, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + Returns: + dict[str, Tensor]: A dictionary of losses. + """ + num_imgs = len(batch_img_metas) + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device, + with_stride=True) + + flatten_cls_preds = [ + cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_pred in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in bbox_preds + ] + flatten_objectness = [ + objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) + for objectness in objectnesses + ] + + flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1) + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + flatten_objectness = torch.cat(flatten_objectness, dim=1) + flatten_priors = torch.cat(mlvl_priors) + flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds) + + (pos_masks, cls_targets, obj_targets, bbox_targets, l1_targets, + num_fg_imgs) = multi_apply( + self._get_targets_single, + flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1), + flatten_cls_preds.detach(), flatten_bboxes.detach(), + flatten_objectness.detach(), batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + + # The experimental results show that 'reduce_mean' can improve + # performance on the COCO dataset. + num_pos = torch.tensor( + sum(num_fg_imgs), + dtype=torch.float, + device=flatten_cls_preds.device) + num_total_samples = max(reduce_mean(num_pos), 1.0) + + pos_masks = torch.cat(pos_masks, 0) + cls_targets = torch.cat(cls_targets, 0) + obj_targets = torch.cat(obj_targets, 0) + bbox_targets = torch.cat(bbox_targets, 0) + if self.use_l1: + l1_targets = torch.cat(l1_targets, 0) + + loss_obj = self.loss_obj(flatten_objectness.view(-1, 1), + obj_targets) / num_total_samples + if num_pos > 0: + loss_cls = self.loss_cls( + flatten_cls_preds.view(-1, self.num_classes)[pos_masks], + cls_targets) / num_total_samples + loss_bbox = self.loss_bbox( + flatten_bboxes.view(-1, 4)[pos_masks], + bbox_targets) / num_total_samples + else: + # Avoid cls and reg branch not participating in the gradient + # propagation when there is no ground-truth in the images. + # For more details, please refer to + # https://github.com/open-mmlab/mmdetection/issues/7298 + loss_cls = flatten_cls_preds.sum() * 0 + loss_bbox = flatten_bboxes.sum() * 0 + + loss_dict = dict( + loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj) + + if self.use_l1: + if num_pos > 0: + loss_l1 = self.loss_l1( + flatten_bbox_preds.view(-1, 4)[pos_masks], + l1_targets) / num_total_samples + else: + # Avoid cls and reg branch not participating in the gradient + # propagation when there is no ground-truth in the images. + # For more details, please refer to + # https://github.com/open-mmlab/mmdetection/issues/7298 + loss_l1 = flatten_bbox_preds.sum() * 0 + loss_dict.update(loss_l1=loss_l1) + + return loss_dict + + @torch.no_grad() + def _get_targets_single( + self, + priors: Tensor, + cls_preds: Tensor, + decoded_bboxes: Tensor, + objectness: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None) -> tuple: + """Compute classification, regression, and objectness targets for + priors in a single image. + + Args: + priors (Tensor): All priors of one image, a 2D-Tensor with shape + [num_priors, 4] in [cx, xy, stride_w, stride_y] format. + cls_preds (Tensor): Classification predictions of one image, + a 2D-Tensor with shape [num_priors, num_classes] + decoded_bboxes (Tensor): Decoded bboxes predictions of one image, + a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y, + br_x, br_y] format. + objectness (Tensor): Objectness predictions of one image, + a 1D-Tensor with shape [num_priors] + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + Returns: + tuple: + foreground_mask (list[Tensor]): Binary mask of foreground + targets. + cls_target (list[Tensor]): Classification targets of an image. + obj_target (list[Tensor]): Objectness targets of an image. + bbox_target (list[Tensor]): BBox targets of an image. + l1_target (int): BBox L1 targets of an image. + num_pos_per_img (int): Number of positive samples in an image. + """ + + num_priors = priors.size(0) + num_gts = len(gt_instances) + # No target + if num_gts == 0: + cls_target = cls_preds.new_zeros((0, self.num_classes)) + bbox_target = cls_preds.new_zeros((0, 4)) + l1_target = cls_preds.new_zeros((0, 4)) + obj_target = cls_preds.new_zeros((num_priors, 1)) + foreground_mask = cls_preds.new_zeros(num_priors).bool() + return (foreground_mask, cls_target, obj_target, bbox_target, + l1_target, 0) + + # YOLOX uses center priors with 0.5 offset to assign targets, + # but use center priors without offset to regress bboxes. + offset_priors = torch.cat( + [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1) + + scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid() + pred_instances = InstanceData( + bboxes=decoded_bboxes, scores=scores.sqrt_(), priors=offset_priors) + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=gt_instances, + gt_instances_ignore=gt_instances_ignore) + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + pos_inds = sampling_result.pos_inds + num_pos_per_img = pos_inds.size(0) + + pos_ious = assign_result.max_overlaps[pos_inds] + # IOU aware classification score + cls_target = F.one_hot(sampling_result.pos_gt_labels, + self.num_classes) * pos_ious.unsqueeze(-1) + obj_target = torch.zeros_like(objectness).unsqueeze(-1) + obj_target[pos_inds] = 1 + bbox_target = sampling_result.pos_gt_bboxes + l1_target = cls_preds.new_zeros((num_pos_per_img, 4)) + if self.use_l1: + l1_target = self._get_l1_target(l1_target, bbox_target, + priors[pos_inds]) + foreground_mask = torch.zeros_like(objectness).to(torch.bool) + foreground_mask[pos_inds] = 1 + return (foreground_mask, cls_target, obj_target, bbox_target, + l1_target, num_pos_per_img) + + def _get_l1_target(self, + l1_target: Tensor, + gt_bboxes: Tensor, + priors: Tensor, + eps: float = 1e-8) -> Tensor: + """Convert gt bboxes to center offset and log width height.""" + gt_cxcywh = bbox_xyxy_to_cxcywh(gt_bboxes) + l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:] + l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps) + return l1_target diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a06d2813c810504e12592506be9347111d6696 --- /dev/null +++ b/mmdet/models/detectors/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .atss import ATSS +from .autoassign import AutoAssign +from .base import BaseDetector +from .base_detr import DetectionTransformer +from .boxinst import BoxInst +from .cascade_rcnn import CascadeRCNN +from .centernet import CenterNet +from .condinst import CondInst +from .conditional_detr import ConditionalDETR +from .cornernet import CornerNet +from .crowddet import CrowdDet +from .d2_wrapper import Detectron2Wrapper +from .dab_detr import DABDETR +from .ddod import DDOD +from .ddq_detr import DDQDETR +from .deformable_detr import DeformableDETR +from .detr import DETR +from .dino import DINO +from .fast_rcnn import FastRCNN +from .faster_rcnn import FasterRCNN +from .fcos import FCOS +from .fovea import FOVEA +from .fsaf import FSAF +from .gfl import GFL +from .glip import GLIP +from .grid_rcnn import GridRCNN +from .grounding_dino import GroundingDINO +from .htc import HybridTaskCascade +from .kd_one_stage import KnowledgeDistillationSingleStageDetector +from .lad import LAD +from .mask2former import Mask2Former +from .mask_rcnn import MaskRCNN +from .mask_scoring_rcnn import MaskScoringRCNN +from .maskformer import MaskFormer +from .nasfcos import NASFCOS +from .paa import PAA +from .panoptic_fpn import PanopticFPN +from .panoptic_two_stage_segmentor import TwoStagePanopticSegmentor +from .point_rend import PointRend +from .queryinst import QueryInst +from .reppoints_detector import RepPointsDetector +from .retinanet import RetinaNet +from .rpn import RPN +from .rtmdet import RTMDet +from .scnet import SCNet +from .semi_base import SemiBaseDetector +from .single_stage import SingleStageDetector +from .soft_teacher import SoftTeacher +from .solo import SOLO +from .solov2 import SOLOv2 +from .sparse_rcnn import SparseRCNN +from .tood import TOOD +from .trident_faster_rcnn import TridentFasterRCNN +from .two_stage import TwoStageDetector +from .vfnet import VFNet +from .yolact import YOLACT +from .yolo import YOLOV3 +from .yolof import YOLOF +from .yolox import YOLOX + +__all__ = [ + 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', + 'KnowledgeDistillationSingleStageDetector', 'FastRCNN', 'FasterRCNN', + 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', + 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF', + 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', + 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', + 'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', + 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD', + 'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher', + 'RTMDet', 'Detectron2Wrapper', 'CrowdDet', 'CondInst', 'BoxInst', + 'DetectionTransformer', 'ConditionalDETR', 'DINO', 'DABDETR', 'GLIP', + 'DDQDETR', 'GroundingDINO' +] diff --git a/mmdet/models/detectors/atss.py b/mmdet/models/detectors/atss.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfcc728dc4cc33c0b705a2ab22a4e3f4ad7386d --- /dev/null +++ b/mmdet/models/detectors/atss.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class ATSS(SingleStageDetector): + """Implementation of `ATSS `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of ATSS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of ATSS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/autoassign.py b/mmdet/models/detectors/autoassign.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b3570fe6e0c3812a72bc677038bb4e76b05576 --- /dev/null +++ b/mmdet/models/detectors/autoassign.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class AutoAssign(SingleStageDetector): + """Implementation of `AutoAssign: Differentiable Label Assignment for Dense + Object Detection `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of AutoAssign. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of AutoAssign. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1a193b0ca9ca3d2b42fda452004d5c97421f426c --- /dev/null +++ b/mmdet/models/detectors/base.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Tuple, Union + +import torch +from mmengine.model import BaseModel +from torch import Tensor + +from mmdet.structures import DetDataSample, OptSampleList, SampleList +from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig +from ..utils import samplelist_boxtype2tensor + +ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample], + Tuple[torch.Tensor], torch.Tensor] + + +class BaseDetector(BaseModel, metaclass=ABCMeta): + """Base class for detectors. + + Args: + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. Defaults to None. + """ + + def __init__(self, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + @property + def with_neck(self) -> bool: + """bool: whether the detector has a neck""" + return hasattr(self, 'neck') and self.neck is not None + + # TODO: these properties need to be carefully handled + # for both single stage & two stage detectors + @property + def with_shared_head(self) -> bool: + """bool: whether the detector has a shared head in the RoI Head""" + return hasattr(self, 'roi_head') and self.roi_head.with_shared_head + + @property + def with_bbox(self) -> bool: + """bool: whether the detector has a bbox head""" + return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox) + or (hasattr(self, 'bbox_head') and self.bbox_head is not None)) + + @property + def with_mask(self) -> bool: + """bool: whether the detector has a mask head""" + return ((hasattr(self, 'roi_head') and self.roi_head.with_mask) + or (hasattr(self, 'mask_head') and self.mask_head is not None)) + + def forward(self, + inputs: torch.Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DetDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle either back propagation or + parameter update, which are supposed to be done in :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + @abstractmethod + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, tuple]: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def predict(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing.""" + pass + + @abstractmethod + def _forward(self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None): + """Network forward process. + + Usually includes backbone, neck and head forward without any post- + processing. + """ + pass + + @abstractmethod + def extract_feat(self, batch_inputs: Tensor): + """Extract features from images.""" + pass + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: InstanceList) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + for data_sample, pred_instances in zip(data_samples, results_list): + data_sample.pred_instances = pred_instances + samplelist_boxtype2tensor(data_samples) + return data_samples diff --git a/mmdet/models/detectors/base_detr.py b/mmdet/models/detectors/base_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..88f00ec7408c389a1eb06beac6b383007f80b893 --- /dev/null +++ b/mmdet/models/detectors/base_detr.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Tuple, Union + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .base import BaseDetector + + +@MODELS.register_module() +class DetectionTransformer(BaseDetector, metaclass=ABCMeta): + r"""Base class for Detection Transformer. + + In Detection Transformer, an encoder is used to process output features of + neck, then several queries interact with the encoder features using a + decoder and do the regression and classification with the bounding box + head. + + Args: + backbone (:obj:`ConfigDict` or dict): Config of the backbone. + neck (:obj:`ConfigDict` or dict, optional): Config of the neck. + Defaults to None. + encoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer encoder. Defaults to None. + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + bbox_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding box head module. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict, optional): Config + of the positional encoding module. Defaults to None. + num_queries (int, optional): Number of decoder query in Transformer. + Defaults to 100. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + the bounding box head module. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + the bounding box head module. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + bbox_head: OptConfigType = None, + positional_encoding: OptConfigType = None, + num_queries: int = 100, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + # process args + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.encoder = encoder + self.decoder = decoder + self.positional_encoding = positional_encoding + self.num_queries = num_queries + + # init model layers + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self.bbox_head = MODELS.build(bbox_head) + self._init_layers() + + @abstractmethod + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + pass + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (bs, dim, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, + batch_data_samples) + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples) + + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the input images. + Each DetDataSample usually contain 'pred_instances'. And the + `pred_instances` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, + batch_data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). + batch_data_samples (List[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[Tensor]: A tuple of features from ``bbox_head`` forward. + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, + batch_data_samples) + results = self.bbox_head.forward(**head_inputs_dict) + return results + + def forward_transformer(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward process of Transformer, which includes four steps: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We + summarized the parameters flow of the existing DETR-like detector, + which can be illustrated as follow: + + .. code:: text + + img_feats & batch_data_samples + | + V + +-----------------+ + | pre_transformer | + +-----------------+ + | | + | V + | +-----------------+ + | | forward_encoder | + | +-----------------+ + | | + | V + | +---------------+ + | | pre_decoder | + | +---------------+ + | | | + V V | + +-----------------+ | + | forward_decoder | | + +-----------------+ | + | | + V V + head_inputs_dict + + Args: + img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each + feature map has shape (bs, dim, H, W). + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + dict: The dictionary of bbox_head function inputs, which always + includes the `hidden_states` of the decoder output and may contain + `references` including the initial and intermediate references. + """ + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) + + tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict) + decoder_inputs_dict.update(tmp_dec_in) + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + head_inputs_dict.update(decoder_outputs_dict) + return head_inputs_dict + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W). + + Returns: + tuple[Tensor]: Tuple of feature maps from neck. Each feature map + has shape (bs, dim, H, W). + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + @abstractmethod + def pre_transformer( + self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]: + """Process image features before feeding them to the transformer. + + Args: + img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each + feature map has shape (bs, dim, H, W). + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict, dict]: The first dict contains the inputs of encoder + and the second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + 'feat_pos', and other algorithm-specific arguments. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask', and + other algorithm-specific arguments. + """ + pass + + @abstractmethod + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, **kwargs) -> Dict: + """Forward with Transformer encoder. + + Args: + feat (Tensor): Sequential features, has shape (bs, num_feat_points, + dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (bs, num_feat_points). + feat_pos (Tensor): The positional embeddings of the features, has + shape (bs, num_feat_points, dim). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output and other algorithm-specific + arguments. + """ + pass + + @abstractmethod + def pre_decoder(self, memory: Tensor, **kwargs) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`, and `reference_points`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + + Returns: + tuple[dict, dict]: The first dict contains the inputs of decoder + and the second dict contains the inputs of the bbox_head function. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory', and other algorithm-specific arguments. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which is usually empty, or includes + `enc_outputs_class` and `enc_outputs_class` when the detector + support 'two stage' or 'query selection' strategies. + """ + pass + + @abstractmethod + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + **kwargs) -> Dict: + """Forward with Transformer decoder. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (bs, num_queries, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output, `references` including + the initial and intermediate reference_points, and other + algorithm-specific arguments. + """ + pass diff --git a/mmdet/models/detectors/boxinst.py b/mmdet/models/detectors/boxinst.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6b0bdd90a2a7e78f429a6822dbde6f809426da --- /dev/null +++ b/mmdet/models/detectors/boxinst.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@MODELS.register_module() +class BoxInst(SingleStageInstanceSegmentor): + """Implementation of `BoxInst `_""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + mask_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf733ff104b99436fcc74130b0ccea12a0fa6d0 --- /dev/null +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class CascadeRCNN(TwoStageDetector): + r"""Implementation of `Cascade R-CNN: Delving into High Quality Object + Detection `_""" + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + rpn_head: OptConfigType = None, + roi_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/centernet.py b/mmdet/models/detectors/centernet.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6622d6280227ecba9ede4aabf72c22a764e11d --- /dev/null +++ b/mmdet/models/detectors/centernet.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class CenterNet(SingleStageDetector): + """Implementation of CenterNet(Objects as Points) + + . + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/condinst.py b/mmdet/models/detectors/condinst.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2dc99eea3faf7b03a3970d46a372d28eb89fe1 --- /dev/null +++ b/mmdet/models/detectors/condinst.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@MODELS.register_module() +class CondInst(SingleStageInstanceSegmentor): + """Implementation of `CondInst `_""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + mask_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/conditional_detr.py b/mmdet/models/detectors/conditional_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..d57868e63a2ece085a7e5b67ee93c921ba334830 --- /dev/null +++ b/mmdet/models/detectors/conditional_detr.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from ..layers import (ConditionalDetrTransformerDecoder, + DetrTransformerEncoder, SinePositionalEncoding) +from .detr import DETR + + +@MODELS.register_module() +class ConditionalDETR(DETR): + r"""Implementation of `Conditional DETR for Fast Training Convergence. + + `_. + + Code is modified from the `official github repo + `_. + """ + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DetrTransformerEncoder(**self.encoder) + self.decoder = ConditionalDetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + # NOTE The embed_dims is typically passed from the inside out. + # For example in DETR, The embed_dims is passed as + # self_attn -> the first encoder layer -> encoder -> detector. + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + f'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, memory_pos: Tensor) -> Dict: + """Forward with Transformer decoder. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (bs, num_queries, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + memory_pos (Tensor): The positional embeddings of memory, has + shape (bs, num_feat_points, dim). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` and `references` of the decoder output. + + - hidden_states (Tensor): Has shape + (num_decoder_layers, bs, num_queries, dim) + - references (Tensor): Has shape + (bs, num_queries, 2) + """ + + hidden_states, references = self.decoder( + query=query, + key=memory, + query_pos=query_pos, + key_pos=memory_pos, + key_padding_mask=memory_mask) + head_inputs_dict = dict( + hidden_states=hidden_states, references=references) + return head_inputs_dict diff --git a/mmdet/models/detectors/cornernet.py b/mmdet/models/detectors/cornernet.py new file mode 100644 index 0000000000000000000000000000000000000000..946af4dbe6ae339d44f8db265ff7f11b9e02d239 --- /dev/null +++ b/mmdet/models/detectors/cornernet.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class CornerNet(SingleStageDetector): + """CornerNet. + + This detector is the implementation of the paper `CornerNet: Detecting + Objects as Paired Keypoints `_ . + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/crowddet.py b/mmdet/models/detectors/crowddet.py new file mode 100644 index 0000000000000000000000000000000000000000..4f43bc08aa95756324381ee4182f001a008613c8 --- /dev/null +++ b/mmdet/models/detectors/crowddet.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class CrowdDet(TwoStageDetector): + """Implementation of `CrowdDet `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + rpn_head (:obj:`ConfigDict` or dict): The rpn config. + roi_head (:obj:`ConfigDict` or dict): The roi config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of FCOS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of FCOS. Defaults to None. + neck (:obj:`ConfigDict` or dict): The neck config. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + rpn_head: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) diff --git a/mmdet/models/detectors/d2_wrapper.py b/mmdet/models/detectors/d2_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2daa413e8fe0397ec37008d781ce449e7a26fd --- /dev/null +++ b/mmdet/models/detectors/d2_wrapper.py @@ -0,0 +1,291 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import BaseBoxes +from mmdet.structures.mask import BitmapMasks, PolygonMasks +from mmdet.utils import ConfigType +from .base import BaseDetector + +try: + import detectron2 + from detectron2.config import get_cfg + from detectron2.modeling import build_model + from detectron2.structures.masks import BitMasks as D2_BitMasks + from detectron2.structures.masks import PolygonMasks as D2_PolygonMasks + from detectron2.utils.events import EventStorage +except ImportError: + detectron2 = None + + +def _to_cfgnode_list(cfg: ConfigType, + config_list: list = [], + father_name: str = 'MODEL') -> tuple: + """Convert the key and value of mmengine.ConfigDict into a list. + + Args: + cfg (ConfigDict): The detectron2 model config. + config_list (list): A list contains the key and value of ConfigDict. + Defaults to []. + father_name (str): The father name add before the key. + Defaults to "MODEL". + + Returns: + tuple: + + - config_list: A list contains the key and value of ConfigDict. + - father_name (str): The father name add before the key. + Defaults to "MODEL". + """ + for key, value in cfg.items(): + name = f'{father_name}.{key.upper()}' + if isinstance(value, ConfigDict) or isinstance(value, dict): + config_list, fater_name = \ + _to_cfgnode_list(value, config_list, name) + else: + config_list.append(name) + config_list.append(value) + + return config_list, father_name + + +def convert_d2_pred_to_datasample(data_samples: SampleList, + d2_results_list: list) -> SampleList: + """Convert the Detectron2's result to DetDataSample. + + Args: + data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + d2_results_list (list): The list of the results of Detectron2's model. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(data_samples) == len(d2_results_list) + for data_sample, d2_results in zip(data_samples, d2_results_list): + d2_instance = d2_results['instances'] + + results = InstanceData() + results.bboxes = d2_instance.pred_boxes.tensor + results.scores = d2_instance.scores + results.labels = d2_instance.pred_classes + + if d2_instance.has('pred_masks'): + results.masks = d2_instance.pred_masks + data_sample.pred_instances = results + + return data_samples + + +@MODELS.register_module() +class Detectron2Wrapper(BaseDetector): + """Wrapper of a Detectron2 model. Input/output formats of this class follow + MMDetection's convention, so a Detectron2 model can be trained and + evaluated in MMDetection. + + Args: + detector (:obj:`ConfigDict` or dict): The module config of + Detectron2. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to BGR. + Defaults to False. + """ + + def __init__(self, + detector: ConfigType, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False) -> None: + if detectron2 is None: + raise ImportError('Please install Detectron2 first') + assert not (bgr_to_rgb and rgb_to_bgr), ( + '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + super().__init__() + self._channel_conversion = rgb_to_bgr or bgr_to_rgb + cfgnode_list, _ = _to_cfgnode_list(detector) + self.cfg = get_cfg() + self.cfg.merge_from_list(cfgnode_list) + self.d2_model = build_model(self.cfg) + self.storage = EventStorage() + + def init_weights(self) -> None: + """Initialization Backbone. + + NOTE: The initialization of other layers are in Detectron2, + if users want to change the initialization way, please + change the code in Detectron2. + """ + from detectron2.checkpoint import DetectionCheckpointer + checkpointer = DetectionCheckpointer(model=self.d2_model) + checkpointer.load(self.cfg.MODEL.WEIGHTS, checkpointables=[]) + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, tuple]: + """Calculate losses from a batch of inputs and data samples. + + The inputs will first convert to the Detectron2 type and feed into + D2 models. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + d2_batched_inputs = self._convert_to_d2_inputs( + batch_inputs=batch_inputs, + batch_data_samples=batch_data_samples, + training=True) + + with self.storage as storage: # noqa + losses = self.d2_model(d2_batched_inputs) + # storage contains some training information, such as cls_accuracy. + # you can use storage.latest() to get the detail information + return losses + + def predict(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + The inputs will first convert to the Detectron2 type and feed into + D2 models. And the results will convert back to the MMDet type. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + d2_batched_inputs = self._convert_to_d2_inputs( + batch_inputs=batch_inputs, + batch_data_samples=batch_data_samples, + training=False) + # results in detectron2 has already rescale + d2_results_list = self.d2_model(d2_batched_inputs) + batch_data_samples = convert_d2_pred_to_datasample( + data_samples=batch_data_samples, d2_results_list=d2_results_list) + + return batch_data_samples + + def _forward(self, *args, **kwargs): + """Network forward process. + + Usually includes backbone, neck and head forward without any post- + processing. + """ + raise NotImplementedError( + f'`_forward` is not implemented in {self.__class__.__name__}') + + def extract_feat(self, *args, **kwargs): + """Extract features from images. + + `extract_feat` will not be used in obj:``Detectron2Wrapper``. + """ + pass + + def _convert_to_d2_inputs(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + training=True) -> list: + """Convert inputs type to support Detectron2's model. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + training (bool): Whether to enable training time processing. + + Returns: + list[dict]: A list of dict, which will be fed into Detectron2's + model. And the dict usually contains following keys. + + - image (Tensor): Image in (C, H, W) format. + - instances (Instances): GT Instance. + - height (int): the output height resolution of the model + - width (int): the output width resolution of the model + """ + from detectron2.data.detection_utils import filter_empty_instances + from detectron2.structures import Boxes, Instances + + batched_d2_inputs = [] + for image, data_samples in zip(batch_inputs, batch_data_samples): + d2_inputs = dict() + # deal with metainfo + meta_info = data_samples.metainfo + d2_inputs['file_name'] = meta_info['img_path'] + d2_inputs['height'], d2_inputs['width'] = meta_info['ori_shape'] + d2_inputs['image_id'] = meta_info['img_id'] + # deal with image + if self._channel_conversion: + image = image[[2, 1, 0], ...] + d2_inputs['image'] = image + # deal with gt_instances + gt_instances = data_samples.gt_instances + d2_instances = Instances(meta_info['img_shape']) + + gt_boxes = gt_instances.bboxes + # TODO: use mmdet.structures.box.get_box_tensor after PR 8658 + # has merged + if isinstance(gt_boxes, BaseBoxes): + gt_boxes = gt_boxes.tensor + d2_instances.gt_boxes = Boxes(gt_boxes) + + d2_instances.gt_classes = gt_instances.labels + if gt_instances.get('masks', None) is not None: + gt_masks = gt_instances.masks + if isinstance(gt_masks, PolygonMasks): + d2_instances.gt_masks = D2_PolygonMasks(gt_masks.masks) + elif isinstance(gt_masks, BitmapMasks): + d2_instances.gt_masks = D2_BitMasks(gt_masks.masks) + else: + raise TypeError('The type of `gt_mask` can be ' + '`PolygonMasks` or `BitMasks`, but get ' + f'{type(gt_masks)}.') + # convert to cpu and convert back to cuda to avoid + # some potential error + if training: + device = gt_boxes.device + d2_instances = filter_empty_instances( + d2_instances.to('cpu')).to(device) + d2_inputs['instances'] = d2_instances + batched_d2_inputs.append(d2_inputs) + + return batched_d2_inputs diff --git a/mmdet/models/detectors/dab_detr.py b/mmdet/models/detectors/dab_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..b61301cf6660924f0832f4068841a4664797c585 --- /dev/null +++ b/mmdet/models/detectors/dab_detr.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + +from mmengine.model import uniform_init +from torch import Tensor, nn + +from mmdet.registry import MODELS +from ..layers import SinePositionalEncoding +from ..layers.transformer import (DABDetrTransformerDecoder, + DABDetrTransformerEncoder, inverse_sigmoid) +from .detr import DETR + + +@MODELS.register_module() +class DABDETR(DETR): + r"""Implementation of `DAB-DETR: + Dynamic Anchor Boxes are Better Queries for DETR. + + `_. + + Code is modified from the `official github repo + `_. + + Args: + with_random_refpoints (bool): Whether to randomly initialize query + embeddings and not update them during training. + Defaults to False. + num_patterns (int): Inspired by Anchor-DETR. Defaults to 0. + """ + + def __init__(self, + *args, + with_random_refpoints: bool = False, + num_patterns: int = 0, + **kwargs) -> None: + self.with_random_refpoints = with_random_refpoints + assert isinstance(num_patterns, int), \ + f'num_patterns should be int but {num_patterns}.' + self.num_patterns = num_patterns + + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DABDetrTransformerEncoder(**self.encoder) + self.decoder = DABDetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + self.query_dim = self.decoder.query_dim + self.query_embedding = nn.Embedding(self.num_queries, self.query_dim) + if self.num_patterns > 0: + self.patterns = nn.Embedding(self.num_patterns, self.embed_dims) + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + f'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super(DABDETR, self).init_weights() + if self.with_random_refpoints: + uniform_init(self.query_embedding) + self.query_embedding.weight.data[:, :2] = \ + inverse_sigmoid(self.query_embedding.weight.data[:, :2]) + self.query_embedding.weight.data[:, :2].requires_grad = False + + def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + + Returns: + tuple[dict, dict]: The first dict contains the inputs of decoder + and the second dict contains the inputs of the bbox_head function. + + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory' and 'reg_branches'. + - head_inputs_dict (dict): The keyword args dictionary of the + bbox_head functions, which is usually empty, or includes + `enc_outputs_class` and `enc_outputs_class` when the detector + support 'two stage' or 'query selection' strategies. + """ + batch_size = memory.size(0) + query_pos = self.query_embedding.weight + query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1) + if self.num_patterns == 0: + query = query_pos.new_zeros(batch_size, self.num_queries, + self.embed_dims) + else: + query = self.patterns.weight[:, None, None, :]\ + .repeat(1, self.num_queries, batch_size, 1)\ + .view(-1, batch_size, self.embed_dims)\ + .permute(1, 0, 2) + query_pos = query_pos.repeat(1, self.num_patterns, 1) + + decoder_inputs_dict = dict( + query_pos=query_pos, query=query, memory=memory) + head_inputs_dict = dict() + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, memory_pos: Tensor) -> Dict: + """Forward with Transformer decoder. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (bs, num_queries, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + memory_pos (Tensor): The positional embeddings of memory, has + shape (bs, num_feat_points, dim). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` and `references` of the decoder output. + """ + + hidden_states, references = self.decoder( + query=query, + key=memory, + query_pos=query_pos, + key_pos=memory_pos, + key_padding_mask=memory_mask, + reg_branches=self.bbox_head. + fc_reg # iterative refinement for anchor boxes + ) + head_inputs_dict = dict( + hidden_states=hidden_states, references=references) + return head_inputs_dict diff --git a/mmdet/models/detectors/ddod.py b/mmdet/models/detectors/ddod.py new file mode 100644 index 0000000000000000000000000000000000000000..3503a40c8eb6d6c0496ea0f31740acecf774113a --- /dev/null +++ b/mmdet/models/detectors/ddod.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class DDOD(SingleStageDetector): + """Implementation of `DDOD `_. + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of ATSS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of ATSS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/ddq_detr.py b/mmdet/models/detectors/ddq_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..57d4959d50ddd7a761d5e5c7a29d1f7f233f838a --- /dev/null +++ b/mmdet/models/detectors/ddq_detr.py @@ -0,0 +1,274 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import MultiScaleDeformableAttention, batched_nms +from torch import Tensor, nn +from torch.nn.init import normal_ + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy +from mmdet.utils import OptConfigType +from ..layers import DDQTransformerDecoder +from ..utils import align_tensor +from .deformable_detr import DeformableDETR +from .dino import DINO + + +@MODELS.register_module() +class DDQDETR(DINO): + r"""Implementation of `Dense Distinct Query for + End-to-End Object Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + dense_topk_ratio (float): Ratio of num_dense queries to num_queries. + Defaults to 1.5. + dqs_cfg (:obj:`ConfigDict` or dict, optional): Config of + Distinct Queries Selection. Defaults to nms with + `iou_threshold` = 0.8. + """ + + def __init__(self, + *args, + dense_topk_ratio: float = 1.5, + dqs_cfg: OptConfigType = dict(type='nms', iou_threshold=0.8), + **kwargs): + self.dense_topk_ratio = dense_topk_ratio + self.decoder_cfg = kwargs['decoder'] + self.dqs_cfg = dqs_cfg + super().__init__(*args, **kwargs) + + # a share dict in all moduls + # pass some intermediate results and config parameters + cache_dict = dict() + for m in self.modules(): + m.cache_dict = cache_dict + # first element is the start index of matching queries + # second element is the number of matching queries + self.cache_dict['dis_query_info'] = [0, 0] + + # mask for distinct queries in each decoder layer + self.cache_dict['distinct_query_mask'] = [] + # pass to decoder do the dqs + self.cache_dict['cls_branches'] = self.bbox_head.cls_branches + # Used to construct the attention mask after dqs + self.cache_dict['num_heads'] = self.encoder.layers[ + 0].self_attn.num_heads + # pass to decoder to do the dqs + self.cache_dict['dqs_cfg'] = self.dqs_cfg + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + super(DDQDETR, self)._init_layers() + self.decoder = DDQTransformerDecoder(**self.decoder_cfg) + self.query_embedding = None + self.query_map = nn.Linear(self.embed_dims, self.embed_dims) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super(DeformableDETR, self).init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + normal_(self.level_embed) + + def pre_decoder( + self, + memory: Tensor, + memory_mask: Tensor, + spatial_shapes: Tensor, + batch_data_samples: OptSampleList = None, + ) -> Tuple[Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `memory`, and `reference_points`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). Will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels. + With shape (num_levels, 2), last dimension represents (h, w). + Will only be used when `as_two_stage` is `True`. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'memory', + `reference_points`, and `dn_mask`. The reference points of + decoder input here are 4D boxes, although it has `points` + in its name. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which includes `topk_score`, `topk_coords`, + `dense_topk_score`, `dense_topk_coords`, + and `dn_meta`, when `self.training` is `True`, else is empty. + """ + bs, _, c = memory.shape + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + if self.training: + # aux dense branch particularly in DDQ DETR, which doesn't exist + # in DINO. + # -1 is the aux head for the encoder + dense_enc_outputs_class = self.bbox_head.cls_branches[-1]( + output_memory) + dense_enc_outputs_coord_unact = self.bbox_head.reg_branches[-1]( + output_memory) + output_proposals + + topk = self.num_queries + dense_topk = int(topk * self.dense_topk_ratio) + + proposals = enc_outputs_coord_unact.sigmoid() + proposals = bbox_cxcywh_to_xyxy(proposals) + scores = enc_outputs_class.max(-1)[0].sigmoid() + + if self.training: + # aux dense branch particularly in DDQ DETR, which doesn't exist + # in DINO. + dense_proposals = dense_enc_outputs_coord_unact.sigmoid() + dense_proposals = bbox_cxcywh_to_xyxy(dense_proposals) + dense_scores = dense_enc_outputs_class.max(-1)[0].sigmoid() + + num_imgs = len(scores) + topk_score = [] + topk_coords_unact = [] + # Distinct query. + query = [] + + dense_topk_score = [] + dense_topk_coords_unact = [] + dense_query = [] + + for img_id in range(num_imgs): + single_proposals = proposals[img_id] + single_scores = scores[img_id] + + # `batched_nms` of class scores and bbox coordinations is used + # particularly by DDQ DETR for region proposal generation, + # instead of `topk` of class scores by DINO. + _, keep_idxs = batched_nms( + single_proposals, single_scores, + torch.ones(len(single_scores), device=single_scores.device), + self.cache_dict['dqs_cfg']) + + if self.training: + # aux dense branch particularly in DDQ DETR, which doesn't + # exist in DINO. + dense_single_proposals = dense_proposals[img_id] + dense_single_scores = dense_scores[img_id] + # sort according the score + # Only sort by classification score, neither nms nor topk is + # required. So input parameter `nms_cfg` = None. + _, dense_keep_idxs = batched_nms( + dense_single_proposals, dense_single_scores, + torch.ones( + len(dense_single_scores), + device=dense_single_scores.device), None) + + dense_topk_score.append(dense_enc_outputs_class[img_id] + [dense_keep_idxs][:dense_topk]) + dense_topk_coords_unact.append( + dense_enc_outputs_coord_unact[img_id][dense_keep_idxs] + [:dense_topk]) + + topk_score.append(enc_outputs_class[img_id][keep_idxs][:topk]) + + # Instead of initializing the content part with transformed + # coordinates in Deformable DETR, we fuse the feature map + # embedding of distinct positions as the content part, which + # makes the initial queries more distinct. + topk_coords_unact.append( + enc_outputs_coord_unact[img_id][keep_idxs][:topk]) + + map_memory = self.query_map(memory[img_id].detach()) + query.append(map_memory[keep_idxs][:topk]) + if self.training: + # aux dense branch particularly in DDQ DETR, which doesn't + # exist in DINO. + dense_query.append(map_memory[dense_keep_idxs][:dense_topk]) + + topk_score = align_tensor(topk_score, topk) + topk_coords_unact = align_tensor(topk_coords_unact, topk) + query = align_tensor(query, topk) + if self.training: + dense_topk_score = align_tensor(dense_topk_score) + dense_topk_coords_unact = align_tensor(dense_topk_coords_unact) + + dense_query = align_tensor(dense_query) + num_dense_queries = dense_query.size(1) + if self.training: + query = torch.cat([query, dense_query], dim=1) + topk_coords_unact = torch.cat( + [topk_coords_unact, dense_topk_coords_unact], dim=1) + + topk_coords = topk_coords_unact.sigmoid() + if self.training: + dense_topk_coords = topk_coords[:, -num_dense_queries:] + topk_coords = topk_coords[:, :-num_dense_queries] + + topk_coords_unact = topk_coords_unact.detach() + + if self.training: + dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ + self.dn_query_generator(batch_data_samples) + query = torch.cat([dn_label_query, query], dim=1) + reference_points = torch.cat([dn_bbox_query, topk_coords_unact], + dim=1) + + # Update `dn_mask` to add mask for dense queries. + ori_size = dn_mask.size(-1) + new_size = dn_mask.size(-1) + num_dense_queries + new_dn_mask = dn_mask.new_ones((new_size, new_size)).bool() + dense_mask = torch.zeros(num_dense_queries, + num_dense_queries).bool() + self.cache_dict['dis_query_info'] = [dn_label_query.size(1), topk] + + new_dn_mask[ori_size:, ori_size:] = dense_mask + new_dn_mask[:ori_size, :ori_size] = dn_mask + dn_meta['num_dense_queries'] = num_dense_queries + dn_mask = new_dn_mask + self.cache_dict['num_dense_queries'] = num_dense_queries + self.decoder.aux_reg_branches = self.bbox_head.aux_reg_branches + + else: + self.cache_dict['dis_query_info'] = [0, topk] + reference_points = topk_coords_unact + dn_mask, dn_meta = None, None + + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, + memory=memory, + reference_points=reference_points, + dn_mask=dn_mask) + head_inputs_dict = dict( + enc_outputs_class=topk_score, + enc_outputs_coord=topk_coords, + aux_enc_outputs_class=dense_topk_score, + aux_enc_outputs_coord=dense_topk_coords, + dn_meta=dn_meta) if self.training else dict() + + return decoder_inputs_dict, head_inputs_dict diff --git a/mmdet/models/detectors/deformable_detr.py b/mmdet/models/detectors/deformable_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb5cd2f95204542d5a9ace1a6d92e0b858c139f --- /dev/null +++ b/mmdet/models/detectors/deformable_detr.py @@ -0,0 +1,572 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention +from mmengine.model import xavier_init +from torch import Tensor, nn +from torch.nn.init import normal_ + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList +from mmdet.utils import OptConfigType +from ..layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerEncoder, SinePositionalEncoding) +from .base_detr import DetectionTransformer + + +@MODELS.register_module() +class DeformableDETR(DetectionTransformer): + r"""Implementation of `Deformable DETR: Deformable Transformers for + End-to-End Object Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + bbox_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding box head module. Defaults to None. + with_box_refine (bool, optional): Whether to refine the references + in the decoder. Defaults to `False`. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + num_feature_levels (int, optional): Number of feature levels. + Defaults to 4. + """ + + def __init__(self, + *args, + decoder: OptConfigType = None, + bbox_head: OptConfigType = None, + with_box_refine: bool = False, + as_two_stage: bool = False, + num_feature_levels: int = 4, + **kwargs) -> None: + self.with_box_refine = with_box_refine + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + + if bbox_head is not None: + assert 'share_pred_layer' not in bbox_head and \ + 'num_pred_layer' not in bbox_head and \ + 'as_two_stage' not in bbox_head, \ + 'The two keyword args `share_pred_layer`, `num_pred_layer`, ' \ + 'and `as_two_stage are set in `detector.__init__()`, users ' \ + 'should not set them in `bbox_head` config.' + # The last prediction layer is used to generate proposal + # from encode feature map when `as_two_stage` is `True`. + # And all the prediction layers should share parameters + # when `with_box_refine` is `True`. + bbox_head['share_pred_layer'] = not with_box_refine + bbox_head['num_pred_layer'] = (decoder['num_layers'] + 1) \ + if self.as_two_stage else decoder['num_layers'] + bbox_head['as_two_stage'] = as_two_stage + + super().__init__(*args, decoder=decoder, bbox_head=bbox_head, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DeformableDetrTransformerEncoder(**self.encoder) + self.decoder = DeformableDetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + if not self.as_two_stage: + self.query_embedding = nn.Embedding(self.num_queries, + self.embed_dims * 2) + # NOTE The query_embedding will be split into query and query_pos + # in self.pre_decoder, hence, the embed_dims are doubled. + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + 'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans_fc = nn.Linear(self.embed_dims * 2, + self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points_fc = nn.Linear(self.embed_dims, 2) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if self.as_two_stage: + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + nn.init.xavier_uniform_(self.pos_trans_fc.weight) + else: + xavier_init( + self.reference_points_fc, distribution='uniform', bias=0.) + normal_(self.level_embed) + + def pre_transformer( + self, + mlvl_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict]: + """Process image features before feeding them to the transformer. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + mlvl_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The first dict contains the inputs of encoder and the + second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + and 'feat_pos'. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask'. + """ + batch_size = mlvl_feats[0].size(0) + + # construct binary masks for the transformer. + assert batch_data_samples is not None + batch_input_shape = batch_data_samples[0].batch_input_shape + input_img_h, input_img_w = batch_input_shape + img_shape_list = [sample.img_shape for sample in batch_data_samples] + same_shape_flag = all([ + s[0] == input_img_h and s[1] == input_img_w for s in img_shape_list + ]) + # support torch2onnx without feeding masks + if torch.onnx.is_in_onnx_export() or same_shape_flag: + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in mlvl_feats: + mlvl_masks.append(None) + mlvl_pos_embeds.append( + self.positional_encoding(None, input=feat)) + else: + masks = mlvl_feats[0].new_ones( + (batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_shape_list[img_id] + masks[img_id, :img_h, :img_w] = 0 + # NOTE following the official DETR repo, non-zero + # values representing ignored positions, while + # zero values means valid positions. + + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate(masks[None], size=feat.shape[-2:]).to( + torch.bool).squeeze(0)) + mlvl_pos_embeds.append( + self.positional_encoding(mlvl_masks[-1])) + + feat_flatten = [] + lvl_pos_embed_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + batch_size, c, h, w = feat.shape + spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) + # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] + feat = feat.view(batch_size, c, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] + if mask is not None: + mask = mask.flatten(1) + + feat_flatten.append(feat) + lvl_pos_embed_flatten.append(lvl_pos_embed) + mask_flatten.append(mask) + spatial_shapes.append(spatial_shape) + + # (bs, num_feat_points, dim) + feat_flatten = torch.cat(feat_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) + if mask_flatten[0] is not None: + mask_flatten = torch.cat(mask_flatten, 1) + else: + mask_flatten = None + + # (num_level, 2) + spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), # (num_level) + spatial_shapes.prod(1).cumsum(0)[:-1])) + if mlvl_masks[0] is not None: + valid_ratios = torch.stack( # (bs, num_level, 2) + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + else: + valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats), + 2) + + encoder_inputs_dict = dict( + feat=feat_flatten, + feat_mask=mask_flatten, + feat_pos=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + decoder_inputs_dict = dict( + memory_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + return encoder_inputs_dict, decoder_inputs_dict + + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor) -> Dict: + """Forward with Transformer encoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + feat (Tensor): Sequential features, has shape (bs, num_feat_points, + dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (bs, num_feat_points). + feat_pos (Tensor): The positional embeddings of the features, has + shape (bs, num_feat_points, dim). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + memory = self.encoder( + query=feat, + query_pos=feat_pos, + key_padding_mask=feat_mask, # for self_attn + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + encoder_outputs_dict = dict( + memory=memory, + memory_mask=feat_mask, + spatial_shapes=spatial_shapes) + return encoder_outputs_dict + + def pre_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`, and `reference_points`. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). It will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + It will only be used when `as_two_stage` is `True`. + + Returns: + tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory', and `reference_points`. The reference_points of + decoder input here are 4D boxes when `as_two_stage` is `True`, + otherwise 2D points, although it has `points` in its name. + The reference_points in encoder is always 2D points. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which includes `enc_outputs_class` and + `enc_outputs_coord`. They are both `None` when 'as_two_stage' + is `False`. The dict is empty when `self.training` is `False`. + """ + batch_size, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + # We only use the first channel in enc_outputs_class as foreground, + # the other (num_classes - 1) channels are actually not used. + # Its targets are set to be 0s, which indicates the first + # class (foreground) because we use [0, num_classes - 1] to + # indicate class labels, background class is indicated by + # num_classes (similar convention in RPN). + # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa + # This follows the official implementation of Deformable DETR. + topk_proposals = torch.topk( + enc_outputs_class[..., 0], self.num_queries, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + pos_trans_out = self.pos_trans_fc( + self.get_proposal_pos_embed(topk_coords_unact)) + pos_trans_out = self.pos_trans_norm(pos_trans_out) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + enc_outputs_class, enc_outputs_coord = None, None + query_embed = self.query_embedding.weight + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1) + query = query.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = self.reference_points_fc(query_pos).sigmoid() + + decoder_inputs_dict = dict( + query=query, + query_pos=query_pos, + memory=memory, + reference_points=reference_points) + head_inputs_dict = dict( + enc_outputs_class=enc_outputs_class, + enc_outputs_coord=enc_outputs_coord) if self.training else dict() + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, reference_points: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor) -> Dict: + """Forward with Transformer decoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (bs, num_queries, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged as + (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output and `references` including + the initial and intermediate reference_points. + """ + inter_states, inter_references = self.decoder( + query=query, + value=memory, + query_pos=query_pos, + key_padding_mask=memory_mask, # for cross_attn + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=self.bbox_head.reg_branches + if self.with_box_refine else None) + references = [reference_points, *inter_references] + decoder_outputs_dict = dict( + hidden_states=inter_states, references=references) + return decoder_outputs_dict + + @staticmethod + def get_valid_ratio(mask: Tensor) -> Tensor: + """Get the valid radios of feature map in a level. + + .. code:: text + + |---> valid_W <---| + ---+-----------------+-----+--- + A | | | A + | | | | | + | | | | | + valid_H | | | | + | | | | H + | | | | | + V | | | | + ---+-----------------+ | | + | | V + +-----------------------+--- + |---------> W <---------| + + The valid_ratios are defined as: + r_h = valid_H / H, r_w = valid_W / W + They are the factors to re-normalize the relative coordinates of the + image to the relative coordinates of the current level feature map. + + Args: + mask (Tensor): Binary mask of a feature map, has shape (bs, H, W). + + Returns: + Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2). + """ + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def gen_encoder_output_proposals( + self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]: + """Generate proposals from encoded memory. The function will only be + used when `as_two_stage` is `True`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + + Returns: + tuple: A tuple of transformed memory and proposals. + + - output_memory (Tensor): The transformed memory for obtaining + top-k proposals, has shape (bs, num_feat_points, dim). + - output_proposals (Tensor): The inverse-normalized proposal, has + shape (batch_size, num_keys, 4) with the last dimension arranged + as (cx, cy, w, h). + """ + + bs = memory.size(0) + proposals = [] + _cur = 0 # start index in the sequence of the current level + for lvl, HW in enumerate(spatial_shapes): + H, W = HW + + if memory_mask is not None: + mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view( + bs, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], + 1).unsqueeze(-1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], + 1).unsqueeze(-1) + scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2) + else: + if not isinstance(HW, torch.Tensor): + HW = memory.new_tensor(HW) + scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2) + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) + proposals.append(proposal) + _cur += (H * W) + output_proposals = torch.cat(proposals, 1) + # do not use `all` to make it exportable to onnx + output_proposals_valid = ( + (output_proposals > 0.01) & (output_proposals < 0.99)).sum( + -1, keepdim=True) == output_proposals.shape[-1] + # inverse_sigmoid + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + if memory_mask is not None: + output_proposals = output_proposals.masked_fill( + memory_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + if memory_mask is not None: + output_memory = output_memory.masked_fill( + memory_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.memory_trans_fc(output_memory) + output_memory = self.memory_trans_norm(output_memory) + # [bs, sum(hw), 2] + return output_memory, output_proposals + + @staticmethod + def get_proposal_pos_embed(proposals: Tensor, + num_pos_feats: int = 128, + temperature: int = 10000) -> Tensor: + """Get the position embedding of the proposal. + + Args: + proposals (Tensor): Not normalized proposals, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + num_pos_feats (int, optional): The feature dimension for each + position along x, y, w, and h-axis. Note the final returned + dimension for each position is 4 times of num_pos_feats. + Default to 128. + temperature (int, optional): The temperature used for scaling the + position embedding. Defaults to 10000. + + Returns: + Tensor: The position embedding of proposal, has shape + (bs, num_queries, num_pos_feats * 4), with the last dimension + arranged as (cx, cy, w, h) + """ + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py new file mode 100644 index 0000000000000000000000000000000000000000..7895e9ecb4eb66cb75d173c191c2128c3f55c197 --- /dev/null +++ b/mmdet/models/detectors/detr.py @@ -0,0 +1,225 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList +from ..layers import (DetrTransformerDecoder, DetrTransformerEncoder, + SinePositionalEncoding) +from .base_detr import DetectionTransformer + + +@MODELS.register_module() +class DETR(DetectionTransformer): + r"""Implementation of `DETR: End-to-End Object Detection with Transformers. + + `_. + + Code is modified from the `official github repo + `_. + """ + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DetrTransformerEncoder(**self.encoder) + self.decoder = DetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + # NOTE The embed_dims is typically passed from the inside out. + # For example in DETR, The embed_dims is passed as + # self_attn -> the first encoder layer -> encoder -> detector. + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + 'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def pre_transformer( + self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]: + """Prepare the inputs of the Transformer. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + img_feats (Tuple[Tensor]): Tuple of features output from the neck, + has shape (bs, c, h, w). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such as + `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict, dict]: The first dict contains the inputs of encoder + and the second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_encoder()`, which includes 'feat', 'feat_mask', + and 'feat_pos'. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask', + and 'memory_pos'. + """ + + feat = img_feats[-1] # NOTE img_feats contains only one feature. + batch_size, feat_dim, _, _ = feat.shape + # construct binary masks which for the transformer. + assert batch_data_samples is not None + batch_input_shape = batch_data_samples[0].batch_input_shape + input_img_h, input_img_w = batch_input_shape + img_shape_list = [sample.img_shape for sample in batch_data_samples] + same_shape_flag = all([ + s[0] == input_img_h and s[1] == input_img_w for s in img_shape_list + ]) + if torch.onnx.is_in_onnx_export() or same_shape_flag: + masks = None + # [batch_size, embed_dim, h, w] + pos_embed = self.positional_encoding(masks, input=feat) + else: + masks = feat.new_ones((batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_shape_list[img_id] + masks[img_id, :img_h, :img_w] = 0 + # NOTE following the official DETR repo, non-zero values represent + # ignored positions, while zero values mean valid positions. + + masks = F.interpolate( + masks.unsqueeze(1), + size=feat.shape[-2:]).to(torch.bool).squeeze(1) + # [batch_size, embed_dim, h, w] + pos_embed = self.positional_encoding(masks) + + # use `view` instead of `flatten` for dynamically exporting to ONNX + # [bs, c, h, w] -> [bs, h*w, c] + feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1) + # [bs, h, w] -> [bs, h*w] + if masks is not None: + masks = masks.view(batch_size, -1) + + # prepare transformer_inputs_dict + encoder_inputs_dict = dict( + feat=feat, feat_mask=masks, feat_pos=pos_embed) + decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed) + return encoder_inputs_dict, decoder_inputs_dict + + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor) -> Dict: + """Forward with Transformer encoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + feat (Tensor): Sequential features, has shape (bs, num_feat_points, + dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (bs, num_feat_points). + feat_pos (Tensor): The positional embeddings of the features, has + shape (bs, num_feat_points, dim). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + memory = self.encoder( + query=feat, query_pos=feat_pos, + key_padding_mask=feat_mask) # for self_attn + encoder_outputs_dict = dict(memory=memory) + return encoder_outputs_dict + + def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + + Returns: + tuple[dict, dict]: The first dict contains the inputs of decoder + and the second dict contains the inputs of the bbox_head function. + + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory'. + - head_inputs_dict (dict): The keyword args dictionary of the + bbox_head functions, which is usually empty, or includes + `enc_outputs_class` and `enc_outputs_class` when the detector + support 'two stage' or 'query selection' strategies. + """ + + batch_size = memory.size(0) # (bs, num_feat_points, dim) + query_pos = self.query_embedding.weight + # (num_queries, dim) -> (bs, num_queries, dim) + query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1) + query = torch.zeros_like(query_pos) + + decoder_inputs_dict = dict( + query_pos=query_pos, query=query, memory=memory) + head_inputs_dict = dict() + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, memory_pos: Tensor) -> Dict: + """Forward with Transformer decoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (bs, num_queries, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + memory_pos (Tensor): The positional embeddings of memory, has + shape (bs, num_feat_points, dim). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output. + + - hidden_states (Tensor): Has shape + (num_decoder_layers, bs, num_queries, dim) + """ + + hidden_states = self.decoder( + query=query, + key=memory, + value=memory, + query_pos=query_pos, + key_pos=memory_pos, + key_padding_mask=memory_mask) # for cross_attn + + head_inputs_dict = dict(hidden_states=hidden_states) + return head_inputs_dict diff --git a/mmdet/models/detectors/dino.py b/mmdet/models/detectors/dino.py new file mode 100644 index 0000000000000000000000000000000000000000..ade47f531d27246511cafc2997a07d58677538a7 --- /dev/null +++ b/mmdet/models/detectors/dino.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.nn.init import normal_ + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList +from mmdet.utils import OptConfigType +from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder, + DinoTransformerDecoder, SinePositionalEncoding) +from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention + + +@MODELS.register_module() +class DINO(DeformableDETR): + r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes + for End-to-End Object Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising + query generator. Defaults to `None`. + """ + + def __init__(self, *args, dn_cfg: OptConfigType = None, **kwargs) -> None: + super().__init__(*args, **kwargs) + assert self.as_two_stage, 'as_two_stage must be True for DINO' + assert self.with_box_refine, 'with_box_refine must be True for DINO' + + if dn_cfg is not None: + assert 'num_classes' not in dn_cfg and \ + 'num_queries' not in dn_cfg and \ + 'hidden_dim' not in dn_cfg, \ + 'The three keyword args `num_classes`, `embed_dims`, and ' \ + '`num_matching_queries` are set in `detector.__init__()`, ' \ + 'users should not set them in `dn_cfg` config.' + dn_cfg['num_classes'] = self.bbox_head.num_classes + dn_cfg['embed_dims'] = self.embed_dims + dn_cfg['num_matching_queries'] = self.num_queries + self.dn_query_generator = CdnQueryGenerator(**dn_cfg) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DeformableDetrTransformerEncoder(**self.encoder) + self.decoder = DinoTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + # NOTE In DINO, the query_embedding only contains content + # queries, while in Deformable DETR, the query_embedding + # contains both content and spatial queries, and in DETR, + # it only contains spatial queries. + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + f'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super(DeformableDETR, self).init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + nn.init.xavier_uniform_(self.query_embedding.weight) + normal_(self.level_embed) + + def forward_transformer( + self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None, + ) -> Dict: + """Forward process of Transformer. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + The difference is that the ground truth in `batch_data_samples` is + required for the `pre_decoder` to prepare the query of DINO. + Additionally, DINO inherits the `pre_transformer` method and the + `forward_encoder` method of DeformableDETR. More details about the + two methods can be found in `mmdet/detector/deformable_detr.py`. + + Args: + img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each + feature map has shape (bs, dim, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + dict: The dictionary of bbox_head function inputs, which always + includes the `hidden_states` of the decoder output and may contain + `references` including the initial and intermediate references. + """ + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) + + tmp_dec_in, head_inputs_dict = self.pre_decoder( + **encoder_outputs_dict, batch_data_samples=batch_data_samples) + decoder_inputs_dict.update(tmp_dec_in) + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + head_inputs_dict.update(decoder_outputs_dict) + return head_inputs_dict + + def pre_decoder( + self, + memory: Tensor, + memory_mask: Tensor, + spatial_shapes: Tensor, + batch_data_samples: OptSampleList = None, + ) -> Tuple[Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`, and `reference_points`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). Will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels. + With shape (num_levels, 2), last dimension represents (h, w). + Will only be used when `as_two_stage` is `True`. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'memory', + `reference_points`, and `dn_mask`. The reference points of + decoder input here are 4D boxes, although it has `points` + in its name. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which includes `topk_score`, `topk_coords`, + and `dn_meta` when `self.training` is `True`, else is empty. + """ + bs, _, c = memory.shape + cls_out_features = self.bbox_head.cls_branches[ + self.decoder.num_layers].out_features + + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + # NOTE The DINO selects top-k proposals according to scores of + # multi-class classification, while DeformDETR, where the input + # is `enc_outputs_class[..., 0]` selects according to scores of + # binary classification. + topk_indices = torch.topk( + enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] + topk_score = torch.gather( + enc_outputs_class, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords = topk_coords_unact.sigmoid() + topk_coords_unact = topk_coords_unact.detach() + + query = self.query_embedding.weight[:, None, :] + query = query.repeat(1, bs, 1).transpose(0, 1) + if self.training: + dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ + self.dn_query_generator(batch_data_samples) + query = torch.cat([dn_label_query, query], dim=1) + reference_points = torch.cat([dn_bbox_query, topk_coords_unact], + dim=1) + else: + reference_points = topk_coords_unact + dn_mask, dn_meta = None, None + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, + memory=memory, + reference_points=reference_points, + dn_mask=dn_mask) + # NOTE DINO calculates encoder losses on scores and coordinates + # of selected top-k encoder queries, while DeformDETR is of all + # encoder queries. + head_inputs_dict = dict( + enc_outputs_class=topk_score, + enc_outputs_coord=topk_coords, + dn_meta=dn_meta) if self.training else dict() + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, + query: Tensor, + memory: Tensor, + memory_mask: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + dn_mask: Optional[Tensor] = None, + **kwargs) -> Dict: + """Forward with Transformer decoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries_total, dim), where `num_queries_total` is the + sum of `num_denoising_queries` and `num_matching_queries` when + `self.training` is `True`, else `num_matching_queries`. + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries_total, 4) with the last dimension arranged as + (cx, cy, w, h). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + dn_mask (Tensor, optional): The attention mask to prevent + information leakage from different denoising groups and + matching parts, will be used as `self_attn_mask` of the + `self.decoder`, has shape (num_queries_total, + num_queries_total). + It is `None` when `self.training` is `False`. + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output and `references` including + the initial and intermediate reference_points. + """ + inter_states, references = self.decoder( + query=query, + value=memory, + key_padding_mask=memory_mask, + self_attn_mask=dn_mask, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=self.bbox_head.reg_branches, + **kwargs) + + if len(query) == self.num_queries: + # NOTE: This is to make sure label_embeding can be involved to + # produce loss even if there is no denoising query (no ground truth + # target in this GPU), otherwise, this will raise runtime error in + # distributed training. + inter_states[0] += \ + self.dn_query_generator.label_embedding.weight[0, 0] * 0.0 + + decoder_outputs_dict = dict( + hidden_states=inter_states, references=list(references)) + return decoder_outputs_dict diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5b39050fdc2989eb5c870704e1c1417987d53d46 --- /dev/null +++ b/mmdet/models/detectors/fast_rcnn.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class FastRCNN(TwoStageDetector): + """Implementation of `Fast R-CNN `_""" + + def __init__(self, + backbone: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) diff --git a/mmdet/models/detectors/faster_rcnn.py b/mmdet/models/detectors/faster_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..36109e3200a2d8e7d8a1032f7028e47a7699fb6a --- /dev/null +++ b/mmdet/models/detectors/faster_rcnn.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class FasterRCNN(TwoStageDetector): + """Implementation of `Faster R-CNN `_""" + + def __init__(self, + backbone: ConfigType, + rpn_head: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) diff --git a/mmdet/models/detectors/fcos.py b/mmdet/models/detectors/fcos.py new file mode 100644 index 0000000000000000000000000000000000000000..c628059313ac80644ec2ba2c806e7baf2e418a41 --- /dev/null +++ b/mmdet/models/detectors/fcos.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class FCOS(SingleStageDetector): + """Implementation of `FCOS `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of FCOS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of FCOS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/fovea.py b/mmdet/models/detectors/fovea.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4f21caa239147e3b81e66280aa1da043715b42 --- /dev/null +++ b/mmdet/models/detectors/fovea.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class FOVEA(SingleStageDetector): + """Implementation of `FoveaBox `_ + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of FOVEA. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of FOVEA. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/fsaf.py b/mmdet/models/detectors/fsaf.py new file mode 100644 index 0000000000000000000000000000000000000000..01b40273341f2a85cfa427f8adfc945a1b7da58a --- /dev/null +++ b/mmdet/models/detectors/fsaf.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class FSAF(SingleStageDetector): + """Implementation of `FSAF `_""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/gfl.py b/mmdet/models/detectors/gfl.py new file mode 100644 index 0000000000000000000000000000000000000000..c26821af68c224d4b55a1ca3d2be4c6e1d1b155d --- /dev/null +++ b/mmdet/models/detectors/gfl.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class GFL(SingleStageDetector): + """Implementation of `GFL `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of GFL. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of GFL. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/glip.py b/mmdet/models/detectors/glip.py new file mode 100644 index 0000000000000000000000000000000000000000..e076a55fe20926d9d5f2d95cac645e2db2251045 --- /dev/null +++ b/mmdet/models/detectors/glip.py @@ -0,0 +1,403 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +import warnings +from typing import Tuple, Union + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +def find_noun_phrases(caption: str) -> list: + """Find noun phrases in a caption using nltk. + Args: + caption (str): The caption to analyze. + + Returns: + list: List of noun phrases found in the caption. + + Examples: + >>> caption = 'There is two cat and a remote in the picture' + >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] + """ + try: + import nltk + nltk.download('punkt') + nltk.download('averaged_perceptron_tagger') + except ImportError: + raise RuntimeError('nltk is not installed, please install it by: ' + 'pip install nltk.') + + caption = caption.lower() + tokens = nltk.word_tokenize(caption) + pos_tags = nltk.pos_tag(tokens) + + grammar = 'NP: {
?*+}' + cp = nltk.RegexpParser(grammar) + result = cp.parse(pos_tags) + + noun_phrases = [] + for subtree in result.subtrees(): + if subtree.label() == 'NP': + noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) + + return noun_phrases + + +def remove_punctuation(text: str) -> str: + """Remove punctuation from a text. + Args: + text (str): The input text. + + Returns: + str: The text with punctuation removed. + """ + punctuation = [ + '|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’', + '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' + ] + for p in punctuation: + text = text.replace(p, '') + return text.strip() + + +def run_ner(caption: str) -> Tuple[list, list]: + """Run NER on a caption and return the tokens and noun phrases. + Args: + caption (str): The input caption. + + Returns: + Tuple[List, List]: A tuple containing the tokens and noun phrases. + - tokens_positive (List): A list of token positions. + - noun_phrases (List): A list of noun phrases. + """ + noun_phrases = find_noun_phrases(caption) + noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] + noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] + relevant_phrases = noun_phrases + labels = noun_phrases + + tokens_positive = [] + for entity, label in zip(relevant_phrases, labels): + try: + # search all occurrences and mark them as different entities + # TODO: Not Robust + for m in re.finditer(entity, caption.lower()): + tokens_positive.append([[m.start(), m.end()]]) + except Exception: + print('noun entities:', noun_phrases) + print('entity:', entity) + print('caption:', caption.lower()) + return tokens_positive, noun_phrases + + +def create_positive_map(tokenized, + tokens_positive: list, + max_num_entities: int = 256) -> Tensor: + """construct a map such that positive_map[i,j] = True + if box i is associated to token j + + Args: + tokenized: The tokenized input. + tokens_positive (list): A list of token ranges + associated with positive boxes. + max_num_entities (int, optional): The maximum number of entities. + Defaults to 256. + + Returns: + torch.Tensor: The positive map. + + Raises: + Exception: If an error occurs during token-to-char mapping. + """ + positive_map = torch.zeros((len(tokens_positive), max_num_entities), + dtype=torch.float) + + for j, tok_list in enumerate(tokens_positive): + for (beg, end) in tok_list: + try: + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + except Exception as e: + print('beg:', beg, 'end:', end) + print('token_positive:', tokens_positive) + raise e + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except Exception: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except Exception: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + positive_map[j, beg_pos:end_pos + 1].fill_(1) + return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) + + +def create_positive_map_label_to_token(positive_map: Tensor, + plus: int = 0) -> dict: + """Create a dictionary mapping the label to the token. + Args: + positive_map (Tensor): The positive map tensor. + plus (int, optional): Value added to the label for indexing. + Defaults to 0. + + Returns: + dict: The dictionary mapping the label to the token. + """ + positive_map_label_to_token = {} + for i in range(len(positive_map)): + positive_map_label_to_token[i + plus] = torch.nonzero( + positive_map[i], as_tuple=True)[0].tolist() + return positive_map_label_to_token + + +@MODELS.register_module() +class GLIP(SingleStageDetector): + """Implementation of `GLIP `_ + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + language_model (:obj:`ConfigDict` or dict): The language model config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of GLIP. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of GLIP. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + language_model: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + self.language_model = MODELS.build(language_model) + + self._special_tokens = '. ' + + def get_tokens_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, list, list]: + """Get the tokens positive and prompts for the caption.""" + if isinstance(original_caption, (list, tuple)) or custom_entities: + if custom_entities and isinstance(original_caption, str): + original_caption = original_caption.strip(self._special_tokens) + original_caption = original_caption.split(self._special_tokens) + original_caption = list( + filter(lambda x: len(x) > 0, original_caption)) + + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + if idx != len(original_caption) - 1: + caption_string += self._special_tokens + tokenized = self.language_model.tokenizer([caption_string], + return_tensors='pt') + entities = original_caption + else: + original_caption = original_caption.strip(self._special_tokens) + tokenized = self.language_model.tokenizer([original_caption], + return_tensors='pt') + tokens_positive, noun_phrases = run_ner(original_caption) + entities = noun_phrases + caption_string = original_caption + + return tokenized, caption_string, tokens_positive, entities + + def get_positive_map(self, tokenized, tokens_positive): + positive_map = create_positive_map(tokenized, tokens_positive) + positive_map_label_to_token = create_positive_map_label_to_token( + positive_map, plus=1) + return positive_map_label_to_token, positive_map + + def get_tokens_positive_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, Tensor, list]: + tokenized, caption_string, tokens_positive, entities = \ + self.get_tokens_and_prompts( + original_caption, custom_entities) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + return positive_map_label_to_token, caption_string, \ + positive_map, entities + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + # TODO: Only open vocabulary tasks are supported for training now. + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + + gt_labels = [ + data_samples.gt_instances.labels + for data_samples in batch_data_samples + ] + + new_text_prompts = [] + positive_maps = [] + if len(set(text_prompts)) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + tokenized, caption_string, tokens_positive, _ = \ + self.get_tokens_and_prompts( + text_prompts[0], True) + new_text_prompts = [caption_string] * len(batch_inputs) + for gt_label in gt_labels: + new_tokens_positive = [ + tokens_positive[label] for label in gt_label + ] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive) + positive_maps.append(positive_map) + else: + for text_prompt, gt_label in zip(text_prompts, gt_labels): + tokenized, caption_string, tokens_positive, _ = \ + self.get_tokens_and_prompts( + text_prompt, True) + new_tokens_positive = [ + tokens_positive[label] for label in gt_label + ] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive) + positive_maps.append(positive_map) + new_text_prompts.append(caption_string) + + language_dict_features = self.language_model(new_text_prompts) + for i, data_samples in enumerate(batch_data_samples): + # .bool().float() is very important + positive_map = positive_maps[i].to( + batch_inputs.device).bool().float() + data_samples.gt_instances.positive_maps = positive_map + + visual_features = self.extract_feat(batch_inputs) + + losses = self.bbox_head.loss(visual_features, language_dict_features, + batch_data_samples) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - label_names (List[str]): Label names of bboxes. + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + + if 'custom_entities' in batch_data_samples[0]: + # Assuming that the `custom_entities` flag + # inside a batch is always the same. For single image inference + custom_entities = batch_data_samples[0].custom_entities + else: + custom_entities = False + + if len(set(text_prompts)) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts(text_prompts[0], + custom_entities) + ] * len(batch_inputs) + else: + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts(text_prompt, + custom_entities) + for text_prompt in text_prompts + ] + + token_positive_maps, text_prompts, _, entities = zip( + *_positive_maps_and_prompts) + + language_dict_features = self.language_model(list(text_prompts)) + + for i, data_samples in enumerate(batch_data_samples): + data_samples.token_positive_map = token_positive_maps[i] + + visual_features = self.extract_feat(batch_inputs) + + results_list = self.bbox_head.predict( + visual_features, + language_dict_features, + batch_data_samples, + rescale=rescale) + + for data_sample, pred_instances, entity in zip(batch_data_samples, + results_list, entities): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + if labels >= len(entity): + warnings.warn( + 'The unexpected output indicates an issue with ' + 'named entity recognition. You can try ' + 'setting custom_entities=True and running ' + 'again to see if it helps.') + label_names.append('unobject') + else: + label_names.append(entity[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcb5b033edc620f1cf61b986c345961b719e6f1 --- /dev/null +++ b/mmdet/models/detectors/grid_rcnn.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class GridRCNN(TwoStageDetector): + """Grid R-CNN. + + This detector is the implementation of: + - Grid R-CNN (https://arxiv.org/abs/1811.12030) + - Grid R-CNN Plus: Faster and Better (https://arxiv.org/abs/1906.05688) + """ + + def __init__(self, + backbone: ConfigType, + rpn_head: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..69d398bec8f84a2d062875a0f955467b7c363926 --- /dev/null +++ b/mmdet/models/detectors/grounding_dino.py @@ -0,0 +1,384 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from ..layers import SinePositionalEncoding +from ..layers.transformer.grounding_dino_layers import ( + GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder) +from .dino import DINO +from .glip import (create_positive_map, create_positive_map_label_to_token, + run_ner) + + +@MODELS.register_module() +class GroundingDINO(DINO): + """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- + Training for Open-Set Object Detection. + + `_ + + Code is modified from the `official github repo + `_. + """ + + def __init__(self, language_model, *args, **kwargs) -> None: + + self.language_model_cfg = language_model + self._special_tokens = '. ' + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = GroundingDinoTransformerEncoder(**self.encoder) + self.decoder = GroundingDinoTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + f'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + + # text modules + self.language_model = MODELS.build(self.language_model_cfg) + self.text_feat_map = nn.Linear( + self.language_model.language_backbone.body.language_dim, + self.embed_dims, + bias=True) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + nn.init.constant_(self.text_feat_map.bias.data, 0) + nn.init.xavier_uniform_(self.text_feat_map.weight.data) + + def get_tokens_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, list]: + """Get the tokens positive and prompts for the caption.""" + if isinstance(original_caption, (list, tuple)) or custom_entities: + if custom_entities and isinstance(original_caption, str): + original_caption = original_caption.strip(self._special_tokens) + original_caption = original_caption.split(self._special_tokens) + original_caption = list( + filter(lambda x: len(x) > 0, original_caption)) + + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + caption_string += self._special_tokens + # NOTE: Tokenizer in Grounding DINO is different from + # that in GLIP. The tokenizer in GLIP will pad the + # caption_string to max_length, while the tokenizer + # in Grounding DINO will not. + tokenized = self.language_model.tokenizer( + [caption_string], + padding='max_length' + if self.language_model.pad_to_max else 'longest', + return_tensors='pt') + entities = original_caption + else: + if not original_caption.endswith('.'): + original_caption = original_caption + self._special_tokens + # NOTE: Tokenizer in Grounding DINO is different from + # that in GLIP. The tokenizer in GLIP will pad the + # caption_string to max_length, while the tokenizer + # in Grounding DINO will not. + tokenized = self.language_model.tokenizer( + [original_caption], + padding='max_length' + if self.language_model.pad_to_max else 'longest', + return_tensors='pt') + tokens_positive, noun_phrases = run_ner(original_caption) + entities = noun_phrases + caption_string = original_caption + + return tokenized, caption_string, tokens_positive, entities + + def get_positive_map(self, tokenized, tokens_positive): + positive_map = create_positive_map(tokenized, tokens_positive) + positive_map_label_to_token = create_positive_map_label_to_token( + positive_map, plus=1) + return positive_map_label_to_token, positive_map + + def get_tokens_positive_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, Tensor, list]: + """Get the tokens positive and prompts for the caption. + + Args: + original_caption (str): The original caption, e.g. 'bench . car .' + custom_entities (bool, optional): Whether to use custom entities. + If ``True``, the ``original_caption`` should be a list of + strings, each of which is a word. Defaults to False. + + Returns: + Tuple[dict, str, dict, str]: The dict is a mapping from each entity + id, which is numbered from 1, to its positive token id. + The str represents the prompts. + """ + tokenized, caption_string, tokens_positive, entities = \ + self.get_tokens_and_prompts( + original_caption, custom_entities) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + return positive_map_label_to_token, caption_string, \ + positive_map, entities + + def forward_transformer( + self, + img_feats: Tuple[Tensor], + text_dict: Dict, + batch_data_samples: OptSampleList = None, + ) -> Dict: + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + encoder_outputs_dict = self.forward_encoder( + **encoder_inputs_dict, text_dict=text_dict) + + tmp_dec_in, head_inputs_dict = self.pre_decoder( + **encoder_outputs_dict, batch_data_samples=batch_data_samples) + decoder_inputs_dict.update(tmp_dec_in) + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + head_inputs_dict.update(decoder_outputs_dict) + return head_inputs_dict + + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + text_dict: Dict) -> Dict: + text_token_mask = text_dict['text_token_mask'] + memory, memory_text = self.encoder( + query=feat, + query_pos=feat_pos, + key_padding_mask=feat_mask, # for self_attn + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + # for text encoder + memory_text=text_dict['embedded'], + text_attention_mask=~text_token_mask, + position_ids=text_dict['position_ids'], + text_self_attention_masks=text_dict['masks']) + encoder_outputs_dict = dict( + memory=memory, + memory_mask=feat_mask, + spatial_shapes=spatial_shapes, + memory_text=memory_text, + text_token_mask=text_token_mask) + return encoder_outputs_dict + + def pre_decoder( + self, + memory: Tensor, + memory_mask: Tensor, + spatial_shapes: Tensor, + memory_text: Tensor, + text_token_mask: Tensor, + batch_data_samples: OptSampleList = None, + ) -> Tuple[Dict]: + bs, _, c = memory.shape + + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers](output_memory, memory_text, + text_token_mask) + cls_out_features = self.bbox_head.cls_branches[ + self.decoder.num_layers].max_text_len + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + # NOTE The DINO selects top-k proposals according to scores of + # multi-class classification, while DeformDETR, where the input + # is `enc_outputs_class[..., 0]` selects according to scores of + # binary classification. + topk_indices = torch.topk( + enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] + + topk_score = torch.gather( + enc_outputs_class, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords = topk_coords_unact.sigmoid() + topk_coords_unact = topk_coords_unact.detach() + + query = self.query_embedding.weight[:, None, :] + query = query.repeat(1, bs, 1).transpose(0, 1) + if self.training: + dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ + self.dn_query_generator(batch_data_samples) + query = torch.cat([dn_label_query, query], dim=1) + reference_points = torch.cat([dn_bbox_query, topk_coords_unact], + dim=1) + else: + reference_points = topk_coords_unact + dn_mask, dn_meta = None, None + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, + memory=memory, + reference_points=reference_points, + dn_mask=dn_mask, + memory_text=memory_text, + text_attention_mask=~text_token_mask, + ) + # NOTE DINO calculates encoder losses on scores and coordinates + # of selected top-k encoder queries, while DeformDETR is of all + # encoder queries. + head_inputs_dict = dict( + enc_outputs_class=topk_score, + enc_outputs_coord=topk_coords, + dn_meta=dn_meta) if self.training else dict() + # append text_feats to head_inputs_dict + head_inputs_dict['memory_text'] = memory_text + head_inputs_dict['text_token_mask'] = text_token_mask + return decoder_inputs_dict, head_inputs_dict + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + # TODO: Only open vocabulary tasks are supported for training now. + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + + gt_labels = [ + data_samples.gt_instances.labels + for data_samples in batch_data_samples + ] + + new_text_prompts = [] + positive_maps = [] + if len(set(text_prompts)) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + tokenized, caption_string, tokens_positive, _ = \ + self.get_tokens_and_prompts( + text_prompts[0], True) + new_text_prompts = [caption_string] * len(batch_inputs) + for gt_label in gt_labels: + new_tokens_positive = [ + tokens_positive[label] for label in gt_label + ] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive) + positive_maps.append(positive_map) + else: + for text_prompt, gt_label in zip(text_prompts, gt_labels): + tokenized, caption_string, tokens_positive, _ = \ + self.get_tokens_and_prompts( + text_prompt, True) + new_tokens_positive = [ + tokens_positive[label] for label in gt_label + ] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive) + positive_maps.append(positive_map) + new_text_prompts.append(caption_string) + + text_dict = self.language_model(new_text_prompts) + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) + + for i, data_samples in enumerate(batch_data_samples): + positive_map = positive_maps[i].to( + batch_inputs.device).bool().float() + text_token_mask = text_dict['text_token_mask'][i] + data_samples.gt_instances.positive_maps = positive_map + data_samples.gt_instances.text_token_mask = \ + text_token_mask.unsqueeze(0).repeat( + len(positive_map), 1) + + visual_features = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(visual_features, text_dict, + batch_data_samples) + + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples) + return losses + + def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + if 'custom_entities' in batch_data_samples[0]: + # Assuming that the `custom_entities` flag + # inside a batch is always the same. For single image inference + custom_entities = batch_data_samples[0].custom_entities + else: + custom_entities = False + if len(text_prompts) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts(text_prompts[0], + custom_entities) + ] * len(batch_inputs) + else: + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts(text_prompt, + custom_entities) + for text_prompt in text_prompts + ] + token_positive_maps, text_prompts, _, entities = zip( + *_positive_maps_and_prompts) + # extract text feats + text_dict = self.language_model(list(text_prompts)) + # text feature map layer + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) + + for i, data_samples in enumerate(batch_data_samples): + data_samples.token_positive_map = token_positive_maps[i] + + # image feature extraction + visual_feats = self.extract_feat(batch_inputs) + + head_inputs_dict = self.forward_transformer(visual_feats, text_dict, + batch_data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + for data_sample, pred_instances, entity in zip(batch_data_samples, + results_list, entities): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + if labels >= len(entity): + warnings.warn( + 'The unexpected output indicates an issue with ' + 'named entity recognition. You can try ' + 'setting custom_entities=True and running ' + 'again to see if it helps.') + label_names.append('unobject') + else: + label_names.append(entity[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py new file mode 100644 index 0000000000000000000000000000000000000000..22a2aa889a59fd0e0afeb95a7369028def6e4fa9 --- /dev/null +++ b/mmdet/models/detectors/htc.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from .cascade_rcnn import CascadeRCNN + + +@MODELS.register_module() +class HybridTaskCascade(CascadeRCNN): + """Implementation of `HTC `_""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + @property + def with_semantic(self) -> bool: + """bool: whether the detector has a semantic head""" + return self.roi_head.with_semantic diff --git a/mmdet/models/detectors/kd_one_stage.py b/mmdet/models/detectors/kd_one_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4a1bb564c0f6e4cabe32a5c01cfea252ecfb7d --- /dev/null +++ b/mmdet/models/detectors/kd_one_stage.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from mmengine.config import Config +from mmengine.runner import load_checkpoint +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class KnowledgeDistillationSingleStageDetector(SingleStageDetector): + r"""Implementation of `Distilling the Knowledge in a Neural Network. + `_. + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + teacher_config (:obj:`ConfigDict` | dict | str | Path): Config file + path or the config object of teacher model. + teacher_ckpt (str, optional): Checkpoint path of teacher model. + If left as None, the model will not load any weights. + Defaults to True. + eval_teacher (bool): Set the train mode for teacher. + Defaults to True. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of ATSS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of ATSS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + """ + + def __init__( + self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + teacher_config: Union[ConfigType, str, Path], + teacher_ckpt: Optional[str] = None, + eval_teacher: bool = True, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + ) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor) + self.eval_teacher = eval_teacher + # Build teacher model + if isinstance(teacher_config, (str, Path)): + teacher_config = Config.fromfile(teacher_config) + self.teacher_model = MODELS.build(teacher_config['model']) + if teacher_ckpt is not None: + load_checkpoint( + self.teacher_model, teacher_ckpt, map_location='cpu') + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(batch_inputs) + with torch.no_grad(): + teacher_x = self.teacher_model.extract_feat(batch_inputs) + out_teacher = self.teacher_model.bbox_head(teacher_x) + losses = self.bbox_head.loss(x, out_teacher, batch_data_samples) + return losses + + def cuda(self, device: Optional[str] = None) -> nn.Module: + """Since teacher_model is registered as a plain object, it is necessary + to put the teacher model to cuda when calling ``cuda`` function.""" + self.teacher_model.cuda(device=device) + return super().cuda(device=device) + + def to(self, device: Optional[str] = None) -> nn.Module: + """Since teacher_model is registered as a plain object, it is necessary + to put the teacher model to other device when calling ``to`` + function.""" + self.teacher_model.to(device=device) + return super().to(device=device) + + def train(self, mode: bool = True) -> None: + """Set the same train mode for teacher and student model.""" + if self.eval_teacher: + self.teacher_model.train(False) + else: + self.teacher_model.train(mode) + super().train(mode) + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute, i.e. self.name = value + + This reloading prevent the teacher model from being registered as a + nn.Module. The teacher module is registered as a plain object, so that + the teacher parameters will not show up when calling + ``self.parameters``, ``self.modules``, ``self.children`` methods. + """ + if name == 'teacher_model': + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) diff --git a/mmdet/models/detectors/lad.py b/mmdet/models/detectors/lad.py new file mode 100644 index 0000000000000000000000000000000000000000..008f898772988715c67783d9218ff39c4dd95d80 --- /dev/null +++ b/mmdet/models/detectors/lad.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.runner import load_checkpoint +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType +from ..utils.misc import unpack_gt_instances +from .kd_one_stage import KnowledgeDistillationSingleStageDetector + + +@MODELS.register_module() +class LAD(KnowledgeDistillationSingleStageDetector): + """Implementation of `LAD `_.""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + teacher_backbone: ConfigType, + teacher_neck: ConfigType, + teacher_bbox_head: ConfigType, + teacher_ckpt: Optional[str] = None, + eval_teacher: bool = True, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None) -> None: + super(KnowledgeDistillationSingleStageDetector, self).__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor) + self.eval_teacher = eval_teacher + self.teacher_model = nn.Module() + self.teacher_model.backbone = MODELS.build(teacher_backbone) + if teacher_neck is not None: + self.teacher_model.neck = MODELS.build(teacher_neck) + teacher_bbox_head.update(train_cfg=train_cfg) + teacher_bbox_head.update(test_cfg=test_cfg) + self.teacher_model.bbox_head = MODELS.build(teacher_bbox_head) + if teacher_ckpt is not None: + load_checkpoint( + self.teacher_model, teacher_ckpt, map_location='cpu') + + @property + def with_teacher_neck(self) -> bool: + """bool: whether the detector has a teacher_neck""" + return hasattr(self.teacher_model, 'neck') and \ + self.teacher_model.neck is not None + + def extract_teacher_feat(self, batch_inputs: Tensor) -> Tensor: + """Directly extract teacher features from the backbone+neck.""" + x = self.teacher_model.backbone(batch_inputs) + if self.with_teacher_neck: + x = self.teacher_model.neck(x) + return x + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + # get label assignment from the teacher + with torch.no_grad(): + x_teacher = self.extract_teacher_feat(batch_inputs) + outs_teacher = self.teacher_model.bbox_head(x_teacher) + label_assignment_results = \ + self.teacher_model.bbox_head.get_label_assignment( + *outs_teacher, batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + + # the student use the label assignment from the teacher to learn + x = self.extract_feat(batch_inputs) + losses = self.bbox_head.loss(x, label_assignment_results, + batch_data_samples) + return losses diff --git a/mmdet/models/detectors/mask2former.py b/mmdet/models/detectors/mask2former.py new file mode 100644 index 0000000000000000000000000000000000000000..4f38ef44e482039fdf7476d048eee5df2a96fd9b --- /dev/null +++ b/mmdet/models/detectors/mask2former.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .maskformer import MaskFormer + + +@MODELS.register_module() +class Mask2Former(MaskFormer): + r"""Implementation of `Masked-attention Mask + Transformer for Universal Image Segmentation + `_.""" + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + panoptic_head=panoptic_head, + panoptic_fusion_head=panoptic_fusion_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..880ee1e8ac3926d618ef47985549d3214175ee73 --- /dev/null +++ b/mmdet/models/detectors/mask_rcnn.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import ConfigDict + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class MaskRCNN(TwoStageDetector): + """Implementation of `Mask R-CNN `_""" + + def __init__(self, + backbone: ConfigDict, + rpn_head: ConfigDict, + roi_head: ConfigDict, + train_cfg: ConfigDict, + test_cfg: ConfigDict, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e09d3a1041f929113962e42bdf8b169e52dabe25 --- /dev/null +++ b/mmdet/models/detectors/mask_scoring_rcnn.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class MaskScoringRCNN(TwoStageDetector): + """Mask Scoring RCNN. + + https://arxiv.org/abs/1903.00241 + """ + + def __init__(self, + backbone: ConfigType, + rpn_head: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/maskformer.py b/mmdet/models/detectors/maskformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7493c00e1b87cf9b2fbd2c80f1e642f6eb2bea55 --- /dev/null +++ b/mmdet/models/detectors/maskformer.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class MaskFormer(SingleStageDetector): + r"""Implementation of `Per-Pixel Classification is + NOT All You Need for Semantic Segmentation + `_.""" + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head = MODELS.build(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_head.num_things_classes + self.num_stuff_classes = self.panoptic_head.num_stuff_classes + self.num_classes = self.panoptic_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Dict[str, Tensor]: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self.extract_feat(batch_inputs) + losses = self.panoptic_head.loss(x, batch_data_samples) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + feats = self.extract_feat(batch_inputs) + mask_cls_results, mask_pred_results = self.panoptic_head.predict( + feats, batch_data_samples) + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + rescale=rescale) + results = self.add_pred_to_datasample(batch_data_samples, results_list) + + return results + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: List[dict]) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (List[dict]): Instance segmentation, segmantic + segmentation and panoptic segmentation results. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + data_sample.pred_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + data_sample.pred_instances = pred_results['ins_results'] + + assert 'sem_results' not in pred_results, 'segmantic ' \ + 'segmentation results are not supported yet.' + + return data_samples + + def _forward(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + tuple[List[Tensor]]: A tuple of features from ``panoptic_head`` + forward. + """ + feats = self.extract_feat(batch_inputs) + results = self.panoptic_head.forward(feats, batch_data_samples) + return results diff --git a/mmdet/models/detectors/nasfcos.py b/mmdet/models/detectors/nasfcos.py new file mode 100644 index 0000000000000000000000000000000000000000..da2b911bcfc6b0ba51b00d9b3948a3df7af2e74f --- /dev/null +++ b/mmdet/models/detectors/nasfcos.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class NASFCOS(SingleStageDetector): + """Implementation of `NAS-FCOS: Fast Neural Architecture Search for Object + Detection. `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of NASFCOS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of NASFCOS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/paa.py b/mmdet/models/detectors/paa.py new file mode 100644 index 0000000000000000000000000000000000000000..094306b2fbd18ba45536470ec80443e4ff793e67 --- /dev/null +++ b/mmdet/models/detectors/paa.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class PAA(SingleStageDetector): + """Implementation of `PAA `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of PAA. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of PAA. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/panoptic_fpn.py b/mmdet/models/detectors/panoptic_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ae63ccc38931daa60b4e62f94dcf9f44574d3669 --- /dev/null +++ b/mmdet/models/detectors/panoptic_fpn.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .panoptic_two_stage_segmentor import TwoStagePanopticSegmentor + + +@MODELS.register_module() +class PanopticFPN(TwoStagePanopticSegmentor): + r"""Implementation of `Panoptic feature pyramid + networks `_""" + + def __init__( + self, + backbone: ConfigType, + neck: OptConfigType = None, + rpn_head: OptConfigType = None, + roi_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + # for panoptic segmentation + semantic_head: OptConfigType = None, + panoptic_fusion_head: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg, + semantic_head=semantic_head, + panoptic_fusion_head=panoptic_fusion_head) diff --git a/mmdet/models/detectors/panoptic_two_stage_segmentor.py b/mmdet/models/detectors/panoptic_two_stage_segmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..879edbe1ac6a0f482fdd740f4058e508e728414d --- /dev/null +++ b/mmdet/models/detectors/panoptic_two_stage_segmentor.py @@ -0,0 +1,234 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List + +import torch +from mmengine.structures import PixelData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class TwoStagePanopticSegmentor(TwoStageDetector): + """Base class of Two-stage Panoptic Segmentor. + + As well as the components in TwoStageDetector, Panoptic Segmentor has extra + semantic_head and panoptic_fusion_head. + """ + + def __init__( + self, + backbone: ConfigType, + neck: OptConfigType = None, + rpn_head: OptConfigType = None, + roi_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + # for panoptic segmentation + semantic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + if semantic_head is not None: + self.semantic_head = MODELS.build(semantic_head) + + if panoptic_fusion_head is not None: + panoptic_cfg = test_cfg.panoptic if test_cfg is not None else None + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=panoptic_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_fusion_head.\ + num_things_classes + self.num_stuff_classes = self.panoptic_fusion_head.\ + num_stuff_classes + self.num_classes = self.panoptic_fusion_head.num_classes + + @property + def with_semantic_head(self) -> bool: + """bool: whether the detector has semantic head""" + return hasattr(self, + 'semantic_head') and self.semantic_head is not None + + @property + def with_panoptic_fusion_head(self) -> bool: + """bool: whether the detector has panoptic fusion head""" + return hasattr(self, 'panoptic_fusion_head') and \ + self.panoptic_fusion_head is not None + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + x = self.extract_feat(batch_inputs) + + losses = dict() + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + # TODO: Not support currently, should have a check at Fast R-CNN + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + roi_losses = self.roi_head.loss(x, rpn_results_list, + batch_data_samples) + losses.update(roi_losses) + + semantic_loss = self.semantic_head.loss(x, batch_data_samples) + losses.update(semantic_loss) + + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + List[:obj:`DetDataSample`]: Return the packed panoptic segmentation + results of input images. Each DetDataSample usually contains + 'pred_panoptic_seg'. And the 'pred_panoptic_seg' has a key + ``sem_seg``, which is a tensor of shape (1, h, w). + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + x = self.extract_feat(batch_inputs) + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + rpn_results_list = self.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + results_list = self.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=rescale) + + seg_preds = self.semantic_head.predict(x, batch_img_metas, rescale) + + results_list = self.panoptic_fusion_head.predict( + results_list, seg_preds) + + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + # TODO the code has not been verified and needs to be refactored later. + def _forward(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + + Returns: + tuple: A tuple of features from ``rpn_head``, ``roi_head`` and + ``semantic_head`` forward. + """ + results = () + x = self.extract_feat(batch_inputs) + rpn_outs = self.rpn_head.forward(x) + results = results + (rpn_outs) + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + rpn_results_list = self.rpn_head.predict_by_feat( + *rpn_outs, batch_img_metas=batch_img_metas, rescale=False) + else: + # TODO: Not checked currently. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + # roi_head + roi_outs = self.roi_head(x, rpn_results_list) + results = results + (roi_outs) + + # semantic_head + sem_outs = self.semantic_head.forward(x) + results = results + (sem_outs['seg_preds'], ) + + return results + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: List[PixelData]) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`]): The + annotation data of every samples. + results_list (List[PixelData]): Panoptic segmentation results of + each image. + + Returns: + List[:obj:`DetDataSample`]: Return the packed panoptic segmentation + results of input images. Each DetDataSample usually contains + 'pred_panoptic_seg'. And the 'pred_panoptic_seg' has a key + ``sem_seg``, which is a tensor of shape (1, h, w). + """ + + for data_sample, pred_panoptic_seg in zip(data_samples, results_list): + data_sample.pred_panoptic_seg = pred_panoptic_seg + return data_samples diff --git a/mmdet/models/detectors/point_rend.py b/mmdet/models/detectors/point_rend.py new file mode 100644 index 0000000000000000000000000000000000000000..5062ac0c945e79bd53e66e1642aec51113475cad --- /dev/null +++ b/mmdet/models/detectors/point_rend.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import ConfigDict + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class PointRend(TwoStageDetector): + """PointRend: Image Segmentation as Rendering + + This detector is the implementation of + `PointRend `_. + + """ + + def __init__(self, + backbone: ConfigDict, + rpn_head: ConfigDict, + roi_head: ConfigDict, + train_cfg: ConfigDict, + test_cfg: ConfigDict, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) diff --git a/mmdet/models/detectors/queryinst.py b/mmdet/models/detectors/queryinst.py new file mode 100644 index 0000000000000000000000000000000000000000..400ce20c01f5c3825e343f2d32accf740c5dd55c --- /dev/null +++ b/mmdet/models/detectors/queryinst.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .sparse_rcnn import SparseRCNN + + +@MODELS.register_module() +class QueryInst(SparseRCNN): + r"""Implementation of + `Instances as Queries `_""" + + def __init__(self, + backbone: ConfigType, + rpn_head: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/reppoints_detector.py b/mmdet/models/detectors/reppoints_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..d86cec2ecda0671939e227c50f00379e81d3ac9c --- /dev/null +++ b/mmdet/models/detectors/reppoints_detector.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class RepPointsDetector(SingleStageDetector): + """RepPoints: Point Set Representation for Object Detection. + + This detector is the implementation of: + - RepPoints detector (https://arxiv.org/pdf/1904.11490) + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/retinanet.py b/mmdet/models/detectors/retinanet.py new file mode 100644 index 0000000000000000000000000000000000000000..03e3cb20e5bda603e9384d83688a56fa590e6de8 --- /dev/null +++ b/mmdet/models/detectors/retinanet.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class RetinaNet(SingleStageDetector): + """Implementation of `RetinaNet `_""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py new file mode 100644 index 0000000000000000000000000000000000000000..72fe8521fcc9bc796801b2dd68269bb57aaab984 --- /dev/null +++ b/mmdet/models/detectors/rpn.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class RPN(SingleStageDetector): + """Implementation of Region Proposal Network. + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + rpn_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.neck = MODELS.build(neck) if neck is not None else None + rpn_train_cfg = train_cfg['rpn'] if train_cfg is not None else None + rpn_head_num_classes = rpn_head.get('num_classes', 1) + if rpn_head_num_classes != 1: + warnings.warn('The `num_classes` should be 1 in RPN, but get ' + f'{rpn_head_num_classes}, please set ' + 'rpn_head.num_classes = 1 in your config file.') + rpn_head.update(num_classes=1) + rpn_head.update(train_cfg=rpn_train_cfg) + rpn_head.update(test_cfg=test_cfg['rpn']) + self.bbox_head = MODELS.build(rpn_head) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(batch_inputs) + + # set cat_id of gt_labels to 0 in RPN + rpn_data_samples = copy.deepcopy(batch_data_samples) + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + losses = self.bbox_head.loss(x, rpn_data_samples) + return losses diff --git a/mmdet/models/detectors/rtmdet.py b/mmdet/models/detectors/rtmdet.py new file mode 100644 index 0000000000000000000000000000000000000000..b43e053fc41a4b8400bbc0946fffedfa735b9451 --- /dev/null +++ b/mmdet/models/detectors/rtmdet.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dist import get_world_size +from mmengine.logging import print_log + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class RTMDet(SingleStageDetector): + """Implementation of RTMDet. + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of ATSS. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of ATSS. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + use_syncbn (bool): Whether to use SyncBatchNorm. Defaults to True. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + use_syncbn: bool = True) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # TODO: Waiting for mmengine support + if use_syncbn and get_world_size() > 1: + torch.nn.SyncBatchNorm.convert_sync_batchnorm(self) + print_log('Using SyncBatchNorm()', 'current') diff --git a/mmdet/models/detectors/scnet.py b/mmdet/models/detectors/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..606a0203869f1731a21d811f06c4781f5cd90d8d --- /dev/null +++ b/mmdet/models/detectors/scnet.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from .cascade_rcnn import CascadeRCNN + + +@MODELS.register_module() +class SCNet(CascadeRCNN): + """Implementation of `SCNet `_""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) diff --git a/mmdet/models/detectors/semi_base.py b/mmdet/models/detectors/semi_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f0c8c030830e188bf3ad245d5b3cb471ecb04f --- /dev/null +++ b/mmdet/models/detectors/semi_base.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, + reweight_loss_dict) +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_project +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .base import BaseDetector + + +@MODELS.register_module() +class SemiBaseDetector(BaseDetector): + """Base class for semi-supervised detectors. + + Semi-supervised detectors typically consisting of a teacher model + updated by exponential moving average and a student model updated + by gradient descent. + + Args: + detector (:obj:`ConfigDict` or dict): The detector config. + semi_train_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised training config. + semi_test_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised testing config. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + detector: ConfigType, + semi_train_cfg: OptConfigType = None, + semi_test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.student = MODELS.build(detector) + self.teacher = MODELS.build(detector) + self.semi_train_cfg = semi_train_cfg + self.semi_test_cfg = semi_test_cfg + if self.semi_train_cfg.get('freeze_teacher', True) is True: + self.freeze(self.teacher) + + @staticmethod + def freeze(model: nn.Module): + """Freeze the model.""" + model.eval() + for param in model.parameters(): + param.requires_grad = False + + def loss(self, multi_batch_inputs: Dict[str, Tensor], + multi_batch_data_samples: Dict[str, SampleList]) -> dict: + """Calculate losses from multi-branch inputs and data samples. + + Args: + multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch + input images, each value with shape (N, C, H, W). + Each value should usually be mean centered and std scaled. + multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]): + The dict of multi-branch data samples. + + Returns: + dict: A dictionary of loss components + """ + losses = dict() + losses.update(**self.loss_by_gt_instances( + multi_batch_inputs['sup'], multi_batch_data_samples['sup'])) + + origin_pseudo_data_samples, batch_info = self.get_pseudo_instances( + multi_batch_inputs['unsup_teacher'], + multi_batch_data_samples['unsup_teacher']) + multi_batch_data_samples[ + 'unsup_student'] = self.project_pseudo_instances( + origin_pseudo_data_samples, + multi_batch_data_samples['unsup_student']) + losses.update(**self.loss_by_pseudo_instances( + multi_batch_inputs['unsup_student'], + multi_batch_data_samples['unsup_student'], batch_info)) + return losses + + def loss_by_gt_instances(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and ground-truth data + samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + + losses = self.student.loss(batch_inputs, batch_data_samples) + sup_weight = self.semi_train_cfg.get('sup_weight', 1.) + return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight)) + + def loss_by_pseudo_instances(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + batch_info: Optional[dict] = None) -> dict: + """Calculate losses from a batch of inputs and pseudo data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + batch_info (dict): Batch information of teacher model + forward propagation process. Defaults to None. + + Returns: + dict: A dictionary of loss components + """ + batch_data_samples = filter_gt_instances( + batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) + losses = self.student.loss(batch_inputs, batch_data_samples) + pseudo_instances_num = sum([ + len(data_samples.gt_instances) + for data_samples in batch_data_samples + ]) + unsup_weight = self.semi_train_cfg.get( + 'unsup_weight', 1.) if pseudo_instances_num > 0 else 0. + return rename_loss_dict('unsup_', + reweight_loss_dict(losses, unsup_weight)) + + @torch.no_grad() + def get_pseudo_instances( + self, batch_inputs: Tensor, batch_data_samples: SampleList + ) -> Tuple[SampleList, Optional[dict]]: + """Get pseudo instances from teacher model.""" + self.teacher.eval() + results_list = self.teacher.predict( + batch_inputs, batch_data_samples, rescale=False) + batch_info = {} + for data_samples, results in zip(batch_data_samples, results_list): + data_samples.gt_instances = results.pred_instances + data_samples.gt_instances.bboxes = bbox_project( + data_samples.gt_instances.bboxes, + torch.from_numpy(data_samples.homography_matrix).inverse().to( + self.data_preprocessor.device), data_samples.ori_shape) + return batch_data_samples, batch_info + + def project_pseudo_instances(self, batch_pseudo_instances: SampleList, + batch_data_samples: SampleList) -> SampleList: + """Project pseudo instances.""" + for pseudo_instances, data_samples in zip(batch_pseudo_instances, + batch_data_samples): + data_samples.gt_instances = copy.deepcopy( + pseudo_instances.gt_instances) + data_samples.gt_instances.bboxes = bbox_project( + data_samples.gt_instances.bboxes, + torch.tensor(data_samples.homography_matrix).to( + self.data_preprocessor.device), data_samples.img_shape) + wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2)) + return filter_gt_instances(batch_data_samples, wh_thr=wh_thr) + + def predict(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher': + return self.teacher( + batch_inputs, batch_data_samples, mode='predict') + else: + return self.student( + batch_inputs, batch_data_samples, mode='predict') + + def _forward(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> SampleList: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + + Returns: + tuple: A tuple of features from ``rpn_head`` and ``roi_head`` + forward. + """ + if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher': + return self.teacher( + batch_inputs, batch_data_samples, mode='tensor') + else: + return self.student( + batch_inputs, batch_data_samples, mode='tensor') + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have + different resolutions. + """ + if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher': + return self.teacher.extract_feat(batch_inputs) + else: + return self.student.extract_feat(batch_inputs) + + def _load_from_state_dict(self, state_dict: dict, prefix: str, + local_metadata: dict, strict: bool, + missing_keys: Union[List[str], str], + unexpected_keys: Union[List[str], str], + error_msgs: Union[List[str], str]) -> None: + """Add teacher and student prefixes to model parameter names.""" + if not any([ + 'student' in key or 'teacher' in key + for key in state_dict.keys() + ]): + keys = list(state_dict.keys()) + state_dict.update({'teacher.' + k: state_dict[k] for k in keys}) + state_dict.update({'student.' + k: state_dict[k] for k in keys}) + for k in keys: + state_dict.pop(k) + return super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..06c074085967bbc9040d93e5eb446b67a006087e --- /dev/null +++ b/mmdet/models/detectors/single_stage.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .base import BaseDetector + + +@MODELS.register_module() +class SingleStageDetector(BaseDetector): + """Base class for single-stage detectors. + + Single-stage detectors directly and densely predict bounding boxes on the + output features of the backbone+neck. + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + bbox_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.bbox_head = MODELS.build(bbox_head) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def _load_from_state_dict(self, state_dict: dict, prefix: str, + local_metadata: dict, strict: bool, + missing_keys: Union[List[str], str], + unexpected_keys: Union[List[str], str], + error_msgs: Union[List[str], str]) -> None: + """Exchange bbox_head key to rpn_head key when loading two-stage + weights into single-stage model.""" + bbox_head_prefix = prefix + '.bbox_head' if prefix else 'bbox_head' + bbox_head_keys = [ + k for k in state_dict.keys() if k.startswith(bbox_head_prefix) + ] + rpn_head_prefix = prefix + '.rpn_head' if prefix else 'rpn_head' + rpn_head_keys = [ + k for k in state_dict.keys() if k.startswith(rpn_head_prefix) + ] + if len(bbox_head_keys) == 0 and len(rpn_head_keys) != 0: + for rpn_head_key in rpn_head_keys: + bbox_head_key = bbox_head_prefix + \ + rpn_head_key[len(rpn_head_prefix):] + state_dict[bbox_head_key] = state_dict.pop(rpn_head_key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + x = self.extract_feat(batch_inputs) + losses = self.bbox_head.loss(x, batch_data_samples) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + x = self.extract_feat(batch_inputs) + results_list = self.bbox_head.predict( + x, batch_data_samples, rescale=rescale) + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple[list]: A tuple of features from ``bbox_head`` forward. + """ + x = self.extract_feat(batch_inputs) + results = self.bbox_head.forward(x) + return results + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have + different resolutions. + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x diff --git a/mmdet/models/detectors/single_stage_instance_seg.py b/mmdet/models/detectors/single_stage_instance_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..acb5f0d2f8e4636b86b4b66cbf5c4916d0dae16f --- /dev/null +++ b/mmdet/models/detectors/single_stage_instance_seg.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Tuple + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .base import BaseDetector + +INF = 1e8 + + +@MODELS.register_module() +class SingleStageInstanceSegmentor(BaseDetector): + """Base class for single-stage instance segmentors.""" + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + bbox_head: OptConfigType = None, + mask_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + else: + self.neck = None + if bbox_head is not None: + bbox_head.update(train_cfg=copy.deepcopy(train_cfg)) + bbox_head.update(test_cfg=copy.deepcopy(test_cfg)) + self.bbox_head = MODELS.build(bbox_head) + else: + self.bbox_head = None + + assert mask_head, f'`mask_head` must ' \ + f'be implemented in {self.__class__.__name__}' + mask_head.update(train_cfg=copy.deepcopy(train_cfg)) + mask_head.update(test_cfg=copy.deepcopy(test_cfg)) + self.mask_head = MODELS.build(mask_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have different + resolutions. + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + def _forward(self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None, + **kwargs) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + + Returns: + tuple: A tuple of features from ``bbox_head`` forward. + """ + outs = () + # backbone + x = self.extract_feat(batch_inputs) + # bbox_head + positive_infos = None + if self.with_bbox: + assert batch_data_samples is not None + bbox_outs = self.bbox_head.forward(x) + outs = outs + (bbox_outs, ) + # It is necessary to use `bbox_head.loss` to update + # `_raw_positive_infos` which will be used in `get_positive_infos` + # positive_infos will be used in the following mask head. + _ = self.bbox_head.loss(x, batch_data_samples, **kwargs) + positive_infos = self.bbox_head.get_positive_infos() + # mask_head + if positive_infos is None: + mask_outs = self.mask_head.forward(x) + else: + mask_outs = self.mask_head.forward(x, positive_infos) + outs = outs + (mask_outs, ) + return outs + + def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList, + **kwargs) -> dict: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + x = self.extract_feat(batch_inputs) + losses = dict() + + positive_infos = None + # CondInst and YOLACT have bbox_head + if self.with_bbox: + bbox_losses = self.bbox_head.loss(x, batch_data_samples, **kwargs) + losses.update(bbox_losses) + # get positive information from bbox head, which will be used + # in the following mask head. + positive_infos = self.bbox_head.get_positive_infos() + + mask_loss = self.mask_head.loss( + x, batch_data_samples, positive_infos=positive_infos, **kwargs) + # avoid loss override + assert not set(mask_loss.keys()) & set(losses.keys()) + + losses.update(mask_loss) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True, + **kwargs) -> SampleList: + """Perform forward propagation of the mask head and predict mask + results on the features of the upstream network. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to False. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + x = self.extract_feat(batch_inputs) + if self.with_bbox: + # the bbox branch does not need to be scaled to the original + # image scale, because the mask branch will scale both bbox + # and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.bbox_head.predict( + x, batch_data_samples, rescale=bbox_rescale) + else: + results_list = None + + results_list = self.mask_head.predict( + x, batch_data_samples, rescale=rescale, results_list=results_list) + + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples diff --git a/mmdet/models/detectors/soft_teacher.py b/mmdet/models/detectors/soft_teacher.py new file mode 100644 index 0000000000000000000000000000000000000000..80853f1d8399c70008923067777a2581671ede0b --- /dev/null +++ b/mmdet/models/detectors/soft_teacher.py @@ -0,0 +1,378 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, + reweight_loss_dict) +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi, bbox_project +from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig +from ..utils.misc import unpack_gt_instances +from .semi_base import SemiBaseDetector + + +@MODELS.register_module() +class SoftTeacher(SemiBaseDetector): + r"""Implementation of `End-to-End Semi-Supervised Object Detection + with Soft Teacher `_ + + Args: + detector (:obj:`ConfigDict` or dict): The detector config. + semi_train_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised training config. + semi_test_cfg (:obj:`ConfigDict` or dict, optional): + The semi-supervised testing config. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + detector: ConfigType, + semi_train_cfg: OptConfigType = None, + semi_test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + detector=detector, + semi_train_cfg=semi_train_cfg, + semi_test_cfg=semi_test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + def loss_by_pseudo_instances(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + batch_info: Optional[dict] = None) -> dict: + """Calculate losses from a batch of inputs and pseudo data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + batch_info (dict): Batch information of teacher model + forward propagation process. Defaults to None. + + Returns: + dict: A dictionary of loss components + """ + + x = self.student.extract_feat(batch_inputs) + + losses = {} + rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances( + x, batch_data_samples) + losses.update(**rpn_losses) + losses.update(**self.rcnn_cls_loss_by_pseudo_instances( + x, rpn_results_list, batch_data_samples, batch_info)) + losses.update(**self.rcnn_reg_loss_by_pseudo_instances( + x, rpn_results_list, batch_data_samples)) + unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.) + return rename_loss_dict('unsup_', + reweight_loss_dict(losses, unsup_weight)) + + @torch.no_grad() + def get_pseudo_instances( + self, batch_inputs: Tensor, batch_data_samples: SampleList + ) -> Tuple[SampleList, Optional[dict]]: + """Get pseudo instances from teacher model.""" + assert self.teacher.with_bbox, 'Bbox head must be implemented.' + x = self.teacher.extract_feat(batch_inputs) + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + rpn_results_list = self.teacher.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + results_list = self.teacher.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=False) + + for data_samples, results in zip(batch_data_samples, results_list): + data_samples.gt_instances = results + + batch_data_samples = filter_gt_instances( + batch_data_samples, + score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr) + + reg_uncs_list = self.compute_uncertainty_with_aug( + x, batch_data_samples) + + for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list): + data_samples.gt_instances['reg_uncs'] = reg_uncs + data_samples.gt_instances.bboxes = bbox_project( + data_samples.gt_instances.bboxes, + torch.from_numpy(data_samples.homography_matrix).inverse().to( + self.data_preprocessor.device), data_samples.ori_shape) + + batch_info = { + 'feat': x, + 'img_shape': [], + 'homography_matrix': [], + 'metainfo': [] + } + for data_samples in batch_data_samples: + batch_info['img_shape'].append(data_samples.img_shape) + batch_info['homography_matrix'].append( + torch.from_numpy(data_samples.homography_matrix).to( + self.data_preprocessor.device)) + batch_info['metainfo'].append(data_samples.metainfo) + return batch_data_samples, batch_info + + def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + """Calculate rpn loss from a batch of inputs and pseudo data samples. + + Args: + x (tuple[Tensor]): Features from FPN. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + Returns: + dict: A dictionary of rpn loss components + """ + + rpn_data_samples = copy.deepcopy(batch_data_samples) + rpn_data_samples = filter_gt_instances( + rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr) + proposal_cfg = self.student.train_cfg.get('rpn_proposal', + self.student.test_cfg.rpn) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + for key in rpn_losses.keys(): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + return rpn_losses, rpn_results_list + + def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor], + unsup_rpn_results_list: InstanceList, + batch_data_samples: SampleList, + batch_info: dict) -> dict: + """Calculate classification loss from a batch of inputs and pseudo data + samples. + + Args: + x (tuple[Tensor]): List of multi-level img features. + unsup_rpn_results_list (list[:obj:`InstanceData`]): + List of region proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + batch_info (dict): Batch information of teacher model + forward propagation process. + + Returns: + dict[str, Tensor]: A dictionary of rcnn + classification loss components + """ + rpn_results_list = copy.deepcopy(unsup_rpn_results_list) + cls_data_samples = copy.deepcopy(batch_data_samples) + cls_data_samples = filter_gt_instances( + cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) + + outputs = unpack_gt_instances(cls_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + # assign gts and sample proposals + num_imgs = len(cls_data_samples) + sampling_results = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + assign_result = self.student.roi_head.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.student.roi_head.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + selected_bboxes = [res.priors for res in sampling_results] + rois = bbox2roi(selected_bboxes) + bbox_results = self.student.roi_head._bbox_forward(x, rois) + # cls_reg_targets is a tuple of labels, label_weights, + # and bbox_targets, bbox_weights + cls_reg_targets = self.student.roi_head.bbox_head.get_targets( + sampling_results, self.student.train_cfg.rcnn) + + selected_results_list = [] + for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip( + selected_bboxes, batch_data_samples, + batch_info['homography_matrix'], batch_info['img_shape']): + student_matrix = torch.tensor( + data_samples.homography_matrix, device=teacher_matrix.device) + homography_matrix = teacher_matrix @ student_matrix.inverse() + projected_bboxes = bbox_project(bboxes, homography_matrix, + teacher_img_shape) + selected_results_list.append(InstanceData(bboxes=projected_bboxes)) + + with torch.no_grad(): + results_list = self.teacher.roi_head.predict_bbox( + batch_info['feat'], + batch_info['metainfo'], + selected_results_list, + rcnn_test_cfg=None, + rescale=False) + bg_score = torch.cat( + [results.scores[:, -1] for results in results_list]) + # cls_reg_targets[0] is labels + neg_inds = cls_reg_targets[ + 0] == self.student.roi_head.bbox_head.num_classes + # cls_reg_targets[1] is label_weights + cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach() + + losses = self.student.roi_head.bbox_head.loss( + bbox_results['cls_score'], bbox_results['bbox_pred'], rois, + *cls_reg_targets) + # cls_reg_targets[1] is label_weights + losses['loss_cls'] = losses['loss_cls'] * len( + cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0) + return losses + + def rcnn_reg_loss_by_pseudo_instances( + self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Calculate rcnn regression loss from a batch of inputs and pseudo + data samples. + + Args: + x (tuple[Tensor]): List of multi-level img features. + unsup_rpn_results_list (list[:obj:`InstanceData`]): + List of region proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + + Returns: + dict[str, Tensor]: A dictionary of rcnn + regression loss components + """ + rpn_results_list = copy.deepcopy(unsup_rpn_results_list) + reg_data_samples = copy.deepcopy(batch_data_samples) + for data_samples in reg_data_samples: + if data_samples.gt_instances.bboxes.shape[0] > 0: + data_samples.gt_instances = data_samples.gt_instances[ + data_samples.gt_instances.reg_uncs < + self.semi_train_cfg.reg_pseudo_thr] + roi_losses = self.student.roi_head.loss(x, rpn_results_list, + reg_data_samples) + return {'loss_bbox': roi_losses['loss_bbox']} + + def compute_uncertainty_with_aug( + self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> List[Tensor]: + """Compute uncertainty with augmented bboxes. + + Args: + x (tuple[Tensor]): List of multi-level img features. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, + which are `pseudo_instance` or `pseudo_panoptic_seg` + or `pseudo_sem_seg` in fact. + + Returns: + list[Tensor]: A list of uncertainty for pseudo bboxes. + """ + auged_results_list = self.aug_box(batch_data_samples, + self.semi_train_cfg.jitter_times, + self.semi_train_cfg.jitter_scale) + # flatten + auged_results_list = [ + InstanceData(bboxes=auged.reshape(-1, auged.shape[-1])) + for auged in auged_results_list + ] + + self.teacher.roi_head.test_cfg = None + results_list = self.teacher.roi_head.predict( + x, auged_results_list, batch_data_samples, rescale=False) + self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn + + reg_channel = max( + [results.bboxes.shape[-1] for results in results_list]) // 4 + bboxes = [ + results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1, + results.bboxes.shape[-1]) + if results.bboxes.numel() > 0 else results.bboxes.new_zeros( + self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float() + for results in results_list + ] + + box_unc = [bbox.std(dim=0) for bbox in bboxes] + bboxes = [bbox.mean(dim=0) for bbox in bboxes] + labels = [ + data_samples.gt_instances.labels + for data_samples in batch_data_samples + ] + if reg_channel != 1: + bboxes = [ + bbox.reshape(bbox.shape[0], reg_channel, + 4)[torch.arange(bbox.shape[0]), label] + for bbox, label in zip(bboxes, labels) + ] + box_unc = [ + unc.reshape(unc.shape[0], reg_channel, + 4)[torch.arange(unc.shape[0]), label] + for unc, label in zip(box_unc, labels) + ] + + box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) + for bbox in bboxes] + box_unc = [ + torch.mean( + unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1) + if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape) + ] + return box_unc + + @staticmethod + def aug_box(batch_data_samples, times, frac): + """Augment bboxes with jitter.""" + + def _aug_single(box): + box_scale = box[:, 2:4] - box[:, :2] + box_scale = ( + box_scale.clamp(min=1)[:, None, :].expand(-1, 2, + 2).reshape(-1, 4)) + aug_scale = box_scale * frac # [n,4] + + offset = ( + torch.randn(times, box.shape[0], 4, device=box.device) * + aug_scale[None, ...]) + new_box = box.clone()[None, ...].expand(times, box.shape[0], + -1) + offset + return new_box + + return [ + _aug_single(data_samples.gt_instances.bboxes) + for data_samples in batch_data_samples + ] diff --git a/mmdet/models/detectors/solo.py b/mmdet/models/detectors/solo.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf47ba24941e09fd795b241a3f6aa0b67ae3380 --- /dev/null +++ b/mmdet/models/detectors/solo.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@MODELS.register_module() +class SOLO(SingleStageInstanceSegmentor): + """`SOLO: Segmenting Objects by Locations + `_ + + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + bbox_head: OptConfigType = None, + mask_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/solov2.py b/mmdet/models/detectors/solov2.py new file mode 100644 index 0000000000000000000000000000000000000000..1eefe4c532267be1480d13b8d73fc54bf694e81c --- /dev/null +++ b/mmdet/models/detectors/solov2.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@MODELS.register_module() +class SOLOv2(SingleStageInstanceSegmentor): + """`SOLOv2: Dynamic and Fast Instance Segmentation + `_ + + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + bbox_head: OptConfigType = None, + mask_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/sparse_rcnn.py b/mmdet/models/detectors/sparse_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..75442a69e472953854ded9fc8c30ac4ab30535d3 --- /dev/null +++ b/mmdet/models/detectors/sparse_rcnn.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .two_stage import TwoStageDetector + + +@MODELS.register_module() +class SparseRCNN(TwoStageDetector): + r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with + Learnable Proposals `_""" + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + rpn_head: OptConfigType = None, + roi_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + assert self.with_rpn, 'Sparse R-CNN and QueryInst ' \ + 'do not support external proposals' diff --git a/mmdet/models/detectors/tood.py b/mmdet/models/detectors/tood.py new file mode 100644 index 0000000000000000000000000000000000000000..38720482c5451471f5a66a6cf689dbed6100c9fa --- /dev/null +++ b/mmdet/models/detectors/tood.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class TOOD(SingleStageDetector): + r"""Implementation of `TOOD: Task-aligned One-stage Object Detection. + `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of TOOD. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of TOOD. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/trident_faster_rcnn.py b/mmdet/models/detectors/trident_faster_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4244925beaebea820f836b41ab5463f5f499f4d0 --- /dev/null +++ b/mmdet/models/detectors/trident_faster_rcnn.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .faster_rcnn import FasterRCNN + + +@MODELS.register_module() +class TridentFasterRCNN(FasterRCNN): + """Implementation of `TridentNet `_""" + + def __init__(self, + backbone: ConfigType, + rpn_head: ConfigType, + roi_head: ConfigType, + train_cfg: ConfigType, + test_cfg: ConfigType, + neck: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + assert self.backbone.num_branch == self.roi_head.num_branch + assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx + self.num_branch = self.backbone.num_branch + self.test_branch_idx = self.backbone.test_branch_idx + + def _forward(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> tuple: + """copy the ``batch_data_samples`` to fit multi-branch.""" + num_branch = self.num_branch \ + if self.training or self.test_branch_idx == -1 else 1 + trident_data_samples = batch_data_samples * num_branch + return super()._forward( + batch_inputs=batch_inputs, batch_data_samples=trident_data_samples) + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """copy the ``batch_data_samples`` to fit multi-branch.""" + num_branch = self.num_branch \ + if self.training or self.test_branch_idx == -1 else 1 + trident_data_samples = batch_data_samples * num_branch + return super().loss( + batch_inputs=batch_inputs, batch_data_samples=trident_data_samples) + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """copy the ``batch_data_samples`` to fit multi-branch.""" + num_branch = self.num_branch \ + if self.training or self.test_branch_idx == -1 else 1 + trident_data_samples = batch_data_samples * num_branch + return super().predict( + batch_inputs=batch_inputs, + batch_data_samples=trident_data_samples, + rescale=rescale) + + # TODO need to refactor + def aug_test(self, imgs, img_metas, rescale=False): + """Test with augmentations. + + If rescale is False, then returned bboxes and masks will fit the scale + of imgs[0]. + """ + x = self.extract_feats(imgs) + num_branch = (self.num_branch if self.test_branch_idx == -1 else 1) + trident_img_metas = [img_metas * num_branch for img_metas in img_metas] + proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas) + return self.roi_head.aug_test( + x, proposal_list, img_metas, rescale=rescale) diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..4e83df9eb5ce837636e10c4592fe26a7edce1657 --- /dev/null +++ b/mmdet/models/detectors/two_stage.py @@ -0,0 +1,243 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import List, Tuple, Union + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .base import BaseDetector + + +@MODELS.register_module() +class TwoStageDetector(BaseDetector): + """Base class for two-stage detectors. + + Two-stage detectors typically consisting of a region proposal network and a + task-specific regression head. + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + rpn_head: OptConfigType = None, + roi_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + + if neck is not None: + self.neck = MODELS.build(neck) + + if rpn_head is not None: + rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None + rpn_head_ = rpn_head.copy() + rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) + rpn_head_num_classes = rpn_head_.get('num_classes', None) + if rpn_head_num_classes is None: + rpn_head_.update(num_classes=1) + else: + if rpn_head_num_classes != 1: + warnings.warn( + 'The `num_classes` should be 1 in RPN, but get ' + f'{rpn_head_num_classes}, please set ' + 'rpn_head.num_classes = 1 in your config file.') + rpn_head_.update(num_classes=1) + self.rpn_head = MODELS.build(rpn_head_) + + if roi_head is not None: + # update train and test cfg here for now + # TODO: refactor assigner & sampler + rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None + roi_head.update(train_cfg=rcnn_train_cfg) + roi_head.update(test_cfg=test_cfg.rcnn) + self.roi_head = MODELS.build(roi_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def _load_from_state_dict(self, state_dict: dict, prefix: str, + local_metadata: dict, strict: bool, + missing_keys: Union[List[str], str], + unexpected_keys: Union[List[str], str], + error_msgs: Union[List[str], str]) -> None: + """Exchange bbox_head key to rpn_head key when loading single-stage + weights into two-stage model.""" + bbox_head_prefix = prefix + '.bbox_head' if prefix else 'bbox_head' + bbox_head_keys = [ + k for k in state_dict.keys() if k.startswith(bbox_head_prefix) + ] + rpn_head_prefix = prefix + '.rpn_head' if prefix else 'rpn_head' + rpn_head_keys = [ + k for k in state_dict.keys() if k.startswith(rpn_head_prefix) + ] + if len(bbox_head_keys) != 0 and len(rpn_head_keys) == 0: + for bbox_head_key in bbox_head_keys: + rpn_head_key = rpn_head_prefix + \ + bbox_head_key[len(bbox_head_prefix):] + state_dict[rpn_head_key] = state_dict.pop(bbox_head_key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) + + @property + def with_rpn(self) -> bool: + """bool: whether the detector has RPN""" + return hasattr(self, 'rpn_head') and self.rpn_head is not None + + @property + def with_roi_head(self) -> bool: + """bool: whether the detector has a RoI head""" + return hasattr(self, 'roi_head') and self.roi_head is not None + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have + different resolutions. + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + def _forward(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple: A tuple of features from ``rpn_head`` and ``roi_head`` + forward. + """ + results = () + x = self.extract_feat(batch_inputs) + + if self.with_rpn: + rpn_results_list = self.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + assert batch_data_samples[0].get('proposals', None) is not None + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + roi_outs = self.roi_head.forward(x, rpn_results_list, + batch_data_samples) + results = results + (roi_outs, ) + return results + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + x = self.extract_feat(batch_inputs) + + losses = dict() + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + roi_losses = self.roi_head.loss(x, rpn_results_list, + batch_data_samples) + losses.update(roi_losses) + + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + + assert self.with_bbox, 'Bbox head must be implemented.' + x = self.extract_feat(batch_inputs) + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + rpn_results_list = self.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + results_list = self.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=rescale) + + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples diff --git a/mmdet/models/detectors/vfnet.py b/mmdet/models/detectors/vfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a695513faa7d37756d7716cbca0e457060400518 --- /dev/null +++ b/mmdet/models/detectors/vfnet.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class VFNet(SingleStageDetector): + """Implementation of `VarifocalNet + (VFNet).`_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of VFNet. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of VFNet. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/yolact.py b/mmdet/models/detectors/yolact.py new file mode 100644 index 0000000000000000000000000000000000000000..f15fb7b70263b0c4018751067771b1365af96f67 --- /dev/null +++ b/mmdet/models/detectors/yolact.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@MODELS.register_module() +class YOLACT(SingleStageInstanceSegmentor): + """Implementation of `YOLACT `_""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + mask_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/yolo.py b/mmdet/models/detectors/yolo.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb9a9cd250a2c26af22032b1ed4bb5a7a8af605 --- /dev/null +++ b/mmdet/models/detectors/yolo.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2019 Western Digital Corporation or its affiliates. + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class YOLOV3(SingleStageDetector): + r"""Implementation of `Yolov3: An incremental improvement + `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of YOLOX. Default: None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of YOLOX. Default: None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): + Model preprocessing config for processing the input data. + it usually includes ``to_rgb``, ``pad_size_divisor``, + ``pad_value``, ``mean`` and ``std``. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/yolof.py b/mmdet/models/detectors/yolof.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d98b9134a7f422fa7ea1f1a1e0d548d36603e8 --- /dev/null +++ b/mmdet/models/detectors/yolof.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class YOLOF(SingleStageDetector): + r"""Implementation of `You Only Look One-level Feature + `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of YOLOF. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of YOLOF. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): + Model preprocessing config for processing the input data. + it usually includes ``to_rgb``, ``pad_size_divisor``, + ``pad_value``, ``mean`` and ``std``. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/detectors/yolox.py b/mmdet/models/detectors/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..df9190c93f7b043910fbce3bd5ee8dc0ef7b5f68 --- /dev/null +++ b/mmdet/models/detectors/yolox.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage import SingleStageDetector + + +@MODELS.register_module() +class YOLOX(SingleStageDetector): + r"""Implementation of `YOLOX: Exceeding YOLO Series in 2021 + `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of YOLOX. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of YOLOX. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/language_models/__init__.py b/mmdet/models/language_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70f1a22c7c01624ba3235f1737f8aea1e26a19fe --- /dev/null +++ b/mmdet/models/language_models/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bert import BertModel + +__all__ = ['BertModel'] diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..efb0f46bad6eb0734a324c32a7b05f2795604265 --- /dev/null +++ b/mmdet/models/language_models/bert.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import Sequence + +import torch +from mmengine.model import BaseModel +from torch import nn + +try: + from transformers import AutoTokenizer, BertConfig + from transformers import BertModel as HFBertModel +except ImportError: + AutoTokenizer = None + HFBertModel = None + +from mmdet.registry import MODELS + + +def generate_masks_with_special_tokens_and_transfer_map( + tokenized, special_tokens_list): + """Generate attention mask between each pair of special tokens. + + Only token pairs in between two special tokens are attended to + and thus the attention mask for these pairs is positive. + + Args: + input_ids (torch.Tensor): input ids. Shape: [bs, num_token] + special_tokens_mask (list): special tokens mask. + + Returns: + Tuple(Tensor, Tensor): + - attention_mask is the attention mask between each tokens. + Only token pairs in between two special tokens are positive. + Shape: [bs, num_token, num_token]. + - position_ids is the position id of tokens within each valid sentence. + The id starts from 0 whenenver a special token is encountered. + Shape: [bs, num_token] + """ + input_ids = tokenized['input_ids'] + bs, num_token = input_ids.shape + # special_tokens_mask: + # bs, num_token. 1 for special tokens. 0 for normal tokens + special_tokens_mask = torch.zeros((bs, num_token), + device=input_ids.device).bool() + + for special_token in special_tokens_list: + special_tokens_mask |= input_ids == special_token + + # idxs: each row is a list of indices of special tokens + idxs = torch.nonzero(special_tokens_mask) + + # generate attention mask and positional ids + attention_mask = ( + torch.eye(num_token, + device=input_ids.device).bool().unsqueeze(0).repeat( + bs, 1, 1)) + position_ids = torch.zeros((bs, num_token), device=input_ids.device) + previous_col = 0 + for i in range(idxs.shape[0]): + row, col = idxs[i] + if (col == 0) or (col == num_token - 1): + attention_mask[row, col, col] = True + position_ids[row, col] = 0 + else: + attention_mask[row, previous_col + 1:col + 1, + previous_col + 1:col + 1] = True + position_ids[row, previous_col + 1:col + 1] = torch.arange( + 0, col - previous_col, device=input_ids.device) + previous_col = col + + return attention_mask, position_ids.to(torch.long) + + +@MODELS.register_module() +class BertModel(BaseModel): + """BERT model for language embedding only encoder. + + Args: + name (str, optional): name of the pretrained BERT model from + HuggingFace. Defaults to bert-base-uncased. + max_tokens (int, optional): maximum number of tokens to be + used for BERT. Defaults to 256. + pad_to_max (bool, optional): whether to pad the tokens to max_tokens. + Defaults to True. + use_sub_sentence_represent (bool, optional): whether to use sub + sentence represent introduced in `Grounding DINO + `. Defaults to False. + special_tokens_list (list, optional): special tokens used to split + subsentence. It cannot be None when `use_sub_sentence_represent` + is True. Defaults to None. + add_pooling_layer (bool, optional): whether to adding pooling + layer in bert encoder. Defaults to False. + num_layers_of_embedded (int, optional): number of layers of + the embedded model. Defaults to 1. + use_checkpoint (bool, optional): whether to use gradient checkpointing. + Defaults to False. + """ + + def __init__(self, + name: str = 'bert-base-uncased', + max_tokens: int = 256, + pad_to_max: bool = True, + use_sub_sentence_represent: bool = False, + special_tokens_list: list = None, + add_pooling_layer: bool = False, + num_layers_of_embedded: int = 1, + use_checkpoint: bool = False, + **kwargs) -> None: + + super().__init__(**kwargs) + self.max_tokens = max_tokens + self.pad_to_max = pad_to_max + + if AutoTokenizer is None: + raise RuntimeError( + 'transformers is not installed, please install it by: ' + 'pip install transformers.') + + self.tokenizer = AutoTokenizer.from_pretrained(name) + self.language_backbone = nn.Sequential( + OrderedDict([('body', + BertEncoder( + name, + add_pooling_layer=add_pooling_layer, + num_layers_of_embedded=num_layers_of_embedded, + use_checkpoint=use_checkpoint))])) + + self.use_sub_sentence_represent = use_sub_sentence_represent + if self.use_sub_sentence_represent: + assert special_tokens_list is not None, \ + 'special_tokens should not be None \ + if use_sub_sentence_represent is True' + + self.special_tokens = self.tokenizer.convert_tokens_to_ids( + special_tokens_list) + + def forward(self, captions: Sequence[str], **kwargs) -> dict: + """Forward function.""" + device = next(self.language_backbone.parameters()).device + tokenized = self.tokenizer.batch_encode_plus( + captions, + max_length=self.max_tokens, + padding='max_length' if self.pad_to_max else 'longest', + return_special_tokens_mask=True, + return_tensors='pt', + truncation=True).to(device) + input_ids = tokenized.input_ids + if self.use_sub_sentence_represent: + attention_mask, position_ids = \ + generate_masks_with_special_tokens_and_transfer_map( + tokenized, self.special_tokens) + token_type_ids = tokenized['token_type_ids'] + + else: + attention_mask = tokenized.attention_mask + position_ids = None + token_type_ids = None + + tokenizer_input = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'token_type_ids': token_type_ids + } + language_dict_features = self.language_backbone(tokenizer_input) + if self.use_sub_sentence_represent: + language_dict_features['position_ids'] = position_ids + language_dict_features[ + 'text_token_mask'] = tokenized.attention_mask.bool() + return language_dict_features + + +class BertEncoder(nn.Module): + """BERT encoder for language embedding. + + Args: + name (str): name of the pretrained BERT model from HuggingFace. + Defaults to bert-base-uncased. + add_pooling_layer (bool): whether to add a pooling layer. + num_layers_of_embedded (int): number of layers of the embedded model. + Defaults to 1. + use_checkpoint (bool): whether to use gradient checkpointing. + Defaults to False. + """ + + def __init__(self, + name: str, + add_pooling_layer: bool = False, + num_layers_of_embedded: int = 1, + use_checkpoint: bool = False): + super().__init__() + if BertConfig is None: + raise RuntimeError( + 'transformers is not installed, please install it by: ' + 'pip install transformers.') + config = BertConfig.from_pretrained(name) + config.gradient_checkpointing = use_checkpoint + # only encoder + self.model = HFBertModel.from_pretrained( + name, add_pooling_layer=add_pooling_layer, config=config) + self.language_dim = config.hidden_size + self.num_layers_of_embedded = num_layers_of_embedded + + def forward(self, x) -> dict: + mask = x['attention_mask'] + + outputs = self.model( + input_ids=x['input_ids'], + attention_mask=mask, + position_ids=x['position_ids'], + token_type_ids=x['token_type_ids'], + output_hidden_states=True, + ) + + # outputs has 13 layers, 1 input layer and 12 hidden layers + encoded_layers = outputs.hidden_states[1:] + features = torch.stack(encoded_layers[-self.num_layers_of_embedded:], + 1).mean(1) + # language embedding has shape [len(phrase), seq_len, language_dim] + features = features / self.num_layers_of_embedded + if mask.dim() == 2: + embedded = features * mask.unsqueeze(-1).float() + else: + embedded = features + + results = { + 'embedded': embedded, + 'masks': mask, + 'hidden': encoded_layers[-1] + } + return results diff --git a/mmdet/models/layers/__init__.py b/mmdet/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c41f64d11bbdb7f2c8e128a2e28b2845159589 --- /dev/null +++ b/mmdet/models/layers/__init__.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .activations import SiLU +from .bbox_nms import fast_nms, multiclass_nms +from .brick_wrappers import (AdaptiveAvgPool2d, FrozenBatchNorm2d, + adaptive_avg_pool2d) +from .conv_upsample import ConvUpsample +from .csp_layer import CSPLayer +from .dropblock import DropBlock +from .ema import ExpMomentumEMA +from .inverted_residual import InvertedResidual +from .matrix_nms import mask_matrix_nms +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder +from .normed_predictor import NormedConv2d, NormedLinear +from .pixel_decoder import PixelDecoder, TransformerEncoderPixelDecoder +from .positional_encoding import (LearnedPositionalEncoding, + SinePositionalEncoding, + SinePositionalEncoding3D) +from .res_layer import ResLayer, SimplifiedBasicBlock +from .se_layer import ChannelAttention, DyReLU, SELayer +# yapf: disable +from .transformer import (MLP, AdaptivePadding, CdnQueryGenerator, + ConditionalAttention, + ConditionalDetrTransformerDecoder, + ConditionalDetrTransformerDecoderLayer, + DABDetrTransformerDecoder, + DABDetrTransformerDecoderLayer, + DABDetrTransformerEncoder, DDQTransformerDecoder, + DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer, + DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer, + DinoTransformerDecoder, DynamicConv, + Mask2FormerTransformerDecoder, + Mask2FormerTransformerDecoderLayer, + Mask2FormerTransformerEncoder, PatchEmbed, + PatchMerging, coordinate_to_encoding, + inverse_sigmoid, nchw_to_nlc, nlc_to_nchw) + +# yapf: enable + +__all__ = [ + 'fast_nms', 'multiclass_nms', 'mask_matrix_nms', 'DropBlock', + 'PixelDecoder', 'TransformerEncoderPixelDecoder', + 'MSDeformAttnPixelDecoder', 'ResLayer', 'PatchMerging', + 'SinePositionalEncoding', 'LearnedPositionalEncoding', 'DynamicConv', + 'SimplifiedBasicBlock', 'NormedLinear', 'NormedConv2d', 'InvertedResidual', + 'SELayer', 'ConvUpsample', 'CSPLayer', 'adaptive_avg_pool2d', + 'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'DyReLU', + 'ExpMomentumEMA', 'inverse_sigmoid', 'ChannelAttention', 'SiLU', 'MLP', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer', 'AdaptivePadding', + 'coordinate_to_encoding', 'ConditionalAttention', + 'DABDetrTransformerDecoderLayer', 'DABDetrTransformerDecoder', + 'DABDetrTransformerEncoder', 'DDQTransformerDecoder', + 'ConditionalDetrTransformerDecoder', + 'ConditionalDetrTransformerDecoderLayer', 'DinoTransformerDecoder', + 'CdnQueryGenerator', 'Mask2FormerTransformerEncoder', + 'Mask2FormerTransformerDecoderLayer', 'Mask2FormerTransformerDecoder', + 'SinePositionalEncoding3D', 'FrozenBatchNorm2d' +] diff --git a/mmdet/models/layers/__pycache__/__init__.cpython-311.pyc b/mmdet/models/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee35a0daaddf1c1f751acdfd077bace8fc97f392 Binary files /dev/null and b/mmdet/models/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/models/layers/__pycache__/activations.cpython-311.pyc b/mmdet/models/layers/__pycache__/activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4066259e905aab8fce62c71574a84098ee45a73 Binary files /dev/null and b/mmdet/models/layers/__pycache__/activations.cpython-311.pyc differ diff --git a/mmdet/models/layers/__pycache__/bbox_nms.cpython-311.pyc b/mmdet/models/layers/__pycache__/bbox_nms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60d8a9fc52e622a17a44d32612e998aba538610 Binary files /dev/null and b/mmdet/models/layers/__pycache__/bbox_nms.cpython-311.pyc differ diff --git a/mmdet/models/layers/activations.py b/mmdet/models/layers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..9e73ef42180ccd3dddb4bcca224c0b4eb5da807c --- /dev/null +++ b/mmdet/models/layers/activations.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.utils import digit_version + +from mmdet.registry import MODELS + +if digit_version(torch.__version__) >= digit_version('1.7.0'): + from torch.nn import SiLU +else: + + class SiLU(nn.Module): + """Sigmoid Weighted Liner Unit.""" + + def __init__(self, inplace=True): + super().__init__() + + def forward(self, inputs) -> torch.Tensor: + return inputs * torch.sigmoid(inputs) + + +MODELS.register_module(module=SiLU, name='SiLU') diff --git a/mmdet/models/layers/bbox_nms.py b/mmdet/models/layers/bbox_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..fd67a45f60ca98c354e095127ab7dbb9653deca5 --- /dev/null +++ b/mmdet/models/layers/bbox_nms.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmcv.ops.nms import batched_nms +from torch import Tensor + +from mmdet.structures.bbox import bbox_overlaps +from mmdet.utils import ConfigType + + +def multiclass_nms( + multi_bboxes: Tensor, + multi_scores: Tensor, + score_thr: float, + nms_cfg: ConfigType, + max_num: int = -1, + score_factors: Optional[Tensor] = None, + return_inds: bool = False, + box_dim: int = 4 +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_cfg (Union[:obj:`ConfigDict`, dict]): a dict that contains + the arguments of nms operations. + max_num (int, optional): if there are more than max_num bboxes after + NMS, only top max_num will be kept. Default to -1. + score_factors (Tensor, optional): The factors multiplied to scores + before applying NMS. Default to None. + return_inds (bool, optional): Whether return the indices of kept + bboxes. Default to False. + box_dim (int): The dimension of boxes. Defaults to 4. + + Returns: + Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: + (dets, labels, indices (optional)), tensors of shape (k, 5), + (k), and (k). Dets are boxes with scores. Labels are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + # exclude background category + if multi_bboxes.shape[1] > box_dim: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, box_dim) + + scores = multi_scores[:, :-1] + + labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) + labels = labels.view(1, -1).expand_as(scores) + + bboxes = bboxes.reshape(-1, box_dim) + scores = scores.reshape(-1) + labels = labels.reshape(-1) + + if not torch.onnx.is_in_onnx_export(): + # NonZero not supported in TensorRT + # remove low scoring boxes + valid_mask = scores > score_thr + # multiply score_factor after threshold to preserve more bboxes, improve + # mAP by 1% for YOLOv3 + if score_factors is not None: + # expand the shape to match original shape of score + score_factors = score_factors.view(-1, 1).expand( + multi_scores.size(0), num_classes) + score_factors = score_factors.reshape(-1) + scores = scores * score_factors + + if not torch.onnx.is_in_onnx_export(): + # NonZero not supported in TensorRT + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + else: + # TensorRT NMS plugin has invalid output filled with -1 + # add dummy data to make detection output correct. + bboxes = torch.cat([bboxes, bboxes.new_zeros(1, box_dim)], dim=0) + scores = torch.cat([scores, scores.new_zeros(1)], dim=0) + labels = torch.cat([labels, labels.new_zeros(1)], dim=0) + + if bboxes.numel() == 0: + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + dets = torch.cat([bboxes, scores[:, None]], -1) + if return_inds: + return dets, labels, inds + else: + return dets, labels + + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + if return_inds: + return dets, labels[keep], inds[keep] + else: + return dets, labels[keep] + + +def fast_nms( + multi_bboxes: Tensor, + multi_scores: Tensor, + multi_coeffs: Tensor, + score_thr: float, + iou_thr: float, + top_k: int, + max_num: int = -1 +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: + """Fast NMS in `YOLACT `_. + + Fast NMS allows already-removed detections to suppress other detections so + that every instance can be decided to be kept or discarded in parallel, + which is not possible in traditional NMS. This relaxation allows us to + implement Fast NMS entirely in standard GPU-accelerated matrix operations. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class+1), where the last column + contains scores of the background class, but this will be ignored. + multi_coeffs (Tensor): shape (n, #class*coeffs_dim). + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + iou_thr (float): IoU threshold to be considered as conflicted. + top_k (int): if there are more than top_k bboxes before NMS, + only top top_k will be kept. + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. If -1, keep all the bboxes. + Default: -1. + + Returns: + Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: + (dets, labels, coefficients), tensors of shape (k, 5), (k, 1), + and (k, coeffs_dim). Dets are boxes with scores. + Labels are 0-based. + """ + + scores = multi_scores[:, :-1].t() # [#class, n] + scores, idx = scores.sort(1, descending=True) + + idx = idx[:, :top_k].contiguous() + scores = scores[:, :top_k] # [#class, topk] + num_classes, num_dets = idx.size() + boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4) + coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1) + + iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk] + iou.triu_(diagonal=1) + iou_max, _ = iou.max(dim=1) + + # Now just filter out the ones higher than the threshold + keep = iou_max <= iou_thr + + # Second thresholding introduces 0.2 mAP gain at negligible time cost + keep *= scores > score_thr + + # Assign each kept detection to its corresponding class + classes = torch.arange( + num_classes, device=boxes.device)[:, None].expand_as(keep) + classes = classes[keep] + + boxes = boxes[keep] + coeffs = coeffs[keep] + scores = scores[keep] + + # Only keep the top max_num highest scores across all classes + scores, idx = scores.sort(0, descending=True) + if max_num > 0: + idx = idx[:max_num] + scores = scores[:max_num] + + classes = classes[idx] + boxes = boxes[idx] + coeffs = coeffs[idx] + + cls_dets = torch.cat([boxes, scores[:, None]], dim=1) + return cls_dets, classes, coeffs diff --git a/mmdet/models/layers/brick_wrappers.py b/mmdet/models/layers/brick_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..5ecb8499de329132561dfedb8f55c36080787b31 --- /dev/null +++ b/mmdet/models/layers/brick_wrappers.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version + +from mmdet.registry import MODELS + +if torch.__version__ == 'parrots': + TORCH_VERSION = torch.__version__ +else: + # torch.__version__ could be 1.3.1+cu92, we only need the first two + # for comparison + TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) + + +def adaptive_avg_pool2d(input, output_size): + """Handle empty batch dimension to adaptive_avg_pool2d. + + Args: + input (tensor): 4D tensor. + output_size (int, tuple[int,int]): the target output size. + """ + if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + if isinstance(output_size, int): + output_size = [output_size, output_size] + output_size = [*input.shape[:2], *output_size] + empty = NewEmptyTensorOp.apply(input, output_size) + return empty + else: + return F.adaptive_avg_pool2d(input, output_size) + + +class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): + """Handle empty batch dimension to AdaptiveAvgPool2d.""" + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + output_size = self.output_size + if isinstance(output_size, int): + output_size = [output_size, output_size] + else: + output_size = [ + v if v is not None else d + for v, d in zip(output_size, + x.size()[-2:]) + ] + output_size = [*x.shape[:2], *output_size] + empty = NewEmptyTensorOp.apply(x, output_size) + return empty + + return super().forward(x) + + +# Modified from +# https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13 # noqa +@MODELS.register_module('FrozenBN') +class FrozenBatchNorm2d(nn.Module): + """BatchNorm2d where the batch statistics and the affine parameters are + fixed. + + It contains non-trainable buffers called + "weight" and "bias", "running_mean", "running_var", + initialized to perform identity transformation. + Args: + num_features (int): :math:`C` from an expected input of size + :math:`(N, C, H, W)`. + eps (float): a value added to the denominator for numerical stability. + Default: 1e-5 + """ + + def __init__(self, num_features, eps=1e-5, **kwargs): + super().__init__() + self.num_features = num_features + self.eps = eps + self.register_buffer('weight', torch.ones(num_features)) + self.register_buffer('bias', torch.zeros(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features) - eps) + + def forward(self, x): + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias + # as well. + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + out_dtype = x.dtype # may be half + return x * scale.to(out_dtype) + bias.to(out_dtype) + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.eps, + ) + + def __repr__(self): + return 'FrozenBatchNorm2d(num_features={}, eps={})'.format( + self.num_features, self.eps) + + @classmethod + def convert_frozen_batchnorm(cls, module): + """Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. + + Args: + module (torch.nn.Module): + Returns: + If module is BatchNorm/SyncBatchNorm, returns a new module. + Otherwise, in-place convert module and return it. + Similar to convert_sync_batchnorm in + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py + """ + bn_module = nn.modules.batchnorm + bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) + res = module + if isinstance(module, bn_module): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res diff --git a/mmdet/models/layers/conv_upsample.py b/mmdet/models/layers/conv_upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..32505875a2162330ed7d00455f088d08d94f679e --- /dev/null +++ b/mmdet/models/layers/conv_upsample.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList + + +class ConvUpsample(BaseModule): + """ConvUpsample performs 2x upsampling after Conv. + + There are several `ConvModule` layers. In the first few layers, upsampling + will be applied after each layer of convolution. The number of upsampling + must be no more than the number of ConvModule layers. + + Args: + in_channels (int): Number of channels in the input feature map. + inner_channels (int): Number of channels produced by the convolution. + num_layers (int): Number of convolution layers. + num_upsample (int | optional): Number of upsampling layer. Must be no + more than num_layers. Upsampling will be applied after the first + ``num_upsample`` layers of convolution. Default: ``num_layers``. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + init_cfg (dict): Config dict for initialization. Default: None. + kwargs (key word augments): Other augments used in ConvModule. + """ + + def __init__(self, + in_channels, + inner_channels, + num_layers=1, + num_upsample=None, + conv_cfg=None, + norm_cfg=None, + init_cfg=None, + **kwargs): + super(ConvUpsample, self).__init__(init_cfg) + if num_upsample is None: + num_upsample = num_layers + assert num_upsample <= num_layers, \ + f'num_upsample({num_upsample})must be no more than ' \ + f'num_layers({num_layers})' + self.num_layers = num_layers + self.num_upsample = num_upsample + self.conv = ModuleList() + for i in range(num_layers): + self.conv.append( + ConvModule( + in_channels, + inner_channels, + 3, + padding=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + in_channels = inner_channels + + def forward(self, x): + num_upsample = self.num_upsample + for i in range(self.num_layers): + x = self.conv[i](x) + if num_upsample > 0: + num_upsample -= 1 + x = F.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=False) + return x diff --git a/mmdet/models/layers/csp_layer.py b/mmdet/models/layers/csp_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b547b8994862bfe14739033bb6b254ef886f29 --- /dev/null +++ b/mmdet/models/layers/csp_layer.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .se_layer import ChannelAttention + + +class DarknetBottleneck(BaseModule): + """The basic bottleneck block used in Darknet. + + Each ResBlock consists of two ConvModules and the input is added to the + final output. Each ConvModule is composed of Conv, BN, and LeakyReLU. + The first convLayer has filter size of 1x1 and the second one has the + filter size of 3x3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (float): The kernel size of the convolution. + Defaults to 0.5. + add_identity (bool): Whether to add identity to the out. + Defaults to True. + use_depthwise (bool): Whether to use depthwise separable convolution. + Defaults to False. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + """ + + def __init__(self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + add_identity: bool = True, + use_depthwise: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='Swish'), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + hidden_channels = int(out_channels * expansion) + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + self.conv1 = ConvModule( + in_channels, + hidden_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = conv( + hidden_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.add_identity = \ + add_identity and in_channels == out_channels + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPNeXtBlock(BaseModule): + """The basic bottleneck block used in CSPNeXt. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (float): Expand ratio of the hidden channel. Defaults to 0.5. + add_identity (bool): Whether to add identity to the out. Only works + when in_channels == out_channels. Defaults to True. + use_depthwise (bool): Whether to use depthwise separable convolution. + Defaults to False. + kernel_size (int): The kernel size of the second convolution layer. + Defaults to 5. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN', momentum=0.03, eps=0.001). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='SiLU'). + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + add_identity: bool = True, + use_depthwise: bool = False, + kernel_size: int = 5, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='SiLU'), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + hidden_channels = int(out_channels * expansion) + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + self.conv1 = conv( + in_channels, + hidden_channels, + 3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = DepthwiseSeparableConvModule( + hidden_channels, + out_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.add_identity = \ + add_identity and in_channels == out_channels + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPLayer(BaseModule): + """Cross Stage Partial Layer. + + Args: + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + expand_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Defaults to 0.5. + num_blocks (int): Number of blocks. Defaults to 1. + add_identity (bool): Whether to add identity in blocks. + Defaults to True. + use_cspnext_block (bool): Whether to use CSPNeXt block. + Defaults to False. + use_depthwise (bool): Whether to use depthwise separable convolution in + blocks. Defaults to False. + channel_attention (bool): Whether to add channel attention in each + stage. Defaults to True. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN') + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish') + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + expand_ratio: float = 0.5, + num_blocks: int = 1, + add_identity: bool = True, + use_depthwise: bool = False, + use_cspnext_block: bool = False, + channel_attention: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='Swish'), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + block = CSPNeXtBlock if use_cspnext_block else DarknetBottleneck + mid_channels = int(out_channels * expand_ratio) + self.channel_attention = channel_attention + self.main_conv = ConvModule( + in_channels, + mid_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.short_conv = ConvModule( + in_channels, + mid_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.final_conv = ConvModule( + 2 * mid_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.blocks = nn.Sequential(*[ + block( + mid_channels, + mid_channels, + 1.0, + add_identity, + use_depthwise, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) for _ in range(num_blocks) + ]) + if channel_attention: + self.attention = ChannelAttention(2 * mid_channels) + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + x_short = self.short_conv(x) + + x_main = self.main_conv(x) + x_main = self.blocks(x_main) + + x_final = torch.cat((x_main, x_short), dim=1) + + if self.channel_attention: + x_final = self.attention(x_final) + return self.final_conv(x_final) diff --git a/mmdet/models/layers/dropblock.py b/mmdet/models/layers/dropblock.py new file mode 100644 index 0000000000000000000000000000000000000000..7938199b761d637afdb1b2c62dbca01d1bf629eb --- /dev/null +++ b/mmdet/models/layers/dropblock.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.registry import MODELS + +eps = 1e-6 + + +@MODELS.register_module() +class DropBlock(nn.Module): + """Randomly drop some regions of feature maps. + + Please refer to the method proposed in `DropBlock + `_ for details. + + Args: + drop_prob (float): The probability of dropping each block. + block_size (int): The size of dropped blocks. + warmup_iters (int): The drop probability will linearly increase + from `0` to `drop_prob` during the first `warmup_iters` iterations. + Default: 2000. + """ + + def __init__(self, drop_prob, block_size, warmup_iters=2000, **kwargs): + super(DropBlock, self).__init__() + assert block_size % 2 == 1 + assert 0 < drop_prob <= 1 + assert warmup_iters >= 0 + self.drop_prob = drop_prob + self.block_size = block_size + self.warmup_iters = warmup_iters + self.iter_cnt = 0 + + def forward(self, x): + """ + Args: + x (Tensor): Input feature map on which some areas will be randomly + dropped. + + Returns: + Tensor: The tensor after DropBlock layer. + """ + if not self.training: + return x + self.iter_cnt += 1 + N, C, H, W = list(x.shape) + gamma = self._compute_gamma((H, W)) + mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) + mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device)) + + mask = F.pad(mask, [self.block_size // 2] * 4, value=0) + mask = F.max_pool2d( + input=mask, + stride=(1, 1), + kernel_size=(self.block_size, self.block_size), + padding=self.block_size // 2) + mask = 1 - mask + x = x * mask * mask.numel() / (eps + mask.sum()) + return x + + def _compute_gamma(self, feat_size): + """Compute the value of gamma according to paper. gamma is the + parameter of bernoulli distribution, which controls the number of + features to drop. + + gamma = (drop_prob * fm_area) / (drop_area * keep_area) + + Args: + feat_size (tuple[int, int]): The height and width of feature map. + + Returns: + float: The value of gamma. + """ + gamma = (self.drop_prob * feat_size[0] * feat_size[1]) + gamma /= ((feat_size[0] - self.block_size + 1) * + (feat_size[1] - self.block_size + 1)) + gamma /= (self.block_size**2) + factor = (1.0 if self.iter_cnt > self.warmup_iters else self.iter_cnt / + self.warmup_iters) + return gamma * factor + + def extra_repr(self): + return (f'drop_prob={self.drop_prob}, block_size={self.block_size}, ' + f'warmup_iters={self.warmup_iters}') diff --git a/mmdet/models/layers/ema.py b/mmdet/models/layers/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bce503c4641f7391a7bd7d722c05f4e49bd07db9 --- /dev/null +++ b/mmdet/models/layers/ema.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import ExponentialMovingAverage +from torch import Tensor + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class ExpMomentumEMA(ExponentialMovingAverage): + """Exponential moving average (EMA) with exponential momentum strategy, + which is used in YOLOX. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The momentum used for updating ema parameter. + Ema's parameter are updated with the formula: + `averaged_param = (1-momentum) * averaged_param + momentum * + source_param`. Defaults to 0.0002. + gamma (int): Use a larger momentum early in training and gradually + annealing to a smaller value to update the ema model smoothly. The + momentum is calculated as + `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. + Defaults to 2000. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.0002, + gamma: int = 2000, + interval=1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' + self.gamma = gamma + + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int) -> None: + """Compute the moving average of the parameters using the exponential + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + """ + momentum = (1 - self.momentum) * math.exp( + -float(1 + steps) / self.gamma) + self.momentum + averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) diff --git a/mmdet/models/layers/inverted_residual.py b/mmdet/models/layers/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..a174ccc8835a1ee720f9cdaa7c5be210f5be8113 --- /dev/null +++ b/mmdet/models/layers/inverted_residual.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule + +from .se_layer import SELayer + + +class InvertedResidual(BaseModule): + """Inverted Residual Block. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. + Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmdet/models/layers/matrix_nms.py b/mmdet/models/layers/matrix_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc8c4f74e28127fb69ccc684f0bdb2bd3943b20 --- /dev/null +++ b/mmdet/models/layers/matrix_nms.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def mask_matrix_nms(masks, + labels, + scores, + filter_thr=-1, + nms_pre=-1, + max_num=-1, + kernel='gaussian', + sigma=2.0, + mask_area=None): + """Matrix NMS for multi-class masks. + + Args: + masks (Tensor): Has shape (num_instances, h, w) + labels (Tensor): Labels of corresponding masks, + has shape (num_instances,). + scores (Tensor): Mask scores of corresponding masks, + has shape (num_instances). + filter_thr (float): Score threshold to filter the masks + after matrix nms. Default: -1, which means do not + use filter_thr. + nms_pre (int): The max number of instances to do the matrix nms. + Default: -1, which means do not use nms_pre. + max_num (int, optional): If there are more than max_num masks after + matrix, only top max_num will be kept. Default: -1, which means + do not use max_num. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + mask_area (Tensor): The sum of seg_masks. + + Returns: + tuple(Tensor): Processed mask results. + + - scores (Tensor): Updated scores, has shape (n,). + - labels (Tensor): Remained labels, has shape (n,). + - masks (Tensor): Remained masks, has shape (n, w, h). + - keep_inds (Tensor): The indices number of + the remaining mask in the input mask, has shape (n,). + """ + assert len(labels) == len(masks) == len(scores) + if len(labels) == 0: + return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( + 0, *masks.shape[-2:]), labels.new_zeros(0) + if mask_area is None: + mask_area = masks.sum((1, 2)).float() + else: + assert len(masks) == len(mask_area) + + # sort and keep top nms_pre + scores, sort_inds = torch.sort(scores, descending=True) + + keep_inds = sort_inds + if nms_pre > 0 and len(sort_inds) > nms_pre: + sort_inds = sort_inds[:nms_pre] + keep_inds = keep_inds[:nms_pre] + scores = scores[:nms_pre] + masks = masks[sort_inds] + mask_area = mask_area[sort_inds] + labels = labels[sort_inds] + + num_masks = len(labels) + flatten_masks = masks.reshape(num_masks, -1).float() + # inter. + inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) + expanded_mask_area = mask_area.expand(num_masks, num_masks) + # Upper triangle iou matrix. + iou_matrix = (inter_matrix / + (expanded_mask_area + expanded_mask_area.transpose(1, 0) - + inter_matrix)).triu(diagonal=1) + # label_specific matrix. + expanded_labels = labels.expand(num_masks, num_masks) + # Upper triangle label matrix. + label_matrix = (expanded_labels == expanded_labels.transpose( + 1, 0)).triu(diagonal=1) + + # IoU compensation + compensate_iou, _ = (iou_matrix * label_matrix).max(0) + compensate_iou = compensate_iou.expand(num_masks, + num_masks).transpose(1, 0) + + # IoU decay + decay_iou = iou_matrix * label_matrix + + # Calculate the decay_coefficient + if kernel == 'gaussian': + decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) + compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) + decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) + elif kernel == 'linear': + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient, _ = decay_matrix.min(0) + else: + raise NotImplementedError( + f'{kernel} kernel is not supported in matrix nms!') + # update the score. + scores = scores * decay_coefficient + + if filter_thr > 0: + keep = scores >= filter_thr + keep_inds = keep_inds[keep] + if not keep.any(): + return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( + 0, *masks.shape[-2:]), labels.new_zeros(0) + masks = masks[keep] + scores = scores[keep] + labels = labels[keep] + + # sort and keep top max_num + scores, sort_inds = torch.sort(scores, descending=True) + keep_inds = keep_inds[sort_inds] + if max_num > 0 and len(sort_inds) > max_num: + sort_inds = sort_inds[:max_num] + keep_inds = keep_inds[:max_num] + scores = scores[:max_num] + masks = masks[sort_inds] + labels = labels[sort_inds] + + return scores, labels, masks, keep_inds diff --git a/mmdet/models/layers/msdeformattn_pixel_decoder.py b/mmdet/models/layers/msdeformattn_pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a67dc3c4437f83ebe1c82d12b3ed91f429030ce7 --- /dev/null +++ b/mmdet/models/layers/msdeformattn_pixel_decoder.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, ConvModule +from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention +from mmengine.model import (BaseModule, ModuleList, caffe2_xavier_init, + normal_init, xavier_init) +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptMultiConfig +from ..task_modules.prior_generators import MlvlPointGenerator +from .positional_encoding import SinePositionalEncoding +from .transformer import Mask2FormerTransformerEncoder + + +@MODELS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`ConfigDict` or dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`ConfigDict` or dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`ConfigDict` or dict): Config for transformer + encoder. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer encoder position encoding. Defaults to + dict(num_feats=128, normalize=True). + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: Union[List[int], + Tuple[int]] = [256, 512, 1024, 2048], + strides: Union[List[int], Tuple[int]] = [4, 8, 16, 32], + feat_channels: int = 256, + out_channels: int = 256, + num_outs: int = 3, + norm_cfg: ConfigType = dict(type='GN', num_groups=32), + act_cfg: ConfigType = dict(type='ReLU'), + encoder: ConfigType = None, + positional_encoding: ConfigType = dict( + num_feats=128, normalize=True), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = \ + encoder.layer_cfg.self_attn_cfg.num_levels + assert self.num_encoder_levels >= 1, \ + 'num_levels in attn_cfgs must be at least one' + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, + self.num_input_levels - self.num_encoder_levels - 1, + -1): + input_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=None, + bias=True) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = Mask2FormerTransformerEncoder(**encoder) + self.postional_encoding = SinePositionalEncoding(**positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, + feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + lateral_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self) -> None: + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init( + self.input_convs[i].conv, + gain=1, + bias=0, + distribution='uniform') + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for m in self.encoder.layers.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + + def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]: + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + feat_hw = torch._shape_as_tensor(feat)[2:].to(feat.device) + + # no padding + padding_mask_resized = feat.new_zeros( + (batch_size, ) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device) + # normalize + feat_wh = feat_hw.unsqueeze(0).flip(dims=[0, 1]) + factor = feat_wh * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(0, 2, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(0, 2, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat_hw) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_queries), + # total_num_queries=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_queries, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=1) + level_positional_encodings = torch.cat( + level_positional_encoding_list, dim=1) + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + num_queries_per_level = [e[0] * e[1] for e in spatial_shapes] + spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat( + batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones( + (batch_size, self.num_encoder_levels, 2)) + # shape (num_total_queries, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + query_pos=level_positional_encodings, + key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_radios) + # (batch_size, c, num_total_queries) + memory = memory.permute(0, 2, 1) + + # from low resolution to high resolution + outs = torch.split(memory, num_queries_per_level, dim=-1) + outs = [ + x.reshape(batch_size, -1, spatial_shapes[i][0], + spatial_shapes[i][1]) for i, x in enumerate(outs) + ] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate( + outs[-1], + size=cur_feat.shape[-2:], + mode='bilinear', + align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[:self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features diff --git a/mmdet/models/layers/normed_predictor.py b/mmdet/models/layers/normed_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..592194b1dbbb8582f4c642bf29135573e1f8c3c8 --- /dev/null +++ b/mmdet/models/layers/normed_predictor.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import digit_version +from torch import Tensor + +from mmdet.registry import MODELS + +MODELS.register_module('Linear', module=nn.Linear) + + +@MODELS.register_module(name='NormedLinear') +class NormedLinear(nn.Linear): + """Normalized Linear Layer. + + Args: + tempeature (float, optional): Tempeature term. Defaults to 20. + power (int, optional): Power term. Defaults to 1.0. + eps (float, optional): The minimal value of divisor to + keep numerical stability. Defaults to 1e-6. + """ + + def __init__(self, + *args, + tempearture: float = 20, + power: int = 1.0, + eps: float = 1e-6, + **kwargs) -> None: + super().__init__(*args, **kwargs) + self.tempearture = tempearture + self.power = power + self.eps = eps + self.init_weights() + + def init_weights(self) -> None: + """Initialize the weights.""" + nn.init.normal_(self.weight, mean=0, std=0.01) + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward function for `NormedLinear`.""" + weight_ = self.weight / ( + self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps) + x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps) + x_ = x_ * self.tempearture + + return F.linear(x_, weight_, self.bias) + + +@MODELS.register_module(name='NormedConv2d') +class NormedConv2d(nn.Conv2d): + """Normalized Conv2d Layer. + + Args: + tempeature (float, optional): Tempeature term. Defaults to 20. + power (int, optional): Power term. Defaults to 1.0. + eps (float, optional): The minimal value of divisor to + keep numerical stability. Defaults to 1e-6. + norm_over_kernel (bool, optional): Normalize over kernel. + Defaults to False. + """ + + def __init__(self, + *args, + tempearture: float = 20, + power: int = 1.0, + eps: float = 1e-6, + norm_over_kernel: bool = False, + **kwargs) -> None: + super().__init__(*args, **kwargs) + self.tempearture = tempearture + self.power = power + self.norm_over_kernel = norm_over_kernel + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + """Forward function for `NormedConv2d`.""" + if not self.norm_over_kernel: + weight_ = self.weight / ( + self.weight.norm(dim=1, keepdim=True).pow(self.power) + + self.eps) + else: + weight_ = self.weight / ( + self.weight.view(self.weight.size(0), -1).norm( + dim=1, keepdim=True).pow(self.power)[..., None, None] + + self.eps) + x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps) + x_ = x_ * self.tempearture + + if hasattr(self, 'conv2d_forward'): + x_ = self.conv2d_forward(x_, weight_) + else: + if digit_version(torch.__version__) >= digit_version('1.8'): + x_ = self._conv_forward(x_, weight_, self.bias) + else: + x_ = self._conv_forward(x_, weight_) + return x_ diff --git a/mmdet/models/layers/pixel_decoder.py b/mmdet/models/layers/pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fb61434045eb9996276518577800132e4a25eb3e --- /dev/null +++ b/mmdet/models/layers/pixel_decoder.py @@ -0,0 +1,249 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, ConvModule +from mmengine.model import BaseModule, ModuleList, caffe2_xavier_init +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptMultiConfig +from .positional_encoding import SinePositionalEncoding +from .transformer import DetrTransformerEncoder + + +@MODELS.register_module() +class PixelDecoder(BaseModule): + """Pixel decoder with a structure like fpn. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + norm_cfg (:obj:`ConfigDict` or dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`ConfigDict` or dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`ConfigDict` or dict): Config for transorformer + encoder.Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: Union[List[int], Tuple[int]], + feat_channels: int, + out_channels: int, + norm_cfg: ConfigType = dict(type='GN', num_groups=32), + act_cfg: ConfigType = dict(type='ReLU'), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_inputs = len(in_channels) + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + for i in range(0, self.num_inputs - 1): + lateral_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.last_feat_conv = ConvModule( + in_channels[-1], + feat_channels, + kernel_size=3, + padding=1, + stride=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def init_weights(self) -> None: + """Initialize weights.""" + for i in range(0, self.num_inputs - 2): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + caffe2_xavier_init(self.last_feat_conv, bias=0) + + def forward(self, feats: List[Tensor], + batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]: + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + batch_img_metas (list[dict]): List of image information. + Pass in for creating more accurate padding mask. Not + used here. + + Returns: + tuple[Tensor, Tensor]: a tuple containing the following: + + - mask_feature (Tensor): Shape (batch_size, c, h, w). + - memory (Tensor): Output of last stage of backbone.\ + Shape (batch_size, c, h, w). + """ + y = self.last_feat_conv(feats[-1]) + for i in range(self.num_inputs - 2, -1, -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + \ + F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest') + y = self.output_convs[i](y) + + mask_feature = self.mask_feature(y) + memory = feats[-1] + return mask_feature, memory + + +@MODELS.register_module() +class TransformerEncoderPixelDecoder(PixelDecoder): + """Pixel decoder with transormer encoder inside. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + norm_cfg (:obj:`ConfigDict` or dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`ConfigDict` or dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`ConfigDict` or dict): Config for transformer encoder. + Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer encoder position encoding. Defaults to + dict(num_feats=128, normalize=True). + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: Union[List[int], Tuple[int]], + feat_channels: int, + out_channels: int, + norm_cfg: ConfigType = dict(type='GN', num_groups=32), + act_cfg: ConfigType = dict(type='ReLU'), + encoder: ConfigType = None, + positional_encoding: ConfigType = dict( + num_feats=128, normalize=True), + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=init_cfg) + self.last_feat_conv = None + + self.encoder = DetrTransformerEncoder(**encoder) + self.encoder_embed_dims = self.encoder.embed_dims + assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \ + 'tranformer encoder must equal to feat_channels({})'.format( + feat_channels, self.encoder_embed_dims) + self.positional_encoding = SinePositionalEncoding( + **positional_encoding) + self.encoder_in_proj = Conv2d( + in_channels[-1], feat_channels, kernel_size=1) + self.encoder_out_proj = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def init_weights(self) -> None: + """Initialize weights.""" + for i in range(0, self.num_inputs - 2): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + caffe2_xavier_init(self.encoder_in_proj, bias=0) + caffe2_xavier_init(self.encoder_out_proj.conv, bias=0) + + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feats: List[Tensor], + batch_img_metas: List[dict]) -> Tuple[Tensor, Tensor]: + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + batch_img_metas (list[dict]): List of image information. Pass in + for creating more accurate padding mask. + + Returns: + tuple: a tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - memory (Tensor): shape (batch_size, c, h, w). + """ + feat_last = feats[-1] + bs, c, h, w = feat_last.shape + input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] + padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w), + dtype=torch.float32) + for i in range(bs): + img_h, img_w = batch_img_metas[i]['img_shape'] + padding_mask[i, :img_h, :img_w] = 0 + padding_mask = F.interpolate( + padding_mask.unsqueeze(1), + size=feat_last.shape[-2:], + mode='nearest').to(torch.bool).squeeze(1) + + pos_embed = self.positional_encoding(padding_mask) + feat_last = self.encoder_in_proj(feat_last) + # (batch_size, c, h, w) -> (batch_size, num_queries, c) + feat_last = feat_last.flatten(2).permute(0, 2, 1) + pos_embed = pos_embed.flatten(2).permute(0, 2, 1) + # (batch_size, h, w) -> (batch_size, h*w) + padding_mask = padding_mask.flatten(1) + memory = self.encoder( + query=feat_last, + query_pos=pos_embed, + key_padding_mask=padding_mask) + # (batch_size, num_queries, c) -> (batch_size, c, h, w) + memory = memory.permute(0, 2, 1).view(bs, self.encoder_embed_dims, h, + w) + y = self.encoder_out_proj(memory) + for i in range(self.num_inputs - 2, -1, -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + \ + F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest') + y = self.output_convs[i](y) + + mask_feature = self.mask_feature(y) + return mask_feature, memory diff --git a/mmdet/models/layers/positional_encoding.py b/mmdet/models/layers/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..87080d81a9f155839d453b8671103e5d51fbf88a --- /dev/null +++ b/mmdet/models/layers/positional_encoding.py @@ -0,0 +1,269 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import MultiConfig, OptMultiConfig + + +@MODELS.register_module() +class SinePositionalEncoding(BaseModule): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, + num_feats: int, + temperature: int = 10000, + normalize: bool = False, + scale: float = 2 * math.pi, + eps: float = 1e-6, + offset: float = 0., + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + if normalize: + assert isinstance(scale, (float, int)), 'when normalize is set,' \ + 'scale should be provided and in float or int type, ' \ + f'found {type(scale)}' + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask: Tensor, input: Optional[Tensor] = None) -> Tensor: + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + input (Tensor, optional): Input image/feature Tensor. + Shape [bs, c, h, w] + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + assert not (mask is None and input is None) + + if mask is not None: + B, H, W = mask.size() + device = mask.device + # For convenience of exporting to ONNX, + # it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + else: + # single image or batch image with no padding + B, _, H, W = input.shape + device = input.device + x_embed = torch.arange( + 1, W + 1, dtype=torch.float32, device=device) + x_embed = x_embed.view(1, 1, -1).repeat(B, H, 1) + y_embed = torch.arange( + 1, H + 1, dtype=torch.float32, device=device) + y_embed = y_embed.view(1, -1, 1).repeat(B, 1, W) + if self.normalize: + y_embed = (y_embed + self.offset) / \ + (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / \ + (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange( + self.num_feats, dtype=torch.float32, device=device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).view(B, H, W, -1) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self) -> str: + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'temperature={self.temperature}, ' + repr_str += f'normalize={self.normalize}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'eps={self.eps})' + return repr_str + + +@MODELS.register_module() +class LearnedPositionalEncoding(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Defaults to 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Defaults to 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_feats: int, + row_num_embed: int = 50, + col_num_embed: int = 50, + init_cfg: MultiConfig = dict(type='Uniform', layer='Embedding') + ) -> None: + super().__init__(init_cfg=init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask: Tensor) -> Tensor: + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = torch.cat( + (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( + 1, w, 1)), + dim=-1).permute(2, 0, + 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) + return pos + + def __repr__(self) -> str: + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'row_num_embed={self.row_num_embed}, ' + repr_str += f'col_num_embed={self.col_num_embed})' + return repr_str + + +@MODELS.register_module() +class SinePositionalEncoding3D(SinePositionalEncoding): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def forward(self, mask: Tensor) -> Tensor: + """Forward function for `SinePositionalEncoding3D`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, t, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + assert mask.dim() == 4,\ + f'{mask.shape} should be a 4-dimensional Tensor,' \ + f' got {mask.dim()}-dimensional Tensor instead ' + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + z_embed = not_mask.cumsum(1, dtype=torch.float32) + y_embed = not_mask.cumsum(2, dtype=torch.float32) + x_embed = not_mask.cumsum(3, dtype=torch.float32) + if self.normalize: + z_embed = (z_embed + self.offset) / \ + (z_embed[:, -1:, :, :] + self.eps) * self.scale + y_embed = (y_embed + self.offset) / \ + (y_embed[:, :, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / \ + (x_embed[:, :, :, -1:] + self.eps) * self.scale + dim_t = torch.arange( + self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) + + dim_t_z = torch.arange((self.num_feats * 2), + dtype=torch.float32, + device=mask.device) + dim_t_z = self.temperature**(2 * (dim_t_z // 2) / (self.num_feats * 2)) + + pos_x = x_embed[:, :, :, :, None] / dim_t + pos_y = y_embed[:, :, :, :, None] / dim_t + pos_z = z_embed[:, :, :, :, None] / dim_t_z + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, T, H, W = mask.size() + pos_x = torch.stack( + (pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), + dim=5).view(B, T, H, W, -1) + pos_y = torch.stack( + (pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), + dim=5).view(B, T, H, W, -1) + pos_z = torch.stack( + (pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), + dim=5).view(B, T, H, W, -1) + pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3) + return pos diff --git a/mmdet/models/layers/res_layer.py b/mmdet/models/layers/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff24d3e8562d1c3c724b35f7dc10cafe48e47650 --- /dev/null +++ b/mmdet/models/layers/res_layer.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, Sequential +from torch import Tensor +from torch import nn as nn + +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Defaults to 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to False + conv_cfg (dict): dictionary to construct and config conv layer. + Defaults to None + norm_cfg (dict): dictionary to construct and config norm layer. + Defaults to dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Defaults to True + """ + + def __init__(self, + block: BaseModule, + inplanes: int, + planes: int, + num_blocks: int, + stride: int = 1, + avg_down: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + downsample_first: bool = True, + **kwargs) -> None: + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + + else: # downsample_first=False is for HourglassModule + for _ in range(num_blocks - 1): + layers.append( + block( + inplanes=inplanes, + planes=inplanes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super().__init__(*layers) + + +class SimplifiedBasicBlock(BaseModule): + """Simplified version of original basic residual block. This is used in + `SCNet `_. + + - Norm layer is now optional + - Last ReLU in forward function is removed + """ + expansion = 1 + + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + downsample: Optional[Sequential] = None, + style: ConfigType = 'pytorch', + with_cp: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + dcn: OptConfigType = None, + plugins: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert not with_cp, 'Not implemented yet.' + self.with_norm = norm_cfg is not None + with_bias = True if norm_cfg is None else False + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=with_bias) + if self.with_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, planes, postfix=1) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=with_bias) + if self.with_norm: + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, planes, postfix=2) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self) -> Optional[BaseModule]: + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) if self.with_norm else None + + @property + def norm2(self) -> Optional[BaseModule]: + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) if self.with_norm else None + + def forward(self, x: Tensor) -> Tensor: + """Forward function for SimplifiedBasicBlock.""" + + identity = x + + out = self.conv1(x) + if self.with_norm: + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + if self.with_norm: + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out diff --git a/mmdet/models/layers/se_layer.py b/mmdet/models/layers/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..5598dabaf6f3b3a09f4348fcd65ff39897b7068f --- /dev/null +++ b/mmdet/models/layers/se_layer.py @@ -0,0 +1,162 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.utils import digit_version, is_tuple_of +from torch import Tensor + +from mmdet.utils import MultiConfig, OptConfigType, OptMultiConfig + + +class SELayer(BaseModule): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Defaults to 16. + conv_cfg (None or dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Defaults to (dict(type='ReLU'), dict(type='Sigmoid')) + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, + channels: int, + ratio: int = 16, + conv_cfg: OptConfigType = None, + act_cfg: MultiConfig = (dict(type='ReLU'), + dict(type='Sigmoid')), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x: Tensor) -> Tensor: + """Forward function for SELayer.""" + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out + + +class DyReLU(BaseModule): + """Dynamic ReLU (DyReLU) module. + + See `Dynamic ReLU `_ for details. + Current implementation is specialized for task-aware attention in DyHead. + HSigmoid arguments in default act_cfg follow DyHead official code. + https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py + + Args: + channels (int): The input (and output) channels of DyReLU module. + ratio (int): Squeeze ratio in Squeeze-and-Excitation-like module, + the intermediate channel will be ``int(channels/ratio)``. + Defaults to 4. + conv_cfg (None or dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Defaults to (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)) + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, + channels: int, + ratio: int = 4, + conv_cfg: OptConfigType = None, + act_cfg: MultiConfig = (dict(type='ReLU'), + dict( + type='HSigmoid', + bias=3.0, + divisor=6.0)), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.channels = channels + self.expansion = 4 # for a1, b1, a2, b2 + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=int(channels / ratio), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=int(channels / ratio), + out_channels=channels * self.expansion, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + coeffs = self.global_avgpool(x) + coeffs = self.conv1(coeffs) + coeffs = self.conv2(coeffs) - 0.5 # value range: [-0.5, 0.5] + a1, b1, a2, b2 = torch.split(coeffs, self.channels, dim=1) + a1 = a1 * 2.0 + 1.0 # [-1.0, 1.0] + 1.0 + a2 = a2 * 2.0 # [-1.0, 1.0] + out = torch.max(x * a1 + b1, x * a2 + b2) + return out + + +class ChannelAttention(BaseModule): + """Channel attention Module. + + Args: + channels (int): The input (and output) channels of the attention layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, channels: int, init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) + if digit_version(torch.__version__) < (1, 7, 0): + self.act = nn.Hardsigmoid() + else: + self.act = nn.Hardsigmoid(inplace=True) + + def forward(self, x: Tensor) -> Tensor: + """Forward function for ChannelAttention.""" + with torch.cuda.amp.autocast(enabled=False): + out = self.global_avgpool(x) + out = self.fc(out) + out = self.act(out) + return x * out diff --git a/mmdet/models/layers/transformer/__init__.py b/mmdet/models/layers/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..839d936412673d765cd9f89a44a366a64976bb9c --- /dev/null +++ b/mmdet/models/layers/transformer/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .conditional_detr_layers import (ConditionalDetrTransformerDecoder, + ConditionalDetrTransformerDecoderLayer) +from .dab_detr_layers import (DABDetrTransformerDecoder, + DABDetrTransformerDecoderLayer, + DABDetrTransformerEncoder) +from .ddq_detr_layers import DDQTransformerDecoder +from .deformable_detr_layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer) +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) +from .dino_layers import CdnQueryGenerator, DinoTransformerDecoder +from .grounding_dino_layers import (GroundingDinoTransformerDecoder, + GroundingDinoTransformerDecoderLayer, + GroundingDinoTransformerEncoder) +from .mask2former_layers import (Mask2FormerTransformerDecoder, + Mask2FormerTransformerDecoderLayer, + Mask2FormerTransformerEncoder) +from .utils import (MLP, AdaptivePadding, ConditionalAttention, DynamicConv, + PatchEmbed, PatchMerging, coordinate_to_encoding, + inverse_sigmoid, nchw_to_nlc, nlc_to_nchw) + +__all__ = [ + 'nlc_to_nchw', 'nchw_to_nlc', 'AdaptivePadding', 'PatchEmbed', + 'PatchMerging', 'inverse_sigmoid', 'DynamicConv', 'MLP', + 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer', 'coordinate_to_encoding', + 'ConditionalAttention', 'DABDetrTransformerDecoderLayer', + 'DABDetrTransformerDecoder', 'DABDetrTransformerEncoder', + 'DDQTransformerDecoder', 'ConditionalDetrTransformerDecoder', + 'ConditionalDetrTransformerDecoderLayer', 'DinoTransformerDecoder', + 'CdnQueryGenerator', 'Mask2FormerTransformerEncoder', + 'Mask2FormerTransformerDecoderLayer', 'Mask2FormerTransformerDecoder', + 'GroundingDinoTransformerDecoderLayer', 'GroundingDinoTransformerEncoder', + 'GroundingDinoTransformerDecoder' +] diff --git a/mmdet/models/layers/transformer/conditional_detr_layers.py b/mmdet/models/layers/transformer/conditional_detr_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..6db12a1340c758996e8c0e96f0b21cbc6fa928c9 --- /dev/null +++ b/mmdet/models/layers/transformer/conditional_detr_layers.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN +from torch import Tensor +from torch.nn import ModuleList + +from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer +from .utils import MLP, ConditionalAttention, coordinate_to_encoding + + +class ConditionalDetrTransformerDecoder(DetrTransformerDecoder): + """Decoder of Conditional DETR.""" + + def _init_layers(self) -> None: + """Initialize decoder layers and other layers.""" + self.layers = ModuleList([ + ConditionalDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + self.post_norm = build_norm_layer(self.post_norm_cfg, + self.embed_dims)[1] + # conditional detr affline + self.query_scale = MLP(self.embed_dims, self.embed_dims, + self.embed_dims, 2) + self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 2, 2) + # we have substitute 'qpos_proj' with 'qpos_sine_proj' except for + # the first decoder layer), so 'qpos_proj' should be deleted + # in other layers. + for layer_id in range(self.num_layers - 1): + self.layers[layer_id + 1].cross_attn.qpos_proj = None + + def forward(self, + query: Tensor, + key: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + key_padding_mask: Tensor = None): + """Forward function of decoder. + + Args: + query (Tensor): The input query with shape + (bs, num_queries, dim). + key (Tensor): The input key with shape (bs, num_keys, dim) If + `None`, the `query` will be used. Defaults to `None`. + query_pos (Tensor): The positional encoding for `query`, with the + same shape as `query`. If not `None`, it will be added to + `query` before forward function. Defaults to `None`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. If not `None`, it will be added to + `key` before forward function. If `None`, and `query_pos` + has the same shape as `key`, then `query_pos` will be used + as `key_pos`. Defaults to `None`. + key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys). + Defaults to `None`. + Returns: + List[Tensor]: forwarded results with shape (num_decoder_layers, + bs, num_queries, dim) if `return_intermediate` is True, otherwise + with shape (1, bs, num_queries, dim). References with shape + (bs, num_queries, 2). + """ + reference_unsigmoid = self.ref_point_head( + query_pos) # [bs, num_queries, 2] + reference = reference_unsigmoid.sigmoid() + reference_xy = reference[..., :2] + intermediate = [] + for layer_id, layer in enumerate(self.layers): + if layer_id == 0: + pos_transformation = 1 + else: + pos_transformation = self.query_scale(query) + # get sine embedding for the query reference + ref_sine_embed = coordinate_to_encoding(coord_tensor=reference_xy) + # apply transformation + ref_sine_embed = ref_sine_embed * pos_transformation + query = layer( + query, + key=key, + query_pos=query_pos, + key_pos=key_pos, + key_padding_mask=key_padding_mask, + ref_sine_embed=ref_sine_embed, + is_first=(layer_id == 0)) + if self.return_intermediate: + intermediate.append(self.post_norm(query)) + + if self.return_intermediate: + return torch.stack(intermediate), reference + + query = self.post_norm(query) + return query.unsqueeze(0), reference + + +class ConditionalDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Implements decoder layer in Conditional DETR transformer.""" + + def _init_layers(self): + """Initialize self-attention, cross-attention, FFN, and + normalization.""" + self.self_attn = ConditionalAttention(**self.self_attn_cfg) + self.cross_attn = ConditionalAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_masks: Tensor = None, + cross_attn_masks: Tensor = None, + key_padding_mask: Tensor = None, + ref_sine_embed: Tensor = None, + is_first: bool = False): + """ + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim) + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be + added to `query` before forward function. Defaults to `None`. + ref_sine_embed (Tensor): The positional encoding for query in + cross attention, with the same shape as `x`. Defaults to None. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not None, it will be added to + `key` before forward function. If None, and `query_pos` has + the same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_masks (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), Same in `nn.MultiheadAttention. + forward`. Defaults to None. + cross_attn_masks (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), Same in `nn.MultiheadAttention. + forward`. Defaults to None. + key_padding_mask (Tensor, optional): ByteTensor, has shape + (bs, num_keys). Defaults to None. + is_first (bool): A indicator to tell whether the current layer + is the first layer of the decoder. Defaults to False. + + Returns: + Tensor: Forwarded results, has shape (bs, num_queries, dim). + """ + query = self.self_attn( + query=query, + key=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_masks) + query = self.norms[0](query) + query = self.cross_attn( + query=query, + key=key, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_masks, + key_padding_mask=key_padding_mask, + ref_sine_embed=ref_sine_embed, + is_first=is_first) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query diff --git a/mmdet/models/layers/transformer/dab_detr_layers.py b/mmdet/models/layers/transformer/dab_detr_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a6e7724a1b1ca18f26dd10455f3e3a4d696460 --- /dev/null +++ b/mmdet/models/layers/transformer/dab_detr_layers.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import ModuleList +from torch import Tensor + +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) +from .utils import (MLP, ConditionalAttention, coordinate_to_encoding, + inverse_sigmoid) + + +class DABDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Implements decoder layer in DAB-DETR transformer.""" + + def _init_layers(self): + """Initialize self-attention, cross-attention, FFN, normalization and + others.""" + self.self_attn = ConditionalAttention(**self.self_attn_cfg) + self.cross_attn = ConditionalAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + self.keep_query_pos = self.cross_attn.keep_query_pos + + def forward(self, + query: Tensor, + key: Tensor, + query_pos: Tensor, + key_pos: Tensor, + ref_sine_embed: Tensor = None, + self_attn_masks: Tensor = None, + cross_attn_masks: Tensor = None, + key_padding_mask: Tensor = None, + is_first: bool = False, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query with shape [bs, num_queries, + dim]. + key (Tensor): The key tensor with shape [bs, num_keys, + dim]. + query_pos (Tensor): The positional encoding for query in self + attention, with the same shape as `x`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. + ref_sine_embed (Tensor): The positional encoding for query in + cross attention, with the same shape as `x`. + Defaults to None. + self_attn_masks (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_masks (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + is_first (bool): A indicator to tell whether the current layer + is the first layer of the decoder. + Defaults to False. + + Returns: + Tensor: forwarded results with shape + [bs, num_queries, dim]. + """ + + query = self.self_attn( + query=query, + key=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_masks, + **kwargs) + query = self.norms[0](query) + query = self.cross_attn( + query=query, + key=key, + query_pos=query_pos, + key_pos=key_pos, + ref_sine_embed=ref_sine_embed, + attn_mask=cross_attn_masks, + key_padding_mask=key_padding_mask, + is_first=is_first, + **kwargs) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query + + +class DABDetrTransformerDecoder(DetrTransformerDecoder): + """Decoder of DAB-DETR. + + Args: + query_dim (int): The last dimension of query pos, + 4 for anchor format, 2 for point format. + Defaults to 4. + query_scale_type (str): Type of transformation applied + to content query. Defaults to `cond_elewise`. + with_modulated_hw_attn (bool): Whether to inject h&w info + during cross conditional attention. Defaults to True. + """ + + def __init__(self, + *args, + query_dim: int = 4, + query_scale_type: str = 'cond_elewise', + with_modulated_hw_attn: bool = True, + **kwargs): + + self.query_dim = query_dim + self.query_scale_type = query_scale_type + self.with_modulated_hw_attn = with_modulated_hw_attn + + super().__init__(*args, **kwargs) + + def _init_layers(self): + """Initialize decoder layers and other layers.""" + assert self.query_dim in [2, 4], \ + f'{"dab-detr only supports anchor prior or reference point prior"}' + assert self.query_scale_type in [ + 'cond_elewise', 'cond_scalar', 'fix_elewise' + ] + + self.layers = ModuleList([ + DABDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + + embed_dims = self.layers[0].embed_dims + self.embed_dims = embed_dims + + self.post_norm = build_norm_layer(self.post_norm_cfg, embed_dims)[1] + if self.query_scale_type == 'cond_elewise': + self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2) + elif self.query_scale_type == 'cond_scalar': + self.query_scale = MLP(embed_dims, embed_dims, 1, 2) + elif self.query_scale_type == 'fix_elewise': + self.query_scale = nn.Embedding(self.num_layers, embed_dims) + else: + raise NotImplementedError('Unknown query_scale_type: {}'.format( + self.query_scale_type)) + + self.ref_point_head = MLP(self.query_dim // 2 * embed_dims, embed_dims, + embed_dims, 2) + + if self.with_modulated_hw_attn and self.query_dim == 4: + self.ref_anchor_head = MLP(embed_dims, embed_dims, 2, 2) + + self.keep_query_pos = self.layers[0].keep_query_pos + if not self.keep_query_pos: + for layer_id in range(self.num_layers - 1): + self.layers[layer_id + 1].cross_attn.qpos_proj = None + + def forward(self, + query: Tensor, + key: Tensor, + query_pos: Tensor, + key_pos: Tensor, + reg_branches: nn.Module, + key_padding_mask: Tensor = None, + **kwargs) -> List[Tensor]: + """Forward function of decoder. + + Args: + query (Tensor): The input query with shape (bs, num_queries, dim). + key (Tensor): The input key with shape (bs, num_keys, dim). + query_pos (Tensor): The positional encoding for `query`, with the + same shape as `query`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. + reg_branches (nn.Module): The regression branch for dynamically + updating references in each layer. + key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys). + Defaults to `None`. + + Returns: + List[Tensor]: forwarded results with shape (num_decoder_layers, + bs, num_queries, dim) if `return_intermediate` is True, otherwise + with shape (1, bs, num_queries, dim). references with shape + (num_decoder_layers, bs, num_queries, 2/4). + """ + output = query + unsigmoid_references = query_pos + + reference_points = unsigmoid_references.sigmoid() + intermediate_reference_points = [reference_points] + + intermediate = [] + for layer_id, layer in enumerate(self.layers): + obj_center = reference_points[..., :self.query_dim] + ref_sine_embed = coordinate_to_encoding( + coord_tensor=obj_center, num_feats=self.embed_dims // 2) + query_pos = self.ref_point_head( + ref_sine_embed) # [bs, nq, 2c] -> [bs, nq, c] + # For the first decoder layer, do not apply transformation + if self.query_scale_type != 'fix_elewise': + if layer_id == 0: + pos_transformation = 1 + else: + pos_transformation = self.query_scale(output) + else: + pos_transformation = self.query_scale.weight[layer_id] + # apply transformation + ref_sine_embed = ref_sine_embed[ + ..., :self.embed_dims] * pos_transformation + # modulated height and weight attention + if self.with_modulated_hw_attn: + assert obj_center.size(-1) == 4 + ref_hw = self.ref_anchor_head(output).sigmoid() + ref_sine_embed[..., self.embed_dims // 2:] *= \ + (ref_hw[..., 0] / obj_center[..., 2]).unsqueeze(-1) + ref_sine_embed[..., : self.embed_dims // 2] *= \ + (ref_hw[..., 1] / obj_center[..., 3]).unsqueeze(-1) + + output = layer( + output, + key, + query_pos=query_pos, + ref_sine_embed=ref_sine_embed, + key_pos=key_pos, + key_padding_mask=key_padding_mask, + is_first=(layer_id == 0), + **kwargs) + # iter update + tmp_reg_preds = reg_branches(output) + tmp_reg_preds[..., :self.query_dim] += inverse_sigmoid( + reference_points) + new_reference_points = tmp_reg_preds[ + ..., :self.query_dim].sigmoid() + if layer_id != self.num_layers - 1: + intermediate_reference_points.append(new_reference_points) + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(self.post_norm(output)) + + output = self.post_norm(output) + + if self.return_intermediate: + return [ + torch.stack(intermediate), + torch.stack(intermediate_reference_points), + ] + else: + return [ + output.unsqueeze(0), + torch.stack(intermediate_reference_points) + ] + + +class DABDetrTransformerEncoder(DetrTransformerEncoder): + """Encoder of DAB-DETR.""" + + def _init_layers(self): + """Initialize encoder layers.""" + self.layers = ModuleList([ + DetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + embed_dims = self.layers[0].embed_dims + self.embed_dims = embed_dims + self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2) + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs): + """Forward function of encoder. + + Args: + query (Tensor): Input queries of encoder, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional embeddings of the queries, has + shape (bs, num_feat_points, dim). + key_padding_mask (Tensor): ByteTensor, the key padding mask + of the queries, has shape (bs, num_feat_points). + + Returns: + Tensor: With shape (num_queries, bs, dim). + """ + + for layer in self.layers: + pos_scales = self.query_scale(query) + query = layer( + query, + query_pos=query_pos * pos_scales, + key_padding_mask=key_padding_mask, + **kwargs) + + return query diff --git a/mmdet/models/layers/transformer/ddq_detr_layers.py b/mmdet/models/layers/transformer/ddq_detr_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..57664c7ea2bdd17681ccdabe9140eb043a99e155 --- /dev/null +++ b/mmdet/models/layers/transformer/ddq_detr_layers.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +from mmcv.ops import batched_nms +from torch import Tensor, nn + +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy +from .deformable_detr_layers import DeformableDetrTransformerDecoder +from .utils import MLP, coordinate_to_encoding, inverse_sigmoid + + +class DDQTransformerDecoder(DeformableDetrTransformerDecoder): + """Transformer decoder of DDQ.""" + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + super()._init_layers() + self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, + self.embed_dims, 2) + self.norm = nn.LayerNorm(self.embed_dims) + + def select_distinct_queries(self, reference_points: Tensor, query: Tensor, + self_attn_mask: Tensor, layer_index): + """Get updated `self_attn_mask` for distinct queries selection, it is + used in self attention layers of decoder. + + Args: + reference_points (Tensor): The input reference of decoder, + has shape (bs, num_queries, 4) with the last dimension + arranged as (cx, cy, w, h). + query (Tensor): The input query of decoder, has shape + (bs, num_queries, dims). + self_attn_mask (Tensor): The input self attention mask of + last decoder layer, has shape (bs, num_queries_total, + num_queries_total). + layer_index (int): Last decoder layer index, used to get + classification score of last layer output, for + distinct queries selection. + + Returns: + Tensor: `self_attn_mask` used in self attention layers + of decoder, has shape (bs, num_queries_total, + num_queries_total). + """ + num_imgs = len(reference_points) + dis_start, num_dis = self.cache_dict['dis_query_info'] + # shape of self_attn_mask + # (batch⋅num_heads, num_queries, embed_dims) + dis_mask = self_attn_mask[:, dis_start:dis_start + num_dis, + dis_start:dis_start + num_dis] + # cls_branches from DDQDETRHead + scores = self.cache_dict['cls_branches'][layer_index]( + query[:, dis_start:dis_start + num_dis]).sigmoid().max(-1).values + proposals = reference_points[:, dis_start:dis_start + num_dis] + proposals = bbox_cxcywh_to_xyxy(proposals) + + attn_mask_list = [] + for img_id in range(num_imgs): + single_proposals = proposals[img_id] + single_scores = scores[img_id] + attn_mask = ~dis_mask[img_id * self.cache_dict['num_heads']][0] + # distinct query inds in this layer + ori_index = attn_mask.nonzero().view(-1) + _, keep_idxs = batched_nms(single_proposals[ori_index], + single_scores[ori_index], + torch.ones(len(ori_index)), + self.cache_dict['dqs_cfg']) + + real_keep_index = ori_index[keep_idxs] + + attn_mask = torch.ones_like(dis_mask[0]).bool() + # such a attn_mask give best result + # If it requires to keep index i, then all cells in row or column + # i should be kept in `attn_mask` . For example, if + # `real_keep_index` = [1, 4], and `attn_mask` size = [8, 8], + # then all cells at rows or columns [1, 4] should be kept, and + # all the other cells should be masked out. So the value of + # `attn_mask` should be: + # + # target\source 0 1 2 3 4 5 6 7 + # 0 [ 0 1 0 0 1 0 0 0 ] + # 1 [ 1 1 1 1 1 1 1 1 ] + # 2 [ 0 1 0 0 1 0 0 0 ] + # 3 [ 0 1 0 0 1 0 0 0 ] + # 4 [ 1 1 1 1 1 1 1 1 ] + # 5 [ 0 1 0 0 1 0 0 0 ] + # 6 [ 0 1 0 0 1 0 0 0 ] + # 7 [ 0 1 0 0 1 0 0 0 ] + attn_mask[real_keep_index] = False + attn_mask[:, real_keep_index] = False + + attn_mask = attn_mask[None].repeat(self.cache_dict['num_heads'], 1, + 1) + attn_mask_list.append(attn_mask) + attn_mask = torch.cat(attn_mask_list) + self_attn_mask = copy.deepcopy(self_attn_mask) + self_attn_mask[:, dis_start:dis_start + num_dis, + dis_start:dis_start + num_dis] = attn_mask + # will be used in loss and inference + self.cache_dict['distinct_query_mask'].append(~attn_mask) + return self_attn_mask + + def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, + self_attn_mask: Tensor, reference_points: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor, reg_branches: nn.ModuleList, + **kwargs) -> Tensor: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, + dims). + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + self_attn_mask (Tensor): The attention mask to prevent information + leakage from different denoising groups, distinct queries and + dense queries, has shape (num_queries_total, + num_queries_total). It will be updated for distinct queries + selection in this forward function. It is `None` when + `self.training` is `False`. + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`): Used for refining the + regression results. + + Returns: + tuple[Tensor]: Output queries and references of Transformer + decoder + + - query (Tensor): Output embeddings of the last decoder, has + shape (bs, num_queries, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, bs, num_queries, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (1 + num_decoder_layers, bs, num_queries, 4). + The coordinates are arranged as (cx, cy, w, h). + """ + intermediate = [] + intermediate_reference_points = [reference_points] + self.cache_dict['distinct_query_mask'] = [] + if self_attn_mask is None: + self_attn_mask = torch.zeros((query.size(1), query.size(1)), + device=query.device).bool() + # shape is (batch*number_heads, num_queries, num_queries) + self_attn_mask = self_attn_mask[None].repeat( + len(query) * self.cache_dict['num_heads'], 1, 1) + for layer_index, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * torch.cat( + [valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * valid_ratios[:, None] + + query_sine_embed = coordinate_to_encoding( + reference_points_input[:, :, 0, :], + num_feats=self.embed_dims // 2) + query_pos = self.ref_point_head(query_sine_embed) + + query = layer( + query, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + self_attn_mask=self_attn_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + + if not self.training: + tmp = reg_branches[layer_index](query) + assert reference_points.shape[-1] == 4 + new_reference_points = tmp + inverse_sigmoid( + reference_points, eps=1e-3) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + if layer_index < (len(self.layers) - 1): + self_attn_mask = self.select_distinct_queries( + reference_points, query, self_attn_mask, layer_index) + + else: + num_dense = self.cache_dict['num_dense_queries'] + tmp = reg_branches[layer_index](query[:, :-num_dense]) + tmp_dense = self.aux_reg_branches[layer_index]( + query[:, -num_dense:]) + + tmp = torch.cat([tmp, tmp_dense], dim=1) + assert reference_points.shape[-1] == 4 + new_reference_points = tmp + inverse_sigmoid( + reference_points, eps=1e-3) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + if layer_index < (len(self.layers) - 1): + self_attn_mask = self.select_distinct_queries( + reference_points, query, self_attn_mask, layer_index) + + if self.return_intermediate: + intermediate.append(self.norm(query)) + intermediate_reference_points.append(new_reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return query, reference_points diff --git a/mmdet/models/layers/transformer/deformable_detr_layers.py b/mmdet/models/layers/transformer/deformable_detr_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..da6325d61270eb3546a39d5487587bc0610434d6 --- /dev/null +++ b/mmdet/models/layers/transformer/deformable_detr_layers.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import ModuleList +from torch import Tensor, nn + +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) +from .utils import inverse_sigmoid + +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except Exception: + checkpoint_wrapper = None + + +class DeformableDetrTransformerEncoder(DetrTransformerEncoder): + """Transformer encoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + + if self.num_cp > 0: + if checkpoint_wrapper is None: + raise NotImplementedError( + 'If you want to reduce GPU memory usage, \ + please install fairscale by executing the \ + following command: pip install fairscale.') + for i in range(self.num_cp): + self.layers[i] = checkpoint_wrapper(self.layers[i]) + + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + **kwargs) -> Tensor: + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, has shape + (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + Tensor: Output queries of Transformer encoder, which is also + called 'encoder output embeddings' or 'memory', has shape + (bs, num_queries, dim) + """ + reference_points = self.get_encoder_reference_points( + spatial_shapes, valid_ratios, device=query.device) + for layer in self.layers: + query = layer( + query=query, + query_pos=query_pos, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points, + **kwargs) + return query + + @staticmethod + def get_encoder_reference_points( + spatial_shapes: Tensor, valid_ratios: Tensor, + device: Union[torch.device, str]) -> Tensor: + """Get the reference points used in encoder. + + Args: + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + device (obj:`device` or str): The device acquired by the + `reference_points`. + + Returns: + Tensor: Reference points used in decoder, has shape (bs, length, + num_levels, 2). + """ + + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + # [bs, sum(hw), num_level, 2] + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + +class DeformableDetrTransformerDecoder(DetrTransformerDecoder): + """Transformer Decoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + f'{self._get_name()}') + + def forward(self, + query: Tensor, + query_pos: Tensor, + value: Tensor, + key_padding_mask: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + reg_branches: Optional[nn.Module] = None, + **kwargs) -> Tuple[Tensor]: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input queries, has shape (bs, num_queries, + dim). + query_pos (Tensor): The input positional query, has shape + (bs, num_queries, dim). It will be added to `query` before + forward function. + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. Only would be passed when + `with_box_refine` is `True`, otherwise would be `None`. + + Returns: + tuple[Tensor]: Outputs of Deformable Transformer Decoder. + + - output (Tensor): Output embeddings of the last decoder, has + shape (num_queries, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_queries, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_queries, 4). The + coordinates are arranged as (cx, cy, w, h) + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for layer_id, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + + if reg_branches is not None: + tmp_reg_preds = reg_branches[layer_id](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp_reg_preds + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp_reg_preds + new_reference_points[..., :2] = tmp_reg_preds[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): + """Encoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, ffn, and norms.""" + self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + +class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Decoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) diff --git a/mmdet/models/layers/transformer/detr_layers.py b/mmdet/models/layers/transformer/detr_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..6a83dd2faa660ed8f54bdd08271db1fcf6b53886 --- /dev/null +++ b/mmdet/models/layers/transformer/detr_layers.py @@ -0,0 +1,374 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine import ConfigDict +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + +from mmdet.utils import ConfigType, OptConfigType + +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except Exception: + checkpoint_wrapper = None + + +class DetrTransformerEncoder(BaseModule): + """Encoder of DETR. + + Args: + num_layers (int): Number of encoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + num_cp (int): Number of checkpointing blocks in encoder layer. + Default to -1. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + num_cp: int = -1, + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + self.num_layers = num_layers + self.layer_cfg = layer_cfg + self.num_cp = num_cp + assert self.num_cp <= self.num_layers + self._init_layers() + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + + if self.num_cp > 0: + if checkpoint_wrapper is None: + raise NotImplementedError( + 'If you want to reduce GPU memory usage, \ + please install fairscale by executing the \ + following command: pip install fairscale.') + for i in range(self.num_cp): + self.layers[i] = checkpoint_wrapper(self.layers[i]) + + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of encoder. + + Args: + query (Tensor): Input queries of encoder, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional embeddings of the queries, has + shape (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + + Returns: + Tensor: Has shape (bs, num_queries, dim) if `batch_first` is + `True`, otherwise (num_queries, bs, dim). + """ + for layer in self.layers: + query = layer(query, query_pos, key_padding_mask, **kwargs) + return query + + +class DetrTransformerDecoder(BaseModule): + """Decoder of DETR. + + Args: + num_layers (int): Number of decoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the + post normalization layer. Defaults to `LN`. + return_intermediate (bool, optional): Whether to return outputs of + intermediate layers. Defaults to `True`, + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + post_norm_cfg: OptConfigType = dict(type='LN'), + return_intermediate: bool = True, + init_cfg: Union[dict, ConfigDict] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.layer_cfg = layer_cfg + self.num_layers = num_layers + self.post_norm_cfg = post_norm_cfg + self.return_intermediate = return_intermediate + self._init_layers() + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + self.post_norm = build_norm_layer(self.post_norm_cfg, + self.embed_dims)[1] + + def forward(self, query: Tensor, key: Tensor, value: Tensor, + query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, + **kwargs) -> Tensor: + """Forward function of decoder + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor): The input key, has shape (bs, num_keys, dim). + value (Tensor): The input value with the same shape as `key`. + query_pos (Tensor): The positional encoding for `query`, with the + same shape as `query`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + + Returns: + Tensor: The forwarded results will have shape + (num_decoder_layers, bs, num_queries, dim) if + `return_intermediate` is `True` else (1, bs, num_queries, dim). + """ + intermediate = [] + for layer in self.layers: + query = layer( + query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + key_padding_mask=key_padding_mask, + **kwargs) + if self.return_intermediate: + intermediate.append(self.post_norm(query)) + query = self.post_norm(query) + + if self.return_intermediate: + return torch.stack(intermediate) + + return query.unsqueeze(0) + + +class DetrTransformerEncoderLayer(BaseModule): + """Implements encoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True)), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + + self.self_attn_cfg = self_attn_cfg + if 'batch_first' not in self.self_attn_cfg: + self.self_attn_cfg['batch_first'] = True + else: + assert self.self_attn_cfg['batch_first'] is True, 'First \ + dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` flag.' + + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of an encoder layer. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, with + the same shape as `query`. + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor. has shape (bs, num_queries). + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[0](query) + query = self.ffn(query) + query = self.norms[1](query) + + return query + + +class DetrTransformerDecoderLayer(BaseModule): + """Implements decoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + + self.self_attn_cfg = self_attn_cfg + self.cross_attn_cfg = cross_attn_cfg + if 'batch_first' not in self.self_attn_cfg: + self.self_attn_cfg['batch_first'] = True + else: + assert self.self_attn_cfg['batch_first'] is True, 'First \ + dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` flag.' + + if 'batch_first' not in self.cross_attn_cfg: + self.cross_attn_cfg['batch_first'] = True + else: + assert self.cross_attn_cfg['batch_first'] is True, 'First \ + dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` flag.' + + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + query = self.norms[0](query) + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query diff --git a/mmdet/models/layers/transformer/dino_layers.py b/mmdet/models/layers/transformer/dino_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..64610d0a7c0121a88f5e4279b6f854924230237e --- /dev/null +++ b/mmdet/models/layers/transformer/dino_layers.py @@ -0,0 +1,562 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Tuple, Union + +import torch +from mmengine.model import BaseModule +from torch import Tensor, nn + +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh +from mmdet.utils import OptConfigType +from .deformable_detr_layers import DeformableDetrTransformerDecoder +from .utils import MLP, coordinate_to_encoding, inverse_sigmoid + + +class DinoTransformerDecoder(DeformableDetrTransformerDecoder): + """Transformer decoder of DINO.""" + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + super()._init_layers() + self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, + self.embed_dims, 2) + self.norm = nn.LayerNorm(self.embed_dims) + + def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, + self_attn_mask: Tensor, reference_points: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor, reg_branches: nn.ModuleList, + **kwargs) -> Tuple[Tensor]: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input query, has shape (num_queries, bs, dim). + value (Tensor): The input values, has shape (num_value, bs, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (num_queries, bs). + self_attn_mask (Tensor): The attention mask to prevent information + leakage from different denoising groups and matching parts, has + shape (num_queries_total, num_queries_total). It is `None` when + `self.training` is `False`. + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`): Used for refining the + regression results. + + Returns: + tuple[Tensor]: Output queries and references of Transformer + decoder + + - query (Tensor): Output embeddings of the last decoder, has + shape (num_queries, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_queries, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_queries, 4). The + coordinates are arranged as (cx, cy, w, h) + """ + intermediate = [] + intermediate_reference_points = [reference_points] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * torch.cat( + [valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * valid_ratios[:, None] + + query_sine_embed = coordinate_to_encoding( + reference_points_input[:, :, 0, :]) + query_pos = self.ref_point_head(query_sine_embed) + + query = layer( + query, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + self_attn_mask=self_attn_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + + if reg_branches is not None: + tmp = reg_branches[lid](query) + assert reference_points.shape[-1] == 4 + new_reference_points = tmp + inverse_sigmoid( + reference_points, eps=1e-3) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(self.norm(query)) + intermediate_reference_points.append(new_reference_points) + # NOTE this is for the "Look Forward Twice" module, + # in the DeformDETR, reference_points was appended. + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return query, reference_points + + +class CdnQueryGenerator(BaseModule): + """Implement query generator of the Contrastive denoising (CDN) proposed in + `DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object + Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + num_classes (int): Number of object classes. + embed_dims (int): The embedding dimensions of the generated queries. + num_matching_queries (int): The queries number of the matching part. + Used for generating dn_mask. + label_noise_scale (float): The scale of label noise, defaults to 0.5. + box_noise_scale (float): The scale of box noise, defaults to 1.0. + group_cfg (:obj:`ConfigDict` or dict, optional): The config of the + denoising queries grouping, includes `dynamic`, `num_dn_queries`, + and `num_groups`. Two grouping strategies, 'static dn groups' and + 'dynamic dn groups', are supported. When `dynamic` is `False`, + the `num_groups` should be set, and the number of denoising query + groups will always be `num_groups`. When `dynamic` is `True`, the + `num_dn_queries` should be set, and the group number will be + dynamic to ensure that the denoising queries number will not exceed + `num_dn_queries` to prevent large fluctuations of memory. Defaults + to `None`. + """ + + def __init__(self, + num_classes: int, + embed_dims: int, + num_matching_queries: int, + label_noise_scale: float = 0.5, + box_noise_scale: float = 1.0, + group_cfg: OptConfigType = None) -> None: + super().__init__() + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_matching_queries = num_matching_queries + self.label_noise_scale = label_noise_scale + self.box_noise_scale = box_noise_scale + + # prepare grouping strategy + group_cfg = {} if group_cfg is None else group_cfg + self.dynamic_dn_groups = group_cfg.get('dynamic', True) + if self.dynamic_dn_groups: + if 'num_dn_queries' not in group_cfg: + warnings.warn("'num_dn_queries' should be set when using " + 'dynamic dn groups, use 100 as default.') + self.num_dn_queries = group_cfg.get('num_dn_queries', 100) + assert isinstance(self.num_dn_queries, int), \ + f'Expected the num_dn_queries to have type int, but got ' \ + f'{self.num_dn_queries}({type(self.num_dn_queries)}). ' + else: + assert 'num_groups' in group_cfg, \ + 'num_groups should be set when using static dn groups' + self.num_groups = group_cfg['num_groups'] + assert isinstance(self.num_groups, int), \ + f'Expected the num_groups to have type int, but got ' \ + f'{self.num_groups}({type(self.num_groups)}). ' + + # NOTE The original repo of DINO set the num_embeddings 92 for coco, + # 91 (0~90) of which represents target classes and the 92 (91) + # indicates `Unknown` class. However, the embedding of `unknown` class + # is not used in the original DINO. + # TODO: num_classes + 1 or num_classes ? + self.label_embedding = nn.Embedding(self.num_classes, self.embed_dims) + + def __call__(self, batch_data_samples: SampleList) -> tuple: + """Generate contrastive denoising (cdn) queries with ground truth. + + Descriptions of the Number Values in code and comments: + - num_target_total: the total target number of the input batch + samples. + - max_num_target: the max target number of the input batch samples. + - num_noisy_targets: the total targets number after adding noise, + i.e., num_target_total * num_groups * 2. + - num_denoising_queries: the length of the output batched queries, + i.e., max_num_target * num_groups * 2. + + NOTE The format of input bboxes in batch_data_samples is unnormalized + (x, y, x, y), and the output bbox queries are embedded by normalized + (cx, cy, w, h) format bboxes going through inverse_sigmoid. + + Args: + batch_data_samples (list[:obj:`DetDataSample`]): List of the batch + data samples, each includes `gt_instance` which has attributes + `bboxes` and `labels`. The `bboxes` has unnormalized coordinate + format (x, y, x, y). + + Returns: + tuple: The outputs of the dn query generator. + + - dn_label_query (Tensor): The output content queries for denoising + part, has shape (bs, num_denoising_queries, dim), where + `num_denoising_queries = max_num_target * num_groups * 2`. + - dn_bbox_query (Tensor): The output reference bboxes as positions + of queries for denoising part, which are embedded by normalized + (cx, cy, w, h) format bboxes going through inverse_sigmoid, has + shape (bs, num_denoising_queries, 4) with the last dimension + arranged as (cx, cy, w, h). + - attn_mask (Tensor): The attention mask to prevent information + leakage from different denoising groups and matching parts, + will be used as `self_attn_mask` of the `decoder`, has shape + (num_queries_total, num_queries_total), where `num_queries_total` + is the sum of `num_denoising_queries` and `num_matching_queries`. + - dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + """ + # normalize bbox and collate ground truth (gt) + gt_labels_list = [] + gt_bboxes_list = [] + for sample in batch_data_samples: + img_h, img_w = sample.img_shape + bboxes = sample.gt_instances.bboxes + factor = bboxes.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + bboxes_normalized = bboxes / factor + gt_bboxes_list.append(bboxes_normalized) + gt_labels_list.append(sample.gt_instances.labels) + gt_labels = torch.cat(gt_labels_list) # (num_target_total, 4) + gt_bboxes = torch.cat(gt_bboxes_list) + + num_target_list = [len(bboxes) for bboxes in gt_bboxes_list] + max_num_target = max(num_target_list) + num_groups = self.get_num_groups(max_num_target) + + dn_label_query = self.generate_dn_label_query(gt_labels, num_groups) + dn_bbox_query = self.generate_dn_bbox_query(gt_bboxes, num_groups) + + # The `batch_idx` saves the batch index of the corresponding sample + # for each target, has shape (num_target_total). + batch_idx = torch.cat([ + torch.full_like(t.long(), i) for i, t in enumerate(gt_labels_list) + ]) + dn_label_query, dn_bbox_query = self.collate_dn_queries( + dn_label_query, dn_bbox_query, batch_idx, len(batch_data_samples), + num_groups) + + attn_mask = self.generate_dn_mask( + max_num_target, num_groups, device=dn_label_query.device) + + dn_meta = dict( + num_denoising_queries=int(max_num_target * 2 * num_groups), + num_denoising_groups=num_groups) + + return dn_label_query, dn_bbox_query, attn_mask, dn_meta + + def get_num_groups(self, max_num_target: int = None) -> int: + """Calculate denoising query groups number. + + Two grouping strategies, 'static dn groups' and 'dynamic dn groups', + are supported. When `self.dynamic_dn_groups` is `False`, the number + of denoising query groups will always be `self.num_groups`. When + `self.dynamic_dn_groups` is `True`, the group number will be dynamic, + ensuring the denoising queries number will not exceed + `self.num_dn_queries` to prevent large fluctuations of memory. + + NOTE The `num_group` is shared for different samples in a batch. When + the target numbers in the samples varies, the denoising queries of the + samples containing fewer targets are padded to the max length. + + Args: + max_num_target (int, optional): The max target number of the batch + samples. It will only be used when `self.dynamic_dn_groups` is + `True`. Defaults to `None`. + + Returns: + int: The denoising group number of the current batch. + """ + if self.dynamic_dn_groups: + assert max_num_target is not None, \ + 'group_queries should be provided when using ' \ + 'dynamic dn groups' + if max_num_target == 0: + num_groups = 1 + else: + num_groups = self.num_dn_queries // max_num_target + else: + num_groups = self.num_groups + if num_groups < 1: + num_groups = 1 + return int(num_groups) + + def generate_dn_label_query(self, gt_labels: Tensor, + num_groups: int) -> Tensor: + """Generate noisy labels and their query embeddings. + + The strategy for generating noisy labels is: Randomly choose labels of + `self.label_noise_scale * 0.5` proportion and override each of them + with a random object category label. + + NOTE Not add noise to all labels. Besides, the `self.label_noise_scale + * 0.5` arg is the ratio of the chosen positions, which is higher than + the actual proportion of noisy labels, because the labels to override + may be correct. And the gap becomes larger as the number of target + categories decreases. The users should notice this and modify the scale + arg or the corresponding logic according to specific dataset. + + Args: + gt_labels (Tensor): The concatenated gt labels of all samples + in the batch, has shape (num_target_total, ) where + `num_target_total = sum(num_target_list)`. + num_groups (int): The number of denoising query groups. + + Returns: + Tensor: The query embeddings of noisy labels, has shape + (num_noisy_targets, embed_dims), where `num_noisy_targets = + num_target_total * num_groups * 2`. + """ + assert self.label_noise_scale > 0 + gt_labels_expand = gt_labels.repeat(2 * num_groups, + 1).view(-1) # Note `* 2` # noqa + p = torch.rand_like(gt_labels_expand.float()) + chosen_indice = torch.nonzero(p < (self.label_noise_scale * 0.5)).view( + -1) # Note `* 0.5` + new_labels = torch.randint_like(chosen_indice, 0, self.num_classes) + noisy_labels_expand = gt_labels_expand.scatter(0, chosen_indice, + new_labels) + dn_label_query = self.label_embedding(noisy_labels_expand) + return dn_label_query + + def generate_dn_bbox_query(self, gt_bboxes: Tensor, + num_groups: int) -> Tensor: + """Generate noisy bboxes and their query embeddings. + + The strategy for generating noisy bboxes is as follow: + + .. code:: text + + +--------------------+ + | negative | + | +----------+ | + | | positive | | + | | +-----|----+------------+ + | | | | | | + | +----+-----+ | | + | | | | + +---------+----------+ | + | | + | gt bbox | + | | + | +---------+----------+ + | | | | + | | +----+-----+ | + | | | | | | + +-------------|--- +----+ | | + | | positive | | + | +----------+ | + | negative | + +--------------------+ + + The random noise is added to the top-left and down-right point + positions, hence, normalized (x, y, x, y) format of bboxes are + required. The noisy bboxes of positive queries have the points + both within the inner square, while those of negative queries + have the points both between the inner and outer squares. + + Besides, the length of outer square is twice as long as that of + the inner square, i.e., self.box_noise_scale * w_or_h / 2. + NOTE The noise is added to all the bboxes. Moreover, there is still + unconsidered case when one point is within the positive square and + the others is between the inner and outer squares. + + Args: + gt_bboxes (Tensor): The concatenated gt bboxes of all samples + in the batch, has shape (num_target_total, 4) with the last + dimension arranged as (cx, cy, w, h) where + `num_target_total = sum(num_target_list)`. + num_groups (int): The number of denoising query groups. + + Returns: + Tensor: The output noisy bboxes, which are embedded by normalized + (cx, cy, w, h) format bboxes going through inverse_sigmoid, has + shape (num_noisy_targets, 4) with the last dimension arranged as + (cx, cy, w, h), where + `num_noisy_targets = num_target_total * num_groups * 2`. + """ + assert self.box_noise_scale > 0 + device = gt_bboxes.device + + # expand gt_bboxes as groups + gt_bboxes_expand = gt_bboxes.repeat(2 * num_groups, 1) # xyxy + + # obtain index of negative queries in gt_bboxes_expand + positive_idx = torch.arange( + len(gt_bboxes), dtype=torch.long, device=device) + positive_idx = positive_idx.unsqueeze(0).repeat(num_groups, 1) + positive_idx += 2 * len(gt_bboxes) * torch.arange( + num_groups, dtype=torch.long, device=device)[:, None] + positive_idx = positive_idx.flatten() + negative_idx = positive_idx + len(gt_bboxes) + + # determine the sign of each element in the random part of the added + # noise to be positive or negative randomly. + rand_sign = torch.randint_like( + gt_bboxes_expand, low=0, high=2, + dtype=torch.float32) * 2.0 - 1.0 # [low, high), 1 or -1, randomly + + # calculate the random part of the added noise + rand_part = torch.rand_like(gt_bboxes_expand) # [0, 1) + rand_part[negative_idx] += 1.0 # pos: [0, 1); neg: [1, 2) + rand_part *= rand_sign # pos: (-1, 1); neg: (-2, -1] U [1, 2) + + # add noise to the bboxes + bboxes_whwh = bbox_xyxy_to_cxcywh(gt_bboxes_expand)[:, 2:].repeat(1, 2) + noisy_bboxes_expand = gt_bboxes_expand + torch.mul( + rand_part, bboxes_whwh) * self.box_noise_scale / 2 # xyxy + noisy_bboxes_expand = noisy_bboxes_expand.clamp(min=0.0, max=1.0) + noisy_bboxes_expand = bbox_xyxy_to_cxcywh(noisy_bboxes_expand) + + dn_bbox_query = inverse_sigmoid(noisy_bboxes_expand, eps=1e-3) + return dn_bbox_query + + def collate_dn_queries(self, input_label_query: Tensor, + input_bbox_query: Tensor, batch_idx: Tensor, + batch_size: int, num_groups: int) -> Tuple[Tensor]: + """Collate generated queries to obtain batched dn queries. + + The strategy for query collation is as follow: + + .. code:: text + + input_queries (num_target_total, query_dim) + P_A1 P_B1 P_B2 N_A1 N_B1 N_B2 P'A1 P'B1 P'B2 N'A1 N'B1 N'B2 + |________ group1 ________| |________ group2 ________| + | + V + P_A1 Pad0 N_A1 Pad0 P'A1 Pad0 N'A1 Pad0 + P_B1 P_B2 N_B1 N_B2 P'B1 P'B2 N'B1 N'B2 + |____ group1 ____| |____ group2 ____| + batched_queries (batch_size, max_num_target, query_dim) + + where query_dim is 4 for bbox and self.embed_dims for label. + Notation: _-group 1; '-group 2; + A-Sample1(has 1 target); B-sample2(has 2 targets) + + Args: + input_label_query (Tensor): The generated label queries of all + targets, has shape (num_target_total, embed_dims) where + `num_target_total = sum(num_target_list)`. + input_bbox_query (Tensor): The generated bbox queries of all + targets, has shape (num_target_total, 4) with the last + dimension arranged as (cx, cy, w, h). + batch_idx (Tensor): The batch index of the corresponding sample + for each target, has shape (num_target_total). + batch_size (int): The size of the input batch. + num_groups (int): The number of denoising query groups. + + Returns: + tuple[Tensor]: Output batched label and bbox queries. + - batched_label_query (Tensor): The output batched label queries, + has shape (batch_size, max_num_target, embed_dims). + - batched_bbox_query (Tensor): The output batched bbox queries, + has shape (batch_size, max_num_target, 4) with the last dimension + arranged as (cx, cy, w, h). + """ + device = input_label_query.device + num_target_list = [ + torch.sum(batch_idx == idx) for idx in range(batch_size) + ] + max_num_target = max(num_target_list) + num_denoising_queries = int(max_num_target * 2 * num_groups) + + map_query_index = torch.cat([ + torch.arange(num_target, device=device) + for num_target in num_target_list + ]) + map_query_index = torch.cat([ + map_query_index + max_num_target * i for i in range(2 * num_groups) + ]).long() + batch_idx_expand = batch_idx.repeat(2 * num_groups, 1).view(-1) + mapper = (batch_idx_expand, map_query_index) + + batched_label_query = torch.zeros( + batch_size, num_denoising_queries, self.embed_dims, device=device) + batched_bbox_query = torch.zeros( + batch_size, num_denoising_queries, 4, device=device) + + batched_label_query[mapper] = input_label_query + batched_bbox_query[mapper] = input_bbox_query + return batched_label_query, batched_bbox_query + + def generate_dn_mask(self, max_num_target: int, num_groups: int, + device: Union[torch.device, str]) -> Tensor: + """Generate attention mask to prevent information leakage from + different denoising groups and matching parts. + + .. code:: text + + 0 0 0 0 1 1 1 1 0 0 0 0 0 + 0 0 0 0 1 1 1 1 0 0 0 0 0 + 0 0 0 0 1 1 1 1 0 0 0 0 0 + 0 0 0 0 1 1 1 1 0 0 0 0 0 + 1 1 1 1 0 0 0 0 0 0 0 0 0 + 1 1 1 1 0 0 0 0 0 0 0 0 0 + 1 1 1 1 0 0 0 0 0 0 0 0 0 + 1 1 1 1 0 0 0 0 0 0 0 0 0 + 1 1 1 1 1 1 1 1 0 0 0 0 0 + 1 1 1 1 1 1 1 1 0 0 0 0 0 + 1 1 1 1 1 1 1 1 0 0 0 0 0 + 1 1 1 1 1 1 1 1 0 0 0 0 0 + 1 1 1 1 1 1 1 1 0 0 0 0 0 + max_num_target |_| |_________| num_matching_queries + |_____________| num_denoising_queries + + 1 -> True (Masked), means 'can not see'. + 0 -> False (UnMasked), means 'can see'. + + Args: + max_num_target (int): The max target number of the input batch + samples. + num_groups (int): The number of denoising query groups. + device (obj:`device` or str): The device of generated mask. + + Returns: + Tensor: The attention mask to prevent information leakage from + different denoising groups and matching parts, will be used as + `self_attn_mask` of the `decoder`, has shape (num_queries_total, + num_queries_total), where `num_queries_total` is the sum of + `num_denoising_queries` and `num_matching_queries`. + """ + num_denoising_queries = int(max_num_target * 2 * num_groups) + num_queries_total = num_denoising_queries + self.num_matching_queries + attn_mask = torch.zeros( + num_queries_total, + num_queries_total, + device=device, + dtype=torch.bool) + # Make the matching part cannot see the denoising groups + attn_mask[num_denoising_queries:, :num_denoising_queries] = True + # Make the denoising groups cannot see each other + for i in range(num_groups): + # Mask rows of one group per step. + row_scope = slice(max_num_target * 2 * i, + max_num_target * 2 * (i + 1)) + left_scope = slice(max_num_target * 2 * i) + right_scope = slice(max_num_target * 2 * (i + 1), + num_denoising_queries) + attn_mask[row_scope, right_scope] = True + attn_mask[row_scope, left_scope] = True + return attn_mask diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3c285768f36af98075607b43e48e6f1018125ad1 --- /dev/null +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -0,0 +1,270 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import ModuleList +from torch import Tensor + +from mmdet.models.utils.vlfuse_helper import SingleScaleBiAttentionBlock +from mmdet.utils import ConfigType, OptConfigType +from .deformable_detr_layers import (DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer) +from .detr_layers import DetrTransformerEncoderLayer +from .dino_layers import DinoTransformerDecoder +from .utils import MLP, get_text_sine_pos_embed + +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except Exception: + checkpoint_wrapper = None + + +class GroundingDinoTransformerDecoderLayer( + DeformableDetrTransformerDecoderLayer): + + def __init__(self, + cross_attn_text_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + **kwargs) -> None: + """Decoder layer of Deformable DETR.""" + self.cross_attn_text_cfg = cross_attn_text_cfg + if 'batch_first' not in self.cross_attn_text_cfg: + self.cross_attn_text_cfg['batch_first'] = True + super().__init__(**kwargs) + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn_text = MultiheadAttention(**self.cross_attn_text_cfg) + self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(4) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + memory_text: Tensor = None, + text_attention_mask: Tensor = None, + **kwargs) -> Tensor: + """Implements decoder layer in Grounding DINO transformer. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + text_attention_mask (Tensor): Text token mask. It has shape (bs, + len_text). + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + # self attention + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + query = self.norms[0](query) + # cross attention between query and text + query = self.cross_attn_text( + query=query, + query_pos=query_pos, + key=memory_text, + value=memory_text, + key_padding_mask=text_attention_mask) + query = self.norms[1](query) + # cross attention between query and image + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[2](query) + query = self.ffn(query) + query = self.norms[3](query) + + return query + + +class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder): + + def __init__(self, text_layer_cfg: ConfigType, + fusion_layer_cfg: ConfigType, **kwargs) -> None: + self.text_layer_cfg = text_layer_cfg + self.fusion_layer_cfg = fusion_layer_cfg + super().__init__(**kwargs) + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.text_layers = ModuleList([ + DetrTransformerEncoderLayer(**self.text_layer_cfg) + for _ in range(self.num_layers) + ]) + self.fusion_layers = ModuleList([ + SingleScaleBiAttentionBlock(**self.fusion_layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.num_cp > 0: + if checkpoint_wrapper is None: + raise NotImplementedError( + 'If you want to reduce GPU memory usage, \ + please install fairscale by executing the \ + following command: pip install fairscale.') + for i in range(self.num_cp): + self.layers[i] = checkpoint_wrapper(self.layers[i]) + self.fusion_layers[i] = checkpoint_wrapper( + self.fusion_layers[i]) + + def forward(self, + query: Tensor, + query_pos: Tensor, + key_padding_mask: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + memory_text: Tensor = None, + text_attention_mask: Tensor = None, + pos_text: Tensor = None, + text_self_attention_masks: Tensor = None, + position_ids: Tensor = None): + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, has shape + (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + memory_text (Tensor, optional): Memory text. It has shape (bs, + len_text, text_embed_dims). + text_attention_mask (Tensor, optional): Text token mask. It has + shape (bs,len_text). + pos_text (Tensor, optional): The positional encoding for text. + Defaults to None. + text_self_attention_masks (Tensor, optional): Text self attention + mask. Defaults to None. + position_ids (Tensor, optional): Text position ids. + Defaults to None. + """ + output = query + reference_points = self.get_encoder_reference_points( + spatial_shapes, valid_ratios, device=query.device) + if self.text_layers: + # generate pos_text + bs, n_text, _ = memory_text.shape + if pos_text is None and position_ids is None: + pos_text = ( + torch.arange(n_text, + device=memory_text.device).float().unsqueeze( + 0).unsqueeze(-1).repeat(bs, 1, 1)) + pos_text = get_text_sine_pos_embed( + pos_text, num_pos_feats=256, exchange_xy=False) + if position_ids is not None: + pos_text = get_text_sine_pos_embed( + position_ids[..., None], + num_pos_feats=256, + exchange_xy=False) + + # main process + for layer_id, layer in enumerate(self.layers): + if self.fusion_layers: + output, memory_text = self.fusion_layers[layer_id]( + visual_feature=output, + lang_feature=memory_text, + attention_mask_v=key_padding_mask, + attention_mask_l=text_attention_mask, + ) + if self.text_layers: + text_num_heads = self.text_layers[ + layer_id].self_attn_cfg.num_heads + memory_text = self.text_layers[layer_id]( + query=memory_text, + query_pos=(pos_text if pos_text is not None else None), + attn_mask=~text_self_attention_masks.repeat( + text_num_heads, 1, 1), # note we use ~ for mask here + key_padding_mask=None, + ) + output = layer( + query=output, + query_pos=query_pos, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=key_padding_mask) + return output, memory_text + + +class GroundingDinoTransformerDecoder(DinoTransformerDecoder): + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + GroundingDinoTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + f'{self._get_name()}') + self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, + self.embed_dims, 2) + self.norm = nn.LayerNorm(self.embed_dims) diff --git a/mmdet/models/layers/transformer/mask2former_layers.py b/mmdet/models/layers/transformer/mask2former_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc604e277d91151334ed520d78e6a5a8f388036 --- /dev/null +++ b/mmdet/models/layers/transformer/mask2former_layers.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_norm_layer +from mmengine.model import ModuleList +from torch import Tensor + +from .deformable_detr_layers import DeformableDetrTransformerEncoder +from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer + + +class Mask2FormerTransformerEncoder(DeformableDetrTransformerEncoder): + """Encoder in PixelDecoder of Mask2Former.""" + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + reference_points: Tensor, **kwargs) -> Tensor: + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, has shape + (bs, num_queries, dim). If not None, it will be added to the + `query` before forward function. Defaults to None. + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + + Returns: + Tensor: Output queries of Transformer encoder, which is also + called 'encoder output embeddings' or 'memory', has shape + (bs, num_queries, dim) + """ + for layer in self.layers: + query = layer( + query=query, + query_pos=query_pos, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points, + **kwargs) + return query + + +class Mask2FormerTransformerDecoder(DetrTransformerDecoder): + """Decoder of Mask2Former.""" + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + Mask2FormerTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + self.post_norm = build_norm_layer(self.post_norm_cfg, + self.embed_dims)[1] + + +class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Implements decoder layer in Mask2Former transformer.""" + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[0](query) + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query diff --git a/mmdet/models/layers/transformer/utils.py b/mmdet/models/layers/transformer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e43a172ca7175b23c82f60894faf38ec6c437e3 --- /dev/null +++ b/mmdet/models/layers/transformer/utils.py @@ -0,0 +1,915 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks.drop import Dropout +from mmengine.model import BaseModule, ModuleList +from mmengine.utils import to_2tuple +from torch import Tensor, nn + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig + + +def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor: + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len does not match H, W' + return x.transpose(1, 2).reshape(B, C, H, W).contiguous() + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def coordinate_to_encoding(coord_tensor: Tensor, + num_feats: int = 128, + temperature: int = 10000, + scale: float = 2 * math.pi): + """Convert coordinate tensor to positional encoding. + + Args: + coord_tensor (Tensor): Coordinate tensor to be converted to + positional encoding. With the last dimension as 2 or 4. + num_feats (int, optional): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. Defaults to 128. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + Returns: + Tensor: Returned encoded positional tensor. + """ + dim_t = torch.arange( + num_feats, dtype=torch.float32, device=coord_tensor.device) + dim_t = temperature**(2 * (dim_t // 2) / num_feats) + x_embed = coord_tensor[..., 0] * scale + y_embed = coord_tensor[..., 1] * scale + pos_x = x_embed[..., None] / dim_t + pos_y = y_embed[..., None] / dim_t + pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), + dim=-1).flatten(2) + pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), + dim=-1).flatten(2) + if coord_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=-1) + elif coord_tensor.size(-1) == 4: + w_embed = coord_tensor[..., 2] * scale + pos_w = w_embed[..., None] / dim_t + pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), + dim=-1).flatten(2) + + h_embed = coord_tensor[..., 3] * scale + pos_h = h_embed[..., None] / dim_t + pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), + dim=-1).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1) + else: + raise ValueError('Unknown pos_tensor shape(-1):{}'.format( + coord_tensor.size(-1))) + return pos + + +def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor: + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the inverse. + eps (float): EPS avoid numerical overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse function of sigmoid, has the same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super(AdaptivePadding, self).__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d. + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmengine.ConfigDict`, optional): The Config for + initialization. Default: None. + """ + + def __init__(self, + in_channels: int = 3, + embed_dims: int = 768, + conv_type: str = 'Conv2d', + kernel_size: int = 16, + stride: int = 16, + padding: Union[int, tuple, str] = 'corner', + dilation: int = 1, + bias: bool = True, + norm_cfg: OptConfigType = None, + input_size: Union[int, tuple] = None, + init_cfg: OptConfigType = None) -> None: + super(PatchEmbed, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]: + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Optional[Union[int, tuple]] = 2, + stride: Optional[Union[int, tuple]] = None, + padding: Union[int, tuple, str] = 'corner', + dilation: Optional[Union[int, tuple]] = 1, + bias: Optional[bool] = False, + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x: Tensor, + input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]: + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +class ConditionalAttention(BaseModule): + """A wrapper of conditional attention, dropout and residual connection. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop: A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + cross_attn (bool): Whether the attention module is for cross attention. + Default: False + keep_query_pos (bool): Whether to transform query_pos before cross + attention. + Default: False. + batch_first (bool): When it is True, Key, Query and Value are shape of + (batch, n, embed_dim), otherwise (n, batch, embed_dim). + Default: True. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + attn_drop: float = 0., + proj_drop: float = 0., + cross_attn: bool = False, + keep_query_pos: bool = False, + batch_first: bool = True, + init_cfg: OptMultiConfig = None): + super().__init__(init_cfg=init_cfg) + + assert batch_first is True, 'Set `batch_first`\ + to False is NOT supported in ConditionalAttention. \ + First dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` to True.' + + self.cross_attn = cross_attn + self.keep_query_pos = keep_query_pos + self.embed_dims = embed_dims + self.num_heads = num_heads + self.attn_drop = Dropout(attn_drop) + self.proj_drop = Dropout(proj_drop) + + self._init_layers() + + def _init_layers(self): + """Initialize layers for qkv projection.""" + embed_dims = self.embed_dims + self.qcontent_proj = Linear(embed_dims, embed_dims) + self.qpos_proj = Linear(embed_dims, embed_dims) + self.kcontent_proj = Linear(embed_dims, embed_dims) + self.kpos_proj = Linear(embed_dims, embed_dims) + self.v_proj = Linear(embed_dims, embed_dims) + if self.cross_attn: + self.qpos_sine_proj = Linear(embed_dims, embed_dims) + self.out_proj = Linear(embed_dims, embed_dims) + + nn.init.constant_(self.out_proj.bias, 0.) + + def forward_attn(self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor = None, + key_padding_mask: Tensor = None) -> Tuple[Tensor]: + """Forward process for `ConditionalAttention`. + + Args: + query (Tensor): The input query with shape [bs, num_queries, + embed_dims]. + key (Tensor): The key tensor with shape [bs, num_keys, + embed_dims]. + If None, the `query` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + Returns: + Tuple[Tensor]: Attention outputs of shape :math:`(N, L, E)`, + where :math:`N` is the batch size, :math:`L` is the target + sequence length , and :math:`E` is the embedding dimension + `embed_dim`. Attention weights per head of shape :math:` + (num_heads, L, S)`. where :math:`N` is batch size, :math:`L` + is target sequence length, and :math:`S` is the source sequence + length. + """ + assert key.size(1) == value.size(1), \ + f'{"key, value must have the same sequence length"}' + assert query.size(0) == key.size(0) == value.size(0), \ + f'{"batch size must be equal for query, key, value"}' + assert query.size(2) == key.size(2), \ + f'{"q_dims, k_dims must be equal"}' + assert value.size(2) == self.embed_dims, \ + f'{"v_dims must be equal to embed_dims"}' + + bs, tgt_len, hidden_dims = query.size() + _, src_len, _ = key.size() + head_dims = hidden_dims // self.num_heads + v_head_dims = self.embed_dims // self.num_heads + assert head_dims * self.num_heads == hidden_dims, \ + f'{"hidden_dims must be divisible by num_heads"}' + scaling = float(head_dims)**-0.5 + + q = query * scaling + k = key + v = value + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or \ + attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or \ + attn_mask.dtype == torch.uint8 or \ + attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for \ + attn_mask' + + if attn_mask.dtype == torch.uint8: + warnings.warn('Byte tensor for attn_mask is deprecated.\ + Use bool tensor instead.') + attn_mask = attn_mask.to(torch.bool) + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(1), key.size(1)]: + raise RuntimeError( + 'The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bs * self.num_heads, + query.size(1), + key.size(1) + ]: + raise RuntimeError( + 'The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim())) + # attn_mask's dim is 3 now. + + if key_padding_mask is not None and key_padding_mask.dtype == int: + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(bs, tgt_len, self.num_heads, + head_dims).permute(0, 2, 1, 3).flatten(0, 1) + if k is not None: + k = k.contiguous().view(bs, src_len, self.num_heads, + head_dims).permute(0, 2, 1, + 3).flatten(0, 1) + if v is not None: + v = v.contiguous().view(bs, src_len, self.num_heads, + v_head_dims).permute(0, 2, 1, + 3).flatten(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bs + assert key_padding_mask.size(1) == src_len + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [ + bs * self.num_heads, tgt_len, src_len + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bs, self.num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view( + bs * self.num_heads, tgt_len, src_len) + + attn_output_weights = F.softmax( + attn_output_weights - + attn_output_weights.max(dim=-1, keepdim=True)[0], + dim=-1) + attn_output_weights = self.attn_drop(attn_output_weights) + + attn_output = torch.bmm(attn_output_weights, v) + assert list( + attn_output.size()) == [bs * self.num_heads, tgt_len, v_head_dims] + attn_output = attn_output.view(bs, self.num_heads, tgt_len, + v_head_dims).permute(0, 2, 1, + 3).flatten(2) + attn_output = self.out_proj(attn_output) + + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bs, self.num_heads, + tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / self.num_heads + + def forward(self, + query: Tensor, + key: Tensor, + query_pos: Tensor = None, + ref_sine_embed: Tensor = None, + key_pos: Tensor = None, + attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + is_first: bool = False) -> Tensor: + """Forward function for `ConditionalAttention`. + Args: + query (Tensor): The input query with shape [bs, num_queries, + embed_dims]. + key (Tensor): The key tensor with shape [bs, num_keys, + embed_dims]. + If None, the `query` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query in self + attention, with the same shape as `x`. If not None, it will + be added to `x` before forward function. + Defaults to None. + query_sine_embed (Tensor): The positional encoding for query in + cross attention, with the same shape as `x`. If not None, it + will be added to `x` before forward function. + Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + is_first (bool): A indicator to tell whether the current layer + is the first layer of the decoder. + Defaults to False. + Returns: + Tensor: forwarded results with shape + [bs, num_queries, embed_dims]. + """ + + if self.cross_attn: + q_content = self.qcontent_proj(query) + k_content = self.kcontent_proj(key) + v = self.v_proj(key) + + bs, nq, c = q_content.size() + _, hw, _ = k_content.size() + + k_pos = self.kpos_proj(key_pos) + if is_first or self.keep_query_pos: + q_pos = self.qpos_proj(query_pos) + q = q_content + q_pos + k = k_content + k_pos + else: + q = q_content + k = k_content + q = q.view(bs, nq, self.num_heads, c // self.num_heads) + query_sine_embed = self.qpos_sine_proj(ref_sine_embed) + query_sine_embed = query_sine_embed.view(bs, nq, self.num_heads, + c // self.num_heads) + q = torch.cat([q, query_sine_embed], dim=3).view(bs, nq, 2 * c) + k = k.view(bs, hw, self.num_heads, c // self.num_heads) + k_pos = k_pos.view(bs, hw, self.num_heads, c // self.num_heads) + k = torch.cat([k, k_pos], dim=3).view(bs, hw, 2 * c) + ca_output = self.forward_attn( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask)[0] + query = query + self.proj_drop(ca_output) + else: + q_content = self.qcontent_proj(query) + q_pos = self.qpos_proj(query_pos) + k_content = self.kcontent_proj(query) + k_pos = self.kpos_proj(query_pos) + v = self.v_proj(query) + q = q_content if q_pos is None else q_content + q_pos + k = k_content if k_pos is None else k_content + k_pos + sa_output = self.forward_attn( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask)[0] + query = query + self.proj_drop(sa_output) + + return query + + +class MLP(BaseModule): + """Very simple multi-layer perceptron (also called FFN) with relu. Mostly + used in DETR series detectors. + + Args: + input_dim (int): Feature dim of the input tensor. + hidden_dim (int): Feature dim of the hidden layer. + output_dim (int): Feature dim of the output tensor. + num_layers (int): Number of FFN layers. As the last + layer of MLP only contains FFN (Linear). + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, + num_layers: int) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = ModuleList( + Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of MLP. + + Args: + x (Tensor): The input feature, has shape + (num_queries, bs, input_dim). + Returns: + Tensor: The output feature, has shape + (num_queries, bs, output_dim). + """ + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@MODELS.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo `_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_channels: int = 256, + feat_channels: int = 64, + out_channels: Optional[int] = None, + input_feat_shape: int = 7, + with_proj: bool = True, + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape**2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor: + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, :self.num_params_in].view( + -1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +def get_text_sine_pos_embed( + pos_tensor: torch.Tensor, + num_pos_feats: int = 128, + temperature: int = 10000, + exchange_xy: bool = True, +): + """generate sine position embedding from a position tensor + Args: + pos_tensor (torch.Tensor): shape: [..., n]. + num_pos_feats (int): projected shape for each float in the tensor. + temperature (int): temperature in the sine/cosine function. + exchange_xy (bool, optional): exchange pos x and pos y. For example, + input tensor is [x,y], the results will be [pos(y), pos(x)]. + Defaults to True. + Returns: + pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. + """ + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature**(2 * torch.div(dim_t, 2, rounding_mode='floor') / + num_pos_feats) + + def sine_func(x: torch.Tensor): + sin_x = x * scale / dim_t + sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), + dim=3).flatten(2) + return sin_x + + pos_res = [ + sine_func(x) + for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1) + ] + if exchange_xy: + pos_res[0], pos_res[1] = pos_res[1], pos_res[0] + pos_res = torch.cat(pos_res, dim=-1) + return pos_res diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c57a3a96879c6bd5eb61c300d316e2b4579b287 --- /dev/null +++ b/mmdet/models/losses/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .accuracy import Accuracy, accuracy +from .ae_loss import AssociativeEmbeddingLoss +from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss +from .cross_entropy_loss import (CrossEntropyCustomLoss, CrossEntropyLoss, + binary_cross_entropy, cross_entropy, + mask_cross_entropy) +from .ddq_detr_aux_loss import DDQAuxLoss +from .dice_loss import DiceLoss +from .eqlv2_loss import EQLV2Loss +from .focal_loss import FocalCustomLoss, FocalLoss, sigmoid_focal_loss +from .gaussian_focal_loss import GaussianFocalLoss +from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss +from .ghm_loss import GHMC, GHMR +from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, EIoULoss, GIoULoss, + IoULoss, SIoULoss, bounded_iou_loss, iou_loss) +from .kd_loss import KnowledgeDistillationKLDivLoss +from .l2_loss import L2Loss +from .margin_loss import MarginL2Loss +from .mse_loss import MSELoss, mse_loss +from .multipos_cross_entropy_loss import MultiPosCrossEntropyLoss +from .pisa_loss import carl_loss, isr_p +from .seesaw_loss import SeesawLoss +from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss +from .triplet_loss import TripletLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss +from .varifocal_loss import VarifocalLoss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss', + 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss', + 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss', + 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', + 'EIoULoss', 'SIoULoss', 'GHMC', 'GHMR', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'L1Loss', 'l1_loss', 'isr_p', + 'carl_loss', 'AssociativeEmbeddingLoss', 'GaussianFocalLoss', + 'QualityFocalLoss', 'DistributionFocalLoss', 'VarifocalLoss', + 'KnowledgeDistillationKLDivLoss', 'SeesawLoss', 'DiceLoss', 'EQLV2Loss', + 'MarginL2Loss', 'MultiPosCrossEntropyLoss', 'L2Loss', 'TripletLoss', + 'DDQAuxLoss', 'CrossEntropyCustomLoss', 'FocalCustomLoss' +] diff --git a/mmdet/models/losses/accuracy.py b/mmdet/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..d68484e13965ced3bd6b104071d22657a9b3fde6 --- /dev/null +++ b/mmdet/models/losses/accuracy.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class) + target (torch.Tensor): The target of each prediction, shape (N, ) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == 2 and target.ndim == 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() # transpose to shape (maxk, N) + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / pred.size(0))) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + + def __init__(self, topk=(1, ), thresh=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh) diff --git a/mmdet/models/losses/ae_loss.py b/mmdet/models/losses/ae_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa7d696be4b937a2d45545a8309aaa936fe5f22 --- /dev/null +++ b/mmdet/models/losses/ae_loss.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.registry import MODELS + + +def ae_loss_per_image(tl_preds, br_preds, match): + """Associative Embedding Loss in one image. + + Associative Embedding Loss including two parts: pull loss and push loss. + Pull loss makes embedding vectors from same object closer to each other. + Push loss distinguish embedding vector from different objects, and makes + the gap between them is large enough. + + During computing, usually there are 3 cases: + - no object in image: both pull loss and push loss will be 0. + - one object in image: push loss will be 0 and pull loss is computed + by the two corner of the only object. + - more than one objects in image: pull loss is computed by corner pairs + from each object, push loss is computed by each object with all + other objects. We use confusion matrix with 0 in diagonal to + compute the push loss. + + Args: + tl_preds (tensor): Embedding feature map of left-top corner. + br_preds (tensor): Embedding feature map of bottim-right corner. + match (list): Downsampled coordinates pair of each ground truth box. + """ + + tl_list, br_list, me_list = [], [], [] + if len(match) == 0: # no object in image + pull_loss = tl_preds.sum() * 0. + push_loss = tl_preds.sum() * 0. + else: + for m in match: + [tl_y, tl_x], [br_y, br_x] = m + tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1) + br_e = br_preds[:, br_y, br_x].view(-1, 1) + tl_list.append(tl_e) + br_list.append(br_e) + me_list.append((tl_e + br_e) / 2.0) + + tl_list = torch.cat(tl_list) + br_list = torch.cat(br_list) + me_list = torch.cat(me_list) + + assert tl_list.size() == br_list.size() + + # N is object number in image, M is dimension of embedding vector + N, M = tl_list.size() + + pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2) + pull_loss = pull_loss.sum() / N + + margin = 1 # exp setting of CornerNet, details in section 3.3 of paper + + # confusion matrix of push loss + conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list + conf_weight = 1 - torch.eye(N).type_as(me_list) + conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs()) + + if N > 1: # more than one object in current image + push_loss = F.relu(conf_mat).sum() / (N * (N - 1)) + else: + push_loss = tl_preds.sum() * 0. + + return pull_loss, push_loss + + +@MODELS.register_module() +class AssociativeEmbeddingLoss(nn.Module): + """Associative Embedding Loss. + + More details can be found in + `Associative Embedding `_ and + `CornerNet `_ . + Code is modified from `kp_utils.py `_ # noqa: E501 + + Args: + pull_weight (float): Loss weight for corners from same object. + push_weight (float): Loss weight for corners from different object. + """ + + def __init__(self, pull_weight=0.25, push_weight=0.25): + super(AssociativeEmbeddingLoss, self).__init__() + self.pull_weight = pull_weight + self.push_weight = push_weight + + def forward(self, pred, target, match): + """Forward function.""" + batch = pred.size(0) + pull_all, push_all = 0.0, 0.0 + for i in range(batch): + pull, push = ae_loss_per_image(pred[i], target[i], match[i]) + + pull_all += self.pull_weight * pull + push_all += self.push_weight * push + + return pull_all, push_all diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..25adaab2239e871476d9d4e3cbb1a238c3043041 --- /dev/null +++ b/mmdet/models/losses/balanced_l1_loss.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn + +from mmdet.registry import MODELS +from .utils import weighted_loss + + +@weighted_loss +def balanced_l1_loss(pred, + target, + beta=1.0, + alpha=0.5, + gamma=1.5, + reduction='mean'): + """Calculate balanced L1 loss. + + Please see the `Libra R-CNN `_ + + Args: + pred (torch.Tensor): The prediction with shape (N, 4). + target (torch.Tensor): The learning target of the prediction with + shape (N, 4). + beta (float): The loss is a piecewise function of prediction and target + and ``beta`` serves as a threshold for the difference between the + prediction and target. Defaults to 1.0. + alpha (float): The denominator ``alpha`` in the balanced L1 loss. + Defaults to 0.5. + gamma (float): The ``gamma`` in the balanced L1 loss. + Defaults to 1.5. + reduction (str, optional): The method that reduces the loss to a + scalar. Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + assert beta > 0 + if target.numel() == 0: + return pred.sum() * 0 + + assert pred.size() == target.size() + + diff = torch.abs(pred - target) + b = np.e**(gamma / alpha) - 1 + loss = torch.where( + diff < beta, alpha / b * + (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff, + gamma * diff + gamma / b - alpha * beta) + + return loss + + +@MODELS.register_module() +class BalancedL1Loss(nn.Module): + """Balanced L1 Loss. + + arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) + + Args: + alpha (float): The denominator ``alpha`` in the balanced L1 loss. + Defaults to 0.5. + gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5. + beta (float, optional): The loss is a piecewise function of prediction + and target. ``beta`` serves as a threshold for the difference + between the prediction and target. Defaults to 1.0. + reduction (str, optional): The method that reduces the loss to a + scalar. Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + alpha=0.5, + gamma=1.5, + beta=1.0, + reduction='mean', + loss_weight=1.0): + super(BalancedL1Loss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.beta = beta + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function of loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 4). + target (torch.Tensor): The learning target of the prediction with + shape (N, 4). + weight (torch.Tensor, optional): Sample-wise loss weight with + shape (N, ). + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_bbox = self.loss_weight * balanced_l1_loss( + pred, + target, + weight, + alpha=self.alpha, + gamma=self.gamma, + beta=self.beta, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_bbox diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..49fac7743ceddd2454f44b76c63d514de43b5aef --- /dev/null +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -0,0 +1,401 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.registry import MODELS +from .accuracy import accuracy +from .utils import weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False): + """Calculate the CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int | None): The label index to be ignored. + If None, it will be set to default value. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + + Returns: + torch.Tensor: The calculated loss + """ + # The default value of ignore_index is the same as F.cross_entropy + ignore_index = -100 if ignore_index is None else ignore_index + # element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero( + valid_mask & (labels < label_channels), as_tuple=False) + + if inds.numel() > 0: + bin_labels[inds, labels[inds]] = 1 + + valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), + label_channels).float() + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) + bin_label_weights *= valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1) or (N, ). + When the shape of pred is (N, 1), label will be expanded to + one-hot format, and when the shape of pred is (N, ), label + will not be expanded to one-hot format. + label (torch.Tensor): The learning label of the prediction, + with shape (N, ). + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int | None): The label index to be ignored. + If None, it will be set to default value. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + + Returns: + torch.Tensor: The calculated loss. + """ + # The default value of ignore_index is the same as F.cross_entropy + ignore_index = -100 if ignore_index is None else ignore_index + + if pred.dim() != label.dim(): + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.size(-1), ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + # The inplace writing method will have a mismatched broadcast + # shape error if the weight and valid_mask dimensions + # are inconsistent such as (B,N,1) and (B,N,C). + weight = weight * valid_mask + else: + weight = valid_mask + + # average loss over non-ignored elements + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = valid_mask.sum().item() + + # weighted element-wise losses + weight = weight.float() + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C, *), C is the + number of classes. The trailing * indicates arbitrary shape. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + + Example: + >>> N, C = 3, 11 + >>> H, W = 2, 2 + >>> pred = torch.randn(N, C, H, W) * 1000 + >>> target = torch.rand(N, H, W) + >>> label = torch.randint(0, C, size=(N,)) + >>> reduction = 'mean' + >>> avg_factor = None + >>> class_weights = None + >>> loss = mask_cross_entropy(pred, target, label, reduction, + >>> avg_factor, class_weights) + >>> assert loss.shape == (1,) + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + ignore_index=None, + loss_weight=1.0, + avg_non_ignore=False): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + ignore_index (int | None): The label index to be ignored. + Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + """ + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.ignore_index = ignore_index + self.avg_non_ignore = avg_non_ignore + if ((ignore_index is not None) and not self.avg_non_ignore + and self.reduction == 'mean'): + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=None, + **kwargs): + """Forward function. + + Args: + cls_score (torch.Tensor): The prediction. + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss. Options are "none", "mean" and "sum". + ignore_index (int | None): The label index to be ignored. + If not None, it will override the default value. Default: None. + Returns: + torch.Tensor: The calculated loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if ignore_index is None: + ignore_index = self.ignore_index + + if self.class_weight is not None: + class_weight = cls_score.new_tensor( + self.class_weight, device=cls_score.device) + else: + class_weight = None + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + ignore_index=ignore_index, + avg_non_ignore=self.avg_non_ignore, + **kwargs) + return loss_cls + + +@MODELS.register_module() +class CrossEntropyCustomLoss(CrossEntropyLoss): + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + num_classes=-1, + class_weight=None, + ignore_index=None, + loss_weight=1.0, + avg_non_ignore=False): + """CrossEntropyCustomLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + num_classes (int): Number of classes to classify. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + ignore_index (int | None): The label index to be ignored. + Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + """ + super(CrossEntropyCustomLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.ignore_index = ignore_index + self.avg_non_ignore = avg_non_ignore + if ((ignore_index is not None) and not self.avg_non_ignore + and self.reduction == 'mean'): + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + + self.num_classes = num_classes + + assert self.num_classes != -1 + + # custom output channels of the classifier + self.custom_cls_channels = True + # custom activation of cls_score + self.custom_activation = True + # custom accuracy of the classsifier + self.custom_accuracy = True + + def get_cls_channels(self, num_classes): + assert num_classes == self.num_classes + if not self.use_sigmoid: + return num_classes + 1 + else: + return num_classes + + def get_activation(self, cls_score): + + fine_cls_score = cls_score[:, :self.num_classes] + + if not self.use_sigmoid: + bg_score = cls_score[:, [-1]] + new_score = torch.cat([fine_cls_score, bg_score], dim=-1) + scores = F.softmax(new_score, dim=-1) + else: + score_classes = fine_cls_score.sigmoid() + score_neg = 1 - score_classes.sum(dim=1, keepdim=True) + score_neg = score_neg.clamp(min=0, max=1) + scores = torch.cat([score_classes, score_neg], dim=1) + + return scores + + def get_accuracy(self, cls_score, labels): + + fine_cls_score = cls_score[:, :self.num_classes] + + pos_inds = labels < self.num_classes + acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds]) + acc = dict() + acc['acc_classes'] = acc_classes + return acc diff --git a/mmdet/models/losses/ddq_detr_aux_loss.py b/mmdet/models/losses/ddq_detr_aux_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..41f1c7166e6c7d05c5414cd04ad3eb3cd467f1b6 --- /dev/null +++ b/mmdet/models/losses/ddq_detr_aux_loss.py @@ -0,0 +1,303 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.structures import BaseDataElement + +from mmdet.models.utils import multi_apply +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import reduce_mean + + +class DDQAuxLoss(nn.Module): + """DDQ auxiliary branches loss for dense queries. + + Args: + loss_cls (dict): + Configuration of classification loss function. + loss_bbox (dict): + Configuration of bbox regression loss function. + train_cfg (dict): + Configuration of gt targets assigner for each predicted bbox. + """ + + def __init__( + self, + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + train_cfg=dict( + assigner=dict(type='TopkHungarianAssigner', topk=8), + alpha=1, + beta=6), + ): + super(DDQAuxLoss, self).__init__() + self.train_cfg = train_cfg + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + + sampler_cfg = dict(type='PseudoSampler') + self.sampler = TASK_UTILS.build(sampler_cfg) + + def loss_single(self, cls_score, bbox_pred, labels, label_weights, + bbox_targets, alignment_metrics): + """Calculate auxiliary branches loss for dense queries for one image. + + Args: + cls_score (Tensor): Predicted normalized classification + scores for one image, has shape (num_dense_queries, + cls_out_channels). + bbox_pred (Tensor): Predicted unnormalized bbox coordinates + for one image, has shape (num_dense_queries, 4) with the + last dimension arranged as (x1, y1, x2, y2). + labels (Tensor): Labels for one image. + label_weights (Tensor): Label weights for one image. + bbox_targets (Tensor): Bbox targets for one image. + alignment_metrics (Tensor): Normalized alignment metrics for one + image. + + Returns: + tuple: A tuple of loss components and loss weights. + """ + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + alignment_metrics = alignment_metrics.reshape(-1) + label_weights = label_weights.reshape(-1) + targets = (labels, alignment_metrics) + cls_loss_func = self.loss_cls + + loss_cls = cls_loss_func( + cls_score, targets, label_weights, avg_factor=1.0) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = cls_score.size(-1) + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + + pos_decode_bbox_pred = pos_bbox_pred + pos_decode_bbox_targets = pos_bbox_targets + + # regression loss + pos_bbox_weight = alignment_metrics[pos_inds] + + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=pos_bbox_weight, + avg_factor=1.0) + else: + loss_bbox = bbox_pred.sum() * 0 + pos_bbox_weight = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, alignment_metrics.sum( + ), pos_bbox_weight.sum() + + def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, + **kwargs): + """Calculate auxiliary branches loss for dense queries. + + Args: + cls_scores (Tensor): Predicted normalized classification + scores, has shape (bs, num_dense_queries, + cls_out_channels). + bbox_preds (Tensor): Predicted unnormalized bbox coordinates, + has shape (bs, num_dense_queries, 4) with the last + dimension arranged as (x1, y1, x2, y2). + gt_bboxes (list[Tensor]): List of unnormalized ground truth + bboxes for each image, each has shape (num_gt, 4) with the + last dimension arranged as (x1, y1, x2, y2). + NOTE: num_gt is dynamic for each image. + gt_labels (list[Tensor]): List of ground truth classification + index for each image, each has shape (num_gt,). + NOTE: num_gt is dynamic for each image. + img_metas (list[dict]): Meta information for one image, + e.g., image size, scaling factor, etc. + + Returns: + dict: A dictionary of loss components. + """ + flatten_cls_scores = cls_scores + flatten_bbox_preds = bbox_preds + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bbox_preds, + gt_bboxes, + img_metas, + gt_labels_list=gt_labels, + ) + (labels_list, label_weights_list, bbox_targets_list, + alignment_metrics_list) = cls_reg_targets + + losses_cls, losses_bbox, \ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_single, + flatten_cls_scores, + flatten_bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + alignment_metrics_list, + ) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict(aux_loss_cls=losses_cls, aux_loss_bbox=losses_bbox) + + def get_targets(self, + cls_scores, + bbox_preds, + gt_bboxes_list, + img_metas, + gt_labels_list=None, + **kwargs): + """Compute regression and classification targets for a batch images. + + Args: + cls_scores (Tensor): Predicted normalized classification + scores, has shape (bs, num_dense_queries, + cls_out_channels). + bbox_preds (Tensor): Predicted unnormalized bbox coordinates, + has shape (bs, num_dense_queries, 4) with the last + dimension arranged as (x1, y1, x2, y2). + gt_bboxes_list (List[Tensor]): List of unnormalized ground truth + bboxes for each image, each has shape (num_gt, 4) with the + last dimension arranged as (x1, y1, x2, y2). + NOTE: num_gt is dynamic for each image. + img_metas (list[dict]): Meta information for one image, + e.g., image size, scaling factor, etc. + gt_labels_list (list[Tensor]): List of ground truth classification + index for each image, each has shape (num_gt,). + NOTE: num_gt is dynamic for each image. + Default: None. + + Returns: + tuple: a tuple containing the following targets. + + - all_labels (list[Tensor]): Labels for all images. + - all_label_weights (list[Tensor]): Label weights for all images. + - all_bbox_targets (list[Tensor]): Bbox targets for all images. + - all_assign_metrics (list[Tensor]): Normalized alignment metrics + for all images. + """ + (all_labels, all_label_weights, all_bbox_targets, + all_assign_metrics) = multi_apply(self._get_target_single, cls_scores, + bbox_preds, gt_bboxes_list, + gt_labels_list, img_metas) + + return (all_labels, all_label_weights, all_bbox_targets, + all_assign_metrics) + + def _get_target_single(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_meta, **kwargs): + """Compute regression and classification targets for one image. + + Args: + cls_scores (Tensor): Predicted normalized classification + scores for one image, has shape (num_dense_queries, + cls_out_channels). + bbox_preds (Tensor): Predicted unnormalized bbox coordinates + for one image, has shape (num_dense_queries, 4) with the + last dimension arranged as (x1, y1, x2, y2). + gt_bboxes (Tensor): Unnormalized ground truth + bboxes for one image, has shape (num_gt, 4) with the + last dimension arranged as (x1, y1, x2, y2). + NOTE: num_gt is dynamic for each image. + gt_labels (Tensor): Ground truth classification + index for the image, has shape (num_gt,). + NOTE: num_gt is dynamic for each image. + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels for one image. + - label_weights (Tensor): Label weights for one image. + - bbox_targets (Tensor): Bbox targets for one image. + - norm_alignment_metrics (Tensor): Normalized alignment + metrics for one image. + """ + if len(gt_labels) == 0: + num_valid_anchors = len(cls_scores) + bbox_targets = torch.zeros_like(bbox_preds) + labels = bbox_preds.new_full((num_valid_anchors, ), + cls_scores.size(-1), + dtype=torch.long) + label_weights = bbox_preds.new_zeros( + num_valid_anchors, dtype=torch.float) + norm_alignment_metrics = bbox_preds.new_zeros( + num_valid_anchors, dtype=torch.float) + return (labels, label_weights, bbox_targets, + norm_alignment_metrics) + + assign_result = self.assigner.assign(cls_scores, bbox_preds, gt_bboxes, + gt_labels, img_meta) + assign_ious = assign_result.max_overlaps + assign_metrics = assign_result.assign_metrics + + pred_instances = BaseDataElement() + gt_instances = BaseDataElement() + + pred_instances.bboxes = bbox_preds + gt_instances.bboxes = gt_bboxes + + pred_instances.priors = cls_scores + gt_instances.labels = gt_labels + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = len(cls_scores) + bbox_targets = torch.zeros_like(bbox_preds) + labels = bbox_preds.new_full((num_valid_anchors, ), + cls_scores.size(-1), + dtype=torch.long) + label_weights = bbox_preds.new_zeros( + num_valid_anchors, dtype=torch.float) + norm_alignment_metrics = bbox_preds.new_zeros( + num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + # point-based + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + + if gt_labels is None: + # Only dense_heads gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + + label_weights[pos_inds] = 1.0 + + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + class_assigned_gt_inds = torch.unique( + sampling_result.pos_assigned_gt_inds) + for gt_inds in class_assigned_gt_inds: + gt_class_inds = sampling_result.pos_assigned_gt_inds == gt_inds + pos_alignment_metrics = assign_metrics[gt_class_inds] + pos_ious = assign_ious[gt_class_inds] + pos_norm_alignment_metrics = pos_alignment_metrics / ( + pos_alignment_metrics.max() + 10e-8) * pos_ious.max() + norm_alignment_metrics[ + pos_inds[gt_class_inds]] = pos_norm_alignment_metrics + + return (labels, label_weights, bbox_targets, norm_alignment_metrics) diff --git a/mmdet/models/losses/dice_loss.py b/mmdet/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5cac1e9710a6a72fe0401db22b8b72cfe058f9 --- /dev/null +++ b/mmdet/models/losses/dice_loss.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmdet.registry import MODELS +from .utils import weight_reduce_loss + + +def dice_loss(pred, + target, + weight=None, + eps=1e-3, + reduction='mean', + naive_dice=False, + avg_factor=None): + """Calculate dice loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + if naive_dice: + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + else: + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class DiceLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=False, + loss_weight=1.0, + eps=1e-3): + """Compute dice loss. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power. Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, + pred, + target, + weight=None, + reduction_override=None, + avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + loss = self.loss_weight * dice_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + naive_dice=self.naive_dice, + avg_factor=avg_factor) + + return loss diff --git a/mmdet/models/losses/eqlv2_loss.py b/mmdet/models/losses/eqlv2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1f4a9a8f7c71119c2bed743d714a34ab4db82c --- /dev/null +++ b/mmdet/models/losses/eqlv2_loss.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import print_log +from torch import Tensor + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class EQLV2Loss(nn.Module): + + def __init__(self, + use_sigmoid: bool = True, + reduction: str = 'mean', + class_weight: Optional[Tensor] = None, + loss_weight: float = 1.0, + num_classes: int = 1203, + use_distributed: bool = False, + mu: float = 0.8, + alpha: float = 4.0, + gamma: int = 12, + vis_grad: bool = False, + test_with_obj: bool = True) -> None: + """`Equalization Loss v2 `_ + + Args: + use_sigmoid (bool): EQLv2 uses the sigmoid function to transform + the predicted logits to an estimated probability distribution. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + class_weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + loss_weight (float, optional): The weight of the total EQLv2 loss. + Defaults to 1.0. + num_classes (int): 1203 for lvis v1.0, 1230 for lvis v0.5. + use_distributed (bool, float): EQLv2 will calculate the gradients + on all GPUs if there is any. Change to True if you are using + distributed training. Default to False. + mu (float, optional): Defaults to 0.8 + alpha (float, optional): A balance factor for the negative part of + EQLV2 Loss. Defaults to 4.0. + gamma (int, optional): The gamma for calculating the modulating + factor. Defaults to 12. + vis_grad (bool, optional): Default to False. + test_with_obj (bool, optional): Default to True. + + Returns: + None. + """ + super().__init__() + self.use_sigmoid = True + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.num_classes = num_classes + self.group = True + + # cfg for eqlv2 + self.vis_grad = vis_grad + self.mu = mu + self.alpha = alpha + self.gamma = gamma + self.use_distributed = use_distributed + + # initial variables + self.register_buffer('pos_grad', torch.zeros(self.num_classes)) + self.register_buffer('neg_grad', torch.zeros(self.num_classes)) + # At the beginning of training, we set a high value (eg. 100) + # for the initial gradient ratio so that the weight for pos + # gradients and neg gradients are 1. + self.register_buffer('pos_neg', torch.ones(self.num_classes) * 100) + + self.test_with_obj = test_with_obj + + def _func(x, gamma, mu): + return 1 / (1 + torch.exp(-gamma * (x - mu))) + + self.map_func = partial(_func, gamma=self.gamma, mu=self.mu) + + print_log( + f'build EQL v2, gamma: {gamma}, mu: {mu}, alpha: {alpha}', + logger='current', + level=logging.DEBUG) + + def forward(self, + cls_score: Tensor, + label: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[Tensor] = None) -> Tensor: + """`Equalization Loss v2 `_ + + Args: + cls_score (Tensor): The prediction with shape (N, C), C is the + number of classes. + label (Tensor): The ground truth label of the predicted target with + shape (N, C), C is the number of classes. + weight (Tensor, optional): The weight of loss for each prediction. + Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + Tensor: The calculated loss + """ + self.n_i, self.n_c = cls_score.size() + self.gt_classes = label + self.pred_class_logits = cls_score + + def expand_label(pred, gt_classes): + target = pred.new_zeros(self.n_i, self.n_c) + target[torch.arange(self.n_i), gt_classes] = 1 + return target + + target = expand_label(cls_score, label) + + pos_w, neg_w = self.get_weight(cls_score) + + weight = pos_w * target + neg_w * (1 - target) + + cls_loss = F.binary_cross_entropy_with_logits( + cls_score, target, reduction='none') + cls_loss = torch.sum(cls_loss * weight) / self.n_i + + self.collect_grad(cls_score.detach(), target.detach(), weight.detach()) + + return self.loss_weight * cls_loss + + def get_channel_num(self, num_classes): + num_channel = num_classes + 1 + return num_channel + + def get_activation(self, pred): + pred = torch.sigmoid(pred) + n_i, n_c = pred.size() + bg_score = pred[:, -1].view(n_i, 1) + if self.test_with_obj: + pred[:, :-1] *= (1 - bg_score) + return pred + + def collect_grad(self, pred, target, weight): + prob = torch.sigmoid(pred) + grad = target * (prob - 1) + (1 - target) * prob + grad = torch.abs(grad) + + # do not collect grad for objectiveness branch [:-1] + pos_grad = torch.sum(grad * target * weight, dim=0)[:-1] + neg_grad = torch.sum(grad * (1 - target) * weight, dim=0)[:-1] + + if self.use_distributed: + dist.all_reduce(pos_grad) + dist.all_reduce(neg_grad) + + self.pos_grad += pos_grad + self.neg_grad += neg_grad + self.pos_neg = self.pos_grad / (self.neg_grad + 1e-10) + + def get_weight(self, pred): + neg_w = torch.cat([self.map_func(self.pos_neg), pred.new_ones(1)]) + pos_w = 1 + self.alpha * (1 - neg_w) + neg_w = neg_w.view(1, -1).expand(self.n_i, self.n_c) + pos_w = pos_w.view(1, -1).expand(self.n_i, self.n_c) + return pos_w, neg_w diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..15bef293a591a7f4c099febdaa82abaf7fb4928a --- /dev/null +++ b/mmdet/models/losses/focal_loss.py @@ -0,0 +1,371 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from mmdet.registry import MODELS +from .accuracy import accuracy +from .utils import weight_reduce_loss + + +# This method is only for debugging +def py_sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + # Actually, pt here denotes (1 - pt) in the Focal Loss paper + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + # Thus it's pt.pow(gamma) rather than (1 - pt).pow(gamma) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def py_focal_loss_with_prob(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + Different from `py_sigmoid_focal_loss`, this function accepts probability + as input. + + Args: + pred (torch.Tensor): The prediction probability with shape (N, C), + C is the number of classes. + target (torch.Tensor): The learning label of the prediction. + The target shape support (N,C) or (N,), (N,C) means one-hot form. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + if pred.dim() != target.dim(): + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + + target = target.type_as(pred) + pt = (1 - pred) * target + pred * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy( + pred, target, reduction='none') * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + r"""A wrapper of cuda version `Focal Loss + `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, + alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0, + activated=False): + """`Focal Loss `_ + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. + """ + super(FocalLoss, self).__init__() + assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + self.activated = activated + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning label of the prediction. + The target shape support (N,C) or (N,), (N,C) means + one-hot form. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + if self.activated: + calculate_loss_func = py_focal_loss_with_prob + else: + if pred.dim() == target.dim(): + # this means that target is already in One-Hot form. + calculate_loss_func = py_sigmoid_focal_loss + elif torch.cuda.is_available() and pred.is_cuda: + calculate_loss_func = sigmoid_focal_loss + else: + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + + else: + raise NotImplementedError + return loss_cls + + +@MODELS.register_module() +class FocalCustomLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + num_classes=-1, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0, + activated=False): + """`Focal Loss for V3Det `_ + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + num_classes (int): Number of classes to classify. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. + """ + super(FocalCustomLoss, self).__init__() + assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' + self.use_sigmoid = use_sigmoid + self.num_classes = num_classes + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + self.activated = activated + + assert self.num_classes != -1 + + # custom output channels of the classifier + self.custom_cls_channels = True + # custom activation of cls_score + self.custom_activation = True + # custom accuracy of the classsifier + self.custom_accuracy = True + + def get_cls_channels(self, num_classes): + assert num_classes == self.num_classes + return num_classes + + def get_activation(self, cls_score): + + fine_cls_score = cls_score[:, :self.num_classes] + + score_classes = fine_cls_score.sigmoid() + + return score_classes + + def get_accuracy(self, cls_score, labels): + + fine_cls_score = cls_score[:, :self.num_classes] + + pos_inds = labels < self.num_classes + acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds]) + acc = dict() + acc['acc_classes'] = acc_classes + return acc + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + + else: + raise NotImplementedError + return loss_cls diff --git a/mmdet/models/losses/gaussian_focal_loss.py b/mmdet/models/losses/gaussian_focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..14fa8da462a5e7cabde2166878a1b9f2ccc16d62 --- /dev/null +++ b/mmdet/models/losses/gaussian_focal_loss.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weight_reduce_loss, weighted_loss + + +@weighted_loss +def gaussian_focal_loss(pred: Tensor, + gaussian_target: Tensor, + alpha: float = 2.0, + gamma: float = 4.0, + pos_weight: float = 1.0, + neg_weight: float = 1.0) -> Tensor: + """`Focal Loss `_ for targets in gaussian + distribution. + + Args: + pred (torch.Tensor): The prediction. + gaussian_target (torch.Tensor): The learning target of the prediction + in gaussian distribution. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 2.0. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 4.0. + pos_weight(float): Positive sample loss weight. Defaults to 1.0. + neg_weight(float): Negative sample loss weight. Defaults to 1.0. + """ + eps = 1e-12 + pos_weights = gaussian_target.eq(1) + neg_weights = (1 - gaussian_target).pow(gamma) + pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights + neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights + return pos_weight * pos_loss + neg_weight * neg_loss + + +def gaussian_focal_loss_with_pos_inds( + pred: Tensor, + gaussian_target: Tensor, + pos_inds: Tensor, + pos_labels: Tensor, + alpha: float = 2.0, + gamma: float = 4.0, + pos_weight: float = 1.0, + neg_weight: float = 1.0, + reduction: str = 'mean', + avg_factor: Optional[Union[int, float]] = None) -> Tensor: + """`Focal Loss `_ for targets in gaussian + distribution. + + Note: The index with a value of 1 in ``gaussian_target`` in the + ``gaussian_focal_loss`` function is a positive sample, but in + ``gaussian_focal_loss_with_pos_inds`` the positive sample is passed + in through the ``pos_inds`` parameter. + + Args: + pred (torch.Tensor): The prediction. The shape is (N, num_classes). + gaussian_target (torch.Tensor): The learning target of the prediction + in gaussian distribution. The shape is (N, num_classes). + pos_inds (torch.Tensor): The positive sample index. + The shape is (M, ). + pos_labels (torch.Tensor): The label corresponding to the positive + sample index. The shape is (M, ). + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 2.0. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 4.0. + pos_weight(float): Positive sample loss weight. Defaults to 1.0. + neg_weight(float): Negative sample loss weight. Defaults to 1.0. + reduction (str): Options are "none", "mean" and "sum". + Defaults to 'mean`. + avg_factor (int, float, optional): Average factor that is used to + average the loss. Defaults to None. + """ + eps = 1e-12 + neg_weights = (1 - gaussian_target).pow(gamma) + + pos_pred_pix = pred[pos_inds] + pos_pred = pos_pred_pix.gather(1, pos_labels.unsqueeze(1)) + pos_loss = -(pos_pred + eps).log() * (1 - pos_pred).pow(alpha) + pos_loss = weight_reduce_loss(pos_loss, None, reduction, avg_factor) + + neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights + neg_loss = weight_reduce_loss(neg_loss, None, reduction, avg_factor) + + return pos_weight * pos_loss + neg_weight * neg_loss + + +@MODELS.register_module() +class GaussianFocalLoss(nn.Module): + """GaussianFocalLoss is a variant of focal loss. + + More details can be found in the `paper + `_ + Code is modified from `kp_utils.py + `_ # noqa: E501 + Please notice that the target in GaussianFocalLoss is a gaussian heatmap, + not 0/1 binary target. + + Args: + alpha (float): Power of prediction. + gamma (float): Power of target for negative samples. + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Loss weight of current loss. + pos_weight(float): Positive sample loss weight. Defaults to 1.0. + neg_weight(float): Negative sample loss weight. Defaults to 1.0. + """ + + def __init__(self, + alpha: float = 2.0, + gamma: float = 4.0, + reduction: str = 'mean', + loss_weight: float = 1.0, + pos_weight: float = 1.0, + neg_weight: float = 1.0) -> None: + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + self.loss_weight = loss_weight + self.pos_weight = pos_weight + self.neg_weight = neg_weight + + def forward(self, + pred: Tensor, + target: Tensor, + pos_inds: Optional[Tensor] = None, + pos_labels: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + avg_factor: Optional[Union[int, float]] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function. + + If you want to manually determine which positions are + positive samples, you can set the pos_index and pos_label + parameter. Currently, only the CenterNet update version uses + the parameter. + + Args: + pred (torch.Tensor): The prediction. The shape is (N, num_classes). + target (torch.Tensor): The learning target of the prediction + in gaussian distribution. The shape is (N, num_classes). + pos_inds (torch.Tensor): The positive sample index. + Defaults to None. + pos_labels (torch.Tensor): The label corresponding to the positive + sample index. Defaults to None. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, float, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if pos_inds is not None: + assert pos_labels is not None + # Only used by centernet update version + loss_reg = self.loss_weight * gaussian_focal_loss_with_pos_inds( + pred, + target, + pos_inds, + pos_labels, + alpha=self.alpha, + gamma=self.gamma, + pos_weight=self.pos_weight, + neg_weight=self.neg_weight, + reduction=reduction, + avg_factor=avg_factor) + else: + loss_reg = self.loss_weight * gaussian_focal_loss( + pred, + target, + weight, + alpha=self.alpha, + gamma=self.gamma, + pos_weight=self.pos_weight, + neg_weight=self.neg_weight, + reduction=reduction, + avg_factor=avg_factor) + return loss_reg diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a1172207e859039ca5ed7e0604d8b787131c29 --- /dev/null +++ b/mmdet/models/losses/gfocal_loss.py @@ -0,0 +1,295 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.models.losses.utils import weighted_loss +from mmdet.registry import MODELS + + +@weighted_loss +def quality_focal_loss(pred, target, beta=2.0): + r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + + Args: + pred (torch.Tensor): Predicted joint representation of classification + and quality (IoU) estimation with shape (N, C), C is the number of + classes. + target (tuple([torch.Tensor])): Target category label with shape (N,) + and target quality label with shape (N,). + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + assert len(target) == 2, """target for QFL must be a tuple of two elements, + including category label and quality label, respectively""" + # label denotes the category id, score denotes the quality score + label, score = target + + # negatives are supervised by 0 quality score + pred_sigmoid = pred.sigmoid() + scale_factor = pred_sigmoid + zerolabel = scale_factor.new_zeros(pred.shape) + loss = F.binary_cross_entropy_with_logits( + pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = pred.size(1) + pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) + pos_label = label[pos].long() + # positives are supervised by bbox quality (IoU) score + scale_factor = score[pos] - pred_sigmoid[pos, pos_label] + loss[pos, pos_label] = F.binary_cross_entropy_with_logits( + pred[pos, pos_label], score[pos], + reduction='none') * scale_factor.abs().pow(beta) + + loss = loss.sum(dim=1, keepdim=False) + return loss + + +@weighted_loss +def quality_focal_loss_tensor_target(pred, target, beta=2.0, activated=False): + """`QualityFocal Loss `_ + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning target of the iou-aware + classification score with shape (N, C), C is the number of classes. + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + activated (bool): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. + """ + # pred and target should be of the same size + assert pred.size() == target.size() + if activated: + pred_sigmoid = pred + loss_function = F.binary_cross_entropy + else: + pred_sigmoid = pred.sigmoid() + loss_function = F.binary_cross_entropy_with_logits + + scale_factor = pred_sigmoid + target = target.type_as(pred) + + zerolabel = scale_factor.new_zeros(pred.shape) + loss = loss_function( + pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + pos = (target != 0) + scale_factor = target[pos] - pred_sigmoid[pos] + loss[pos] = loss_function( + pred[pos], target[pos], + reduction='none') * scale_factor.abs().pow(beta) + + loss = loss.sum(dim=1, keepdim=False) + return loss + + +@weighted_loss +def quality_focal_loss_with_prob(pred, target, beta=2.0): + r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + Different from `quality_focal_loss`, this function accepts probability + as input. + + Args: + pred (torch.Tensor): Predicted joint representation of classification + and quality (IoU) estimation with shape (N, C), C is the number of + classes. + target (tuple([torch.Tensor])): Target category label with shape (N,) + and target quality label with shape (N,). + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + assert len(target) == 2, """target for QFL must be a tuple of two elements, + including category label and quality label, respectively""" + # label denotes the category id, score denotes the quality score + label, score = target + + # negatives are supervised by 0 quality score + pred_sigmoid = pred + scale_factor = pred_sigmoid + zerolabel = scale_factor.new_zeros(pred.shape) + loss = F.binary_cross_entropy( + pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = pred.size(1) + pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) + pos_label = label[pos].long() + # positives are supervised by bbox quality (IoU) score + scale_factor = score[pos] - pred_sigmoid[pos, pos_label] + loss[pos, pos_label] = F.binary_cross_entropy( + pred[pos, pos_label], score[pos], + reduction='none') * scale_factor.abs().pow(beta) + + loss = loss.sum(dim=1, keepdim=False) + return loss + + +@weighted_loss +def distribution_focal_loss(pred, label): + r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + + Args: + pred (torch.Tensor): Predicted general distribution of bounding boxes + (before softmax) with shape (N, n+1), n is the max value of the + integral set `{0, ..., n}` in paper. + label (torch.Tensor): Target distance label for bounding boxes with + shape (N,). + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + dis_left = label.long() + dis_right = dis_left + 1 + weight_left = dis_right.float() - label + weight_right = label - dis_left.float() + loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \ + + F.cross_entropy(pred, dis_right, reduction='none') * weight_right + return loss + + +@MODELS.register_module() +class QualityFocalLoss(nn.Module): + r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss: + Learning Qualified and Distributed Bounding Boxes for Dense Object + Detection `_. + + Args: + use_sigmoid (bool): Whether sigmoid operation is conducted in QFL. + Defaults to True. + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Loss weight of current loss. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. + """ + + def __init__(self, + use_sigmoid=True, + beta=2.0, + reduction='mean', + loss_weight=1.0, + activated=False): + super(QualityFocalLoss, self).__init__() + assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' + self.use_sigmoid = use_sigmoid + self.beta = beta + self.reduction = reduction + self.loss_weight = loss_weight + self.activated = activated + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): Predicted joint representation of + classification and quality (IoU) estimation with shape (N, C), + C is the number of classes. + target (Union(tuple([torch.Tensor]),Torch.Tensor)): The type is + tuple, it should be included Target category label with + shape (N,) and target quality label with shape (N,).The type + is torch.Tensor, the target should be one-hot form with + soft weights. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + if self.activated: + calculate_loss_func = quality_focal_loss_with_prob + else: + calculate_loss_func = quality_focal_loss + if isinstance(target, torch.Tensor): + # the target shape with (N,C) or (N,C,...), which means + # the target is one-hot form with soft weights. + calculate_loss_func = partial( + quality_focal_loss_tensor_target, activated=self.activated) + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + weight, + beta=self.beta, + reduction=reduction, + avg_factor=avg_factor) + else: + raise NotImplementedError + return loss_cls + + +@MODELS.register_module() +class DistributionFocalLoss(nn.Module): + r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss: + Learning Qualified and Distributed Bounding Boxes for Dense Object + Detection `_. + + Args: + reduction (str): Options are `'none'`, `'mean'` and `'sum'`. + loss_weight (float): Loss weight of current loss. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super(DistributionFocalLoss, self).__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): Predicted general distribution of bounding + boxes (before softmax) with shape (N, n+1), n is the max value + of the integral set `{0, ..., n}` in paper. + target (torch.Tensor): Target distance label for bounding boxes + with shape (N,). + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_cls = self.loss_weight * distribution_focal_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss_cls diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a874c0038cc4a77769705a3a06a95a56d3e8dd2d --- /dev/null +++ b/mmdet/models/losses/ghm_loss.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.registry import MODELS +from .utils import weight_reduce_loss + + +def _expand_onehot_labels(labels, label_weights, label_channels): + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + inds = torch.nonzero( + (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze() + if inds.numel() > 0: + bin_labels[inds, labels[inds]] = 1 + bin_label_weights = label_weights.view(-1, 1).expand( + label_weights.size(0), label_channels) + return bin_labels, bin_label_weights + + +# TODO: code refactoring to make it consistent with other losses +@MODELS.register_module() +class GHMC(nn.Module): + """GHM Classification Loss. + + Details of the theorem can be viewed in the paper + `Gradient Harmonized Single-stage Detector + `_. + + Args: + bins (int): Number of the unit regions for distribution calculation. + momentum (float): The parameter for moving average. + use_sigmoid (bool): Can only be true for BCE based loss now. + loss_weight (float): The weight of the total GHM-C loss. + reduction (str): Options are "none", "mean" and "sum". + Defaults to "mean" + """ + + def __init__(self, + bins=10, + momentum=0, + use_sigmoid=True, + loss_weight=1.0, + reduction='mean'): + super(GHMC, self).__init__() + self.bins = bins + self.momentum = momentum + edges = torch.arange(bins + 1).float() / bins + self.register_buffer('edges', edges) + self.edges[-1] += 1e-6 + if momentum > 0: + acc_sum = torch.zeros(bins) + self.register_buffer('acc_sum', acc_sum) + self.use_sigmoid = use_sigmoid + if not self.use_sigmoid: + raise NotImplementedError + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, + pred, + target, + label_weight, + reduction_override=None, + **kwargs): + """Calculate the GHM-C loss. + + Args: + pred (float tensor of size [batch_num, class_num]): + The direct prediction of classification fc layer. + target (float tensor of size [batch_num, class_num]): + Binary class target for each sample. + label_weight (float tensor of size [batch_num, class_num]): + the value is 1 if the sample is valid and 0 if ignored. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + Returns: + The gradient harmonized loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + # the target should be binary class label + if pred.dim() != target.dim(): + target, label_weight = _expand_onehot_labels( + target, label_weight, pred.size(-1)) + target, label_weight = target.float(), label_weight.float() + edges = self.edges + mmt = self.momentum + weights = torch.zeros_like(pred) + + # gradient length + g = torch.abs(pred.sigmoid().detach() - target) + + valid = label_weight > 0 + tot = max(valid.float().sum().item(), 1.0) + n = 0 # n valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i + 1]) & valid + num_in_bin = inds.sum().item() + if num_in_bin > 0: + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + weights[inds] = tot / self.acc_sum[i] + else: + weights[inds] = tot / num_in_bin + n += 1 + if n > 0: + weights = weights / n + + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') + loss = weight_reduce_loss( + loss, weights, reduction=reduction, avg_factor=tot) + return loss * self.loss_weight + + +# TODO: code refactoring to make it consistent with other losses +@MODELS.register_module() +class GHMR(nn.Module): + """GHM Regression Loss. + + Details of the theorem can be viewed in the paper + `Gradient Harmonized Single-stage Detector + `_. + + Args: + mu (float): The parameter for the Authentic Smooth L1 loss. + bins (int): Number of the unit regions for distribution calculation. + momentum (float): The parameter for moving average. + loss_weight (float): The weight of the total GHM-R loss. + reduction (str): Options are "none", "mean" and "sum". + Defaults to "mean" + """ + + def __init__(self, + mu=0.02, + bins=10, + momentum=0, + loss_weight=1.0, + reduction='mean'): + super(GHMR, self).__init__() + self.mu = mu + self.bins = bins + edges = torch.arange(bins + 1).float() / bins + self.register_buffer('edges', edges) + self.edges[-1] = 1e3 + self.momentum = momentum + if momentum > 0: + acc_sum = torch.zeros(bins) + self.register_buffer('acc_sum', acc_sum) + self.loss_weight = loss_weight + self.reduction = reduction + + # TODO: support reduction parameter + def forward(self, + pred, + target, + label_weight, + avg_factor=None, + reduction_override=None): + """Calculate the GHM-R loss. + + Args: + pred (float tensor of size [batch_num, 4 (* class_num)]): + The prediction of box regression layer. Channel number can be 4 + or 4 * class_num depending on whether it is class-agnostic. + target (float tensor of size [batch_num, 4 (* class_num)]): + The target regression values with the same size of pred. + label_weight (float tensor of size [batch_num, 4 (* class_num)]): + The weight of each sample, 0 if ignored. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + Returns: + The gradient harmonized loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + mu = self.mu + edges = self.edges + mmt = self.momentum + + # ASL1 loss + diff = pred - target + loss = torch.sqrt(diff * diff + mu * mu) - mu + + # gradient length + g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() + weights = torch.zeros_like(g) + + valid = label_weight > 0 + tot = max(label_weight.float().sum().item(), 1.0) + n = 0 # n: valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i + 1]) & valid + num_in_bin = inds.sum().item() + if num_in_bin > 0: + n += 1 + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + weights[inds] = tot / self.acc_sum[i] + else: + weights[inds] = tot / num_in_bin + if n > 0: + weights /= n + loss = weight_reduce_loss( + loss, weights, reduction=reduction, avg_factor=tot) + return loss * self.loss_weight diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a2b977868cef6f4039b49277bfc853ffc720bd --- /dev/null +++ b/mmdet/models/losses/iou_loss.py @@ -0,0 +1,926 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox_overlaps +from .utils import weighted_loss + + +@weighted_loss +def iou_loss(pred: Tensor, + target: Tensor, + linear: bool = False, + mode: str = 'log', + eps: float = 1e-6) -> Tensor: + """IoU loss. + + Computing the IoU loss between a set of predicted bboxes and target bboxes. + The loss is calculated as negative log of IoU. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + linear (bool, optional): If True, use linear scale of loss instead of + log scale. Default: False. + mode (str): Loss scaling mode, including "linear", "square", and "log". + Default: 'log' + eps (float): Epsilon to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + assert mode in ['linear', 'square', 'log'] + if linear: + mode = 'linear' + warnings.warn('DeprecationWarning: Setting "linear=True" in ' + 'iou_loss is deprecated, please use "mode=`linear`" ' + 'instead.') + # avoid fp16 overflow + if pred.dtype == torch.float16: + fp16 = True + pred = pred.to(torch.float32) + else: + fp16 = False + + ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps) + + if fp16: + ious = ious.to(torch.float16) + + if mode == 'linear': + loss = 1 - ious + elif mode == 'square': + loss = 1 - ious**2 + elif mode == 'log': + loss = -ious.log() + else: + raise NotImplementedError + return loss + + +@weighted_loss +def bounded_iou_loss(pred: Tensor, + target: Tensor, + beta: float = 0.2, + eps: float = 1e-3) -> Tensor: + """BIoULoss. + + This is an implementation of paper + `Improving Object Localization with Fitness NMS and Bounded IoU Loss. + `_. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + beta (float, optional): Beta parameter in smoothl1. + eps (float, optional): Epsilon to avoid NaN values. + + Return: + Tensor: Loss tensor. + """ + pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5 + pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5 + pred_w = pred[:, 2] - pred[:, 0] + pred_h = pred[:, 3] - pred[:, 1] + with torch.no_grad(): + target_ctrx = (target[:, 0] + target[:, 2]) * 0.5 + target_ctry = (target[:, 1] + target[:, 3]) * 0.5 + target_w = target[:, 2] - target[:, 0] + target_h = target[:, 3] - target[:, 1] + + dx = target_ctrx - pred_ctrx + dy = target_ctry - pred_ctry + + loss_dx = 1 - torch.max( + (target_w - 2 * dx.abs()) / + (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx)) + loss_dy = 1 - torch.max( + (target_h - 2 * dy.abs()) / + (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy)) + loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w / + (target_w + eps)) + loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h / + (target_h + eps)) + # view(..., -1) does not work for empty tensor + loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh], + dim=-1).flatten(1) + + loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta, + loss_comb - 0.5 * beta) + return loss + + +@weighted_loss +def giou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor: + r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding + Box Regression `_. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Epsilon to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + # avoid fp16 overflow + if pred.dtype == torch.float16: + fp16 = True + pred = pred.to(torch.float32) + else: + fp16 = False + + gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps) + + if fp16: + gious = gious.to(torch.float16) + + loss = 1 - gious + return loss + + +@weighted_loss +def diou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor: + r"""Implementation of `Distance-IoU Loss: Faster and Better + Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_. + + Code is modified from https://github.com/Zzh-tju/DIoU. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Epsilon to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + # overlap + lt = torch.max(pred[:, :2], target[:, :2]) + rb = torch.min(pred[:, 2:], target[:, 2:]) + wh = (rb - lt).clamp(min=0) + overlap = wh[:, 0] * wh[:, 1] + + # union + ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1]) + ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = ap + ag - overlap + eps + + # IoU + ious = overlap / union + + # enclose area + enclose_x1y1 = torch.min(pred[:, :2], target[:, :2]) + enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:]) + enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0) + + cw = enclose_wh[:, 0] + ch = enclose_wh[:, 1] + + c2 = cw**2 + ch**2 + eps + + b1_x1, b1_y1 = pred[:, 0], pred[:, 1] + b1_x2, b1_y2 = pred[:, 2], pred[:, 3] + b2_x1, b2_y1 = target[:, 0], target[:, 1] + b2_x2, b2_y2 = target[:, 2], target[:, 3] + + left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4 + right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4 + rho2 = left + right + + # DIoU + dious = ious - rho2 / c2 + loss = 1 - dious + return loss + + +@weighted_loss +def ciou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor: + r"""`Implementation of paper `Enhancing Geometric Factors into + Model Learning and Inference for Object Detection and Instance + Segmentation `_. + + Code is modified from https://github.com/Zzh-tju/CIoU. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Epsilon to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + # overlap + lt = torch.max(pred[:, :2], target[:, :2]) + rb = torch.min(pred[:, 2:], target[:, 2:]) + wh = (rb - lt).clamp(min=0) + overlap = wh[:, 0] * wh[:, 1] + + # union + ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1]) + ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = ap + ag - overlap + eps + + # IoU + ious = overlap / union + + # enclose area + enclose_x1y1 = torch.min(pred[:, :2], target[:, :2]) + enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:]) + enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0) + + cw = enclose_wh[:, 0] + ch = enclose_wh[:, 1] + + c2 = cw**2 + ch**2 + eps + + b1_x1, b1_y1 = pred[:, 0], pred[:, 1] + b1_x2, b1_y2 = pred[:, 2], pred[:, 3] + b2_x1, b2_y1 = target[:, 0], target[:, 1] + b2_x2, b2_y2 = target[:, 2], target[:, 3] + + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4 + right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4 + rho2 = left + right + + factor = 4 / math.pi**2 + v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + + with torch.no_grad(): + alpha = (ious > 0.5).float() * v / (1 - ious + v) + + # CIoU + cious = ious - (rho2 / c2 + alpha * v) + loss = 1 - cious.clamp(min=-1.0, max=1.0) + return loss + + +@weighted_loss +def eiou_loss(pred: Tensor, + target: Tensor, + smooth_point: float = 0.1, + eps: float = 1e-7) -> Tensor: + r"""Implementation of paper `Extended-IoU Loss: A Systematic + IoU-Related Method: Beyond Simplified Regression for Better + Localization `_ + + Code is modified from https://github.com//ShiqiYu/libfacedetection.train. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + smooth_point (float): hyperparameter, default is 0.1. + eps (float): Epsilon to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + px1, py1, px2, py2 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3] + tx1, ty1, tx2, ty2 = target[:, 0], target[:, 1], target[:, 2], target[:, 3] + + # extent top left + ex1 = torch.min(px1, tx1) + ey1 = torch.min(py1, ty1) + + # intersection coordinates + ix1 = torch.max(px1, tx1) + iy1 = torch.max(py1, ty1) + ix2 = torch.min(px2, tx2) + iy2 = torch.min(py2, ty2) + + # extra + xmin = torch.min(ix1, ix2) + ymin = torch.min(iy1, iy2) + xmax = torch.max(ix1, ix2) + ymax = torch.max(iy1, iy2) + + # Intersection + intersection = (ix2 - ex1) * (iy2 - ey1) + (xmin - ex1) * (ymin - ey1) - ( + ix1 - ex1) * (ymax - ey1) - (xmax - ex1) * ( + iy1 - ey1) + # Union + union = (px2 - px1) * (py2 - py1) + (tx2 - tx1) * ( + ty2 - ty1) - intersection + eps + # IoU + ious = 1 - (intersection / union) + + # Smooth-EIoU + smooth_sign = (ious < smooth_point).detach().float() + loss = 0.5 * smooth_sign * (ious**2) / smooth_point + (1 - smooth_sign) * ( + ious - 0.5 * smooth_point) + return loss + + +@weighted_loss +def siou_loss(pred, target, eps=1e-7, neg_gamma=False): + r"""`Implementation of paper `SIoU Loss: More Powerful Learning + for Bounding Box Regression `_. + + Code is modified from https://github.com/meituan/YOLOv6. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Eps to avoid log(0). + neg_gamma (bool): `True` follows original implementation in paper. + + Return: + Tensor: Loss tensor. + """ + # overlap + lt = torch.max(pred[:, :2], target[:, :2]) + rb = torch.min(pred[:, 2:], target[:, 2:]) + wh = (rb - lt).clamp(min=0) + overlap = wh[:, 0] * wh[:, 1] + + # union + ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1]) + ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = ap + ag - overlap + eps + + # IoU + ious = overlap / union + + # enclose area + enclose_x1y1 = torch.min(pred[:, :2], target[:, :2]) + enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:]) + # modified clamp threshold zero to eps to avoid NaN + enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=eps) + + cw = enclose_wh[:, 0] + ch = enclose_wh[:, 1] + + b1_x1, b1_y1 = pred[:, 0], pred[:, 1] + b1_x2, b1_y2 = pred[:, 2], pred[:, 3] + b2_x1, b2_y1 = target[:, 0], target[:, 1] + b2_x2, b2_y2 = target[:, 2], target[:, 3] + + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # angle cost + s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps + s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps + + sigma = torch.pow(s_cw**2 + s_ch**2, 0.5) + + sin_alpha_1 = torch.abs(s_cw) / sigma + sin_alpha_2 = torch.abs(s_ch) / sigma + threshold = pow(2, 0.5) / 2 + sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1) + angle_cost = torch.cos(torch.asin(sin_alpha) * 2 - math.pi / 2) + + # distance cost + rho_x = (s_cw / cw)**2 + rho_y = (s_ch / ch)**2 + + # `neg_gamma=True` follows original implementation in paper + # but setting `neg_gamma=False` makes training more stable. + gamma = angle_cost - 2 if neg_gamma else 2 - angle_cost + distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y) + + # shape cost + omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) + omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) + shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow( + 1 - torch.exp(-1 * omiga_h), 4) + + # SIoU + sious = ious - 0.5 * (distance_cost + shape_cost) + loss = 1 - sious.clamp(min=-1.0, max=1.0) + return loss + + +@MODELS.register_module() +class IoULoss(nn.Module): + """IoULoss. + + Computing the IoU loss between a set of predicted bboxes and target bboxes. + + Args: + linear (bool): If True, use linear scale of loss else determined + by mode. Default: False. + eps (float): Epsilon to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + mode (str): Loss scaling mode, including "linear", "square", and "log". + Default: 'log' + """ + + def __init__(self, + linear: bool = False, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0, + mode: str = 'log') -> None: + super().__init__() + assert mode in ['linear', 'square', 'log'] + if linear: + mode = 'linear' + warnings.warn('DeprecationWarning: Setting "linear=True" in ' + 'IOULoss is deprecated, please use "mode=`linear`" ' + 'instead.') + self.mode = mode + self.linear = linear + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Return: + Tensor: Loss tensor. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if (weight is not None) and (not torch.any(weight > 0)) and ( + reduction != 'none'): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # iou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * iou_loss( + pred, + target, + weight, + mode=self.mode, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@MODELS.register_module() +class BoundedIoULoss(nn.Module): + """BIoULoss. + + This is an implementation of paper + `Improving Object Localization with Fitness NMS and Bounded IoU Loss. + `_. + + Args: + beta (float, optional): Beta parameter in smoothl1. + eps (float, optional): Epsilon to avoid NaN values. + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + """ + + def __init__(self, + beta: float = 0.2, + eps: float = 1e-3, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.beta = beta + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss = self.loss_weight * bounded_iou_loss( + pred, + target, + weight, + beta=self.beta, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@MODELS.register_module() +class GIoULoss(nn.Module): + r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding + Box Regression `_. + + Args: + eps (float): Epsilon to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + """ + + def __init__(self, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # giou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * giou_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@MODELS.register_module() +class DIoULoss(nn.Module): + r"""Implementation of `Distance-IoU Loss: Faster and Better + Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_. + + Code is modified from https://github.com/Zzh-tju/DIoU. + + Args: + eps (float): Epsilon to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + """ + + def __init__(self, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # giou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * diou_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@MODELS.register_module() +class CIoULoss(nn.Module): + r"""`Implementation of paper `Enhancing Geometric Factors into + Model Learning and Inference for Object Detection and Instance + Segmentation `_. + + Code is modified from https://github.com/Zzh-tju/CIoU. + + Args: + eps (float): Epsilon to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + """ + + def __init__(self, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # giou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * ciou_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@MODELS.register_module() +class EIoULoss(nn.Module): + r"""Implementation of paper `Extended-IoU Loss: A Systematic + IoU-Related Method: Beyond Simplified Regression for Better + Localization `_ + + Code is modified from https://github.com//ShiqiYu/libfacedetection.train. + + Args: + eps (float): Epsilon to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + smooth_point (float): hyperparameter, default is 0.1. + """ + + def __init__(self, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0, + smooth_point: float = 0.1) -> None: + super().__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + self.smooth_point = smooth_point + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * eiou_loss( + pred, + target, + weight, + smooth_point=self.smooth_point, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss + + +@MODELS.register_module() +class SIoULoss(nn.Module): + r"""`Implementation of paper `SIoU Loss: More Powerful Learning + for Bounding Box Regression `_. + + Code is modified from https://github.com/meituan/YOLOv6. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Eps to avoid log(0). + neg_gamma (bool): `True` follows original implementation in paper. + + Return: + Tensor: Loss tensor. + """ + + def __init__(self, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0, + neg_gamma: bool = False) -> None: + super().__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + self.neg_gamma = neg_gamma + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # giou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * siou_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + neg_gamma=self.neg_gamma, + **kwargs) + return loss diff --git a/mmdet/models/losses/kd_loss.py b/mmdet/models/losses/kd_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7d5ef24a0b0d7d7390a27c7cd9cbfdbe61d823 --- /dev/null +++ b/mmdet/models/losses/kd_loss.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weighted_loss + + +@weighted_loss +def knowledge_distillation_kl_div_loss(pred: Tensor, + soft_label: Tensor, + T: int, + detach_target: bool = True) -> Tensor: + r"""Loss function for knowledge distilling using KL divergence. + + Args: + pred (Tensor): Predicted logits with shape (N, n + 1). + soft_label (Tensor): Target logits with shape (N, N + 1). + T (int): Temperature for distillation. + detach_target (bool): Remove soft_label from automatic differentiation + + Returns: + Tensor: Loss tensor with shape (N,). + """ + assert pred.size() == soft_label.size() + target = F.softmax(soft_label / T, dim=1) + if detach_target: + target = target.detach() + + kd_loss = F.kl_div( + F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * ( + T * T) + + return kd_loss + + +@MODELS.register_module() +class KnowledgeDistillationKLDivLoss(nn.Module): + """Loss function for knowledge distilling using KL divergence. + + Args: + reduction (str): Options are `'none'`, `'mean'` and `'sum'`. + loss_weight (float): Loss weight of current loss. + T (int): Temperature for distillation. + """ + + def __init__(self, + reduction: str = 'mean', + loss_weight: float = 1.0, + T: int = 10) -> None: + super().__init__() + assert T >= 1 + self.reduction = reduction + self.loss_weight = loss_weight + self.T = T + + def forward(self, + pred: Tensor, + soft_label: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted logits with shape (N, n + 1). + soft_label (Tensor): Target logits with shape (N, N + 1). + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + + Returns: + Tensor: Loss tensor. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + + reduction = ( + reduction_override if reduction_override else self.reduction) + + loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss( + pred, + soft_label, + weight, + reduction=reduction, + avg_factor=avg_factor, + T=self.T) + + return loss_kd diff --git a/mmdet/models/losses/l2_loss.py b/mmdet/models/losses/l2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6210a3007b2c39540f022925cc93181c7328e42d --- /dev/null +++ b/mmdet/models/losses/l2_loss.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weighted_loss + + +@weighted_loss +def l2_loss(pred: Tensor, target: Tensor) -> Tensor: + """L2 loss. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + + Returns: + torch.Tensor: Calculated loss + """ + assert pred.size() == target.size() + loss = torch.abs(pred - target)**2 + return loss + + +@MODELS.register_module() +class L2Loss(BaseModule): + """L2 loss. + + Args: + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of loss. + """ + + def __init__(self, + neg_pos_ub: int = -1, + pos_margin: float = -1, + neg_margin: float = -1, + hard_mining: bool = False, + reduction: str = 'mean', + loss_weight: float = 1.0): + super(L2Loss, self).__init__() + self.neg_pos_ub = neg_pos_ub + self.pos_margin = pos_margin + self.neg_margin = neg_margin + self.hard_mining = hard_mining + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[float] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (float, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + pred, weight, avg_factor = self.update_weight(pred, target, weight, + avg_factor) + loss_bbox = self.loss_weight * l2_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss_bbox + + def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor, + avg_factor: float) -> Tuple[Tensor, Tensor, float]: + """Update the weight according to targets.""" + if weight is None: + weight = target.new_ones(target.size()) + + invalid_inds = weight <= 0 + target[invalid_inds] = -1 + pos_inds = target == 1 + neg_inds = target == 0 + + if self.pos_margin > 0: + pred[pos_inds] -= self.pos_margin + if self.neg_margin > 0: + pred[neg_inds] -= self.neg_margin + pred = torch.clamp(pred, min=0, max=1) + + num_pos = int((target == 1).sum()) + num_neg = int((target == 0).sum()) + if self.neg_pos_ub > 0 and num_neg / (num_pos + + 1e-6) > self.neg_pos_ub: + num_neg = num_pos * self.neg_pos_ub + neg_idx = torch.nonzero(target == 0, as_tuple=False) + + if self.hard_mining: + costs = l2_loss( + pred, target, reduction='none')[neg_idx[:, 0], + neg_idx[:, 1]].detach() + neg_idx = neg_idx[costs.topk(num_neg)[1], :] + else: + neg_idx = self.random_choice(neg_idx, num_neg) + + new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool() + new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True + + invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds) + weight[invalid_neg_inds] = 0 + + avg_factor = (weight > 0).sum() + return pred, weight, avg_factor + + @staticmethod + def random_choice(gallery: Union[list, np.ndarray, Tensor], + num: int) -> np.ndarray: + """Random select some elements from the gallery. + + It seems that Pytorch's implementation is slower than numpy so we use + numpy to randperm the indices. + """ + assert len(gallery) >= num + if isinstance(gallery, list): + gallery = np.array(gallery) + cands = np.arange(len(gallery)) + np.random.shuffle(cands) + rand_inds = cands[:num] + if not isinstance(gallery, np.ndarray): + rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) + return gallery[rand_inds] diff --git a/mmdet/models/losses/margin_loss.py b/mmdet/models/losses/margin_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0609e1db50edf89c8ae8b65709e8ab786f580366 --- /dev/null +++ b/mmdet/models/losses/margin_loss.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from .mse_loss import mse_loss + + +@MODELS.register_module() +class MarginL2Loss(BaseModule): + """L2 loss with margin. + + Args: + neg_pos_ub (int, optional): The upper bound of negative to positive + samples in hard mining. Defaults to -1. + pos_margin (float, optional): The similarity margin for positive + samples in hard mining. Defaults to -1. + neg_margin (float, optional): The similarity margin for negative + samples in hard mining. Defaults to -1. + hard_mining (bool, optional): Whether to use hard mining. Defaults to + False. + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". Defaults to "mean". + loss_weight (float, optional): The weight of loss. Defaults to 1.0. + """ + + def __init__(self, + neg_pos_ub: int = -1, + pos_margin: float = -1, + neg_margin: float = -1, + hard_mining: bool = False, + reduction: str = 'mean', + loss_weight: float = 1.0): + super(MarginL2Loss, self).__init__() + self.neg_pos_ub = neg_pos_ub + self.pos_margin = pos_margin + self.neg_margin = neg_margin + self.hard_mining = hard_mining + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[float] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (float, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + pred, weight, avg_factor = self.update_weight(pred, target, weight, + avg_factor) + loss_bbox = self.loss_weight * mse_loss( + pred, + target.float(), + weight.float(), + reduction=reduction, + avg_factor=avg_factor) + return loss_bbox + + def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor, + avg_factor: float) -> Tuple[Tensor, Tensor, float]: + """Update the weight according to targets. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + weight (torch.Tensor): The weight of loss for each prediction. + avg_factor (float): Average factor that is used to average the + loss. + + Returns: + tuple[torch.Tensor]: The updated prediction, weight and average + factor. + """ + if weight is None: + weight = target.new_ones(target.size()) + + invalid_inds = weight <= 0 + target[invalid_inds] = -1 + pos_inds = target == 1 + neg_inds = target == 0 + + if self.pos_margin > 0: + pred[pos_inds] -= self.pos_margin + if self.neg_margin > 0: + pred[neg_inds] -= self.neg_margin + pred = torch.clamp(pred, min=0, max=1) + + num_pos = int((target == 1).sum()) + num_neg = int((target == 0).sum()) + if self.neg_pos_ub > 0 and num_neg / (num_pos + + 1e-6) > self.neg_pos_ub: + num_neg = num_pos * self.neg_pos_ub + neg_idx = torch.nonzero(target == 0, as_tuple=False) + + if self.hard_mining: + costs = mse_loss( + pred, target.float(), + reduction='none')[neg_idx[:, 0], neg_idx[:, 1]].detach() + neg_idx = neg_idx[costs.topk(num_neg)[1], :] + else: + neg_idx = self.random_choice(neg_idx, num_neg) + + new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool() + new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True + + invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds) + weight[invalid_neg_inds] = 0 + + avg_factor = (weight > 0).sum() + return pred, weight, avg_factor + + @staticmethod + def random_choice(gallery: Union[list, np.ndarray, Tensor], + num: int) -> np.ndarray: + """Random select some elements from the gallery. + + It seems that Pytorch's implementation is slower than numpy so we use + numpy to randperm the indices. + + Args: + gallery (list | np.ndarray | torch.Tensor): The gallery from + which to sample. + num (int): The number of elements to sample. + """ + assert len(gallery) >= num + if isinstance(gallery, list): + gallery = np.array(gallery) + cands = np.arange(len(gallery)) + np.random.shuffle(cands) + rand_inds = cands[:num] + if not isinstance(gallery, np.ndarray): + rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) + return gallery[rand_inds] diff --git a/mmdet/models/losses/mse_loss.py b/mmdet/models/losses/mse_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6048218ad36a8105e7fa182f40fae93ef7c9268f --- /dev/null +++ b/mmdet/models/losses/mse_loss.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weighted_loss + + +@weighted_loss +def mse_loss(pred: Tensor, target: Tensor) -> Tensor: + """A Wrapper of MSE loss. + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + + Returns: + Tensor: loss Tensor + """ + return F.mse_loss(pred, target, reduction='none') + + +@MODELS.register_module() +class MSELoss(nn.Module): + """MSELoss. + + Args: + reduction (str, optional): The method that reduces the loss to a + scalar. Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function of loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + weight (Tensor, optional): Weight of the loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + + Returns: + Tensor: The calculated loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss = self.loss_weight * mse_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss diff --git a/mmdet/models/losses/multipos_cross_entropy_loss.py b/mmdet/models/losses/multipos_cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d1561ed414b7c15412b5e746dff39ca0c53ba1 --- /dev/null +++ b/mmdet/models/losses/multipos_cross_entropy_loss.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weight_reduce_loss + + +@MODELS.register_module() +class MultiPosCrossEntropyLoss(BaseModule): + """multi-positive targets cross entropy loss. + + Args: + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". Defaults to "mean". + loss_weight (float, optional): The weight of loss. Defaults to 1.0. + """ + + def __init__(self, reduction: str = 'mean', loss_weight: float = 1.0): + super(MultiPosCrossEntropyLoss, self).__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def multi_pos_cross_entropy(self, + pred: Tensor, + label: Tensor, + weight: Optional[Tensor] = None, + reduction: str = 'mean', + avg_factor: Optional[float] = None) -> Tensor: + """Multi-positive targets cross entropy loss. + + Args: + pred (torch.Tensor): The prediction. + label (torch.Tensor): The assigned label of the prediction. + weight (torch.Tensor): The element-wise weight. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing + the mean of losses. + + Returns: + torch.Tensor: Calculated loss + """ + + pos_inds = (label >= 1) + neg_inds = (label == 0) + pred_pos = pred * pos_inds.float() + pred_neg = pred * neg_inds.float() + # use -inf to mask out unwanted elements. + pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf') + pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf') + + _pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1) + _neg_expand = pred_neg.repeat(1, pred.shape[1]) + + x = torch.nn.functional.pad((_neg_expand - _pos_expand), (0, 1), + 'constant', 0) + loss = torch.logsumexp(x, dim=1) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + def forward(self, + cls_score: Tensor, + label: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[float] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + cls_score (torch.Tensor): The classification score. + label (torch.Tensor): The assigned label of the prediction. + weight (torch.Tensor): The element-wise weight. + avg_factor (float): Average factor when computing + the mean of losses. + reduction_override (str): Same as built-in losses of PyTorch. + + Returns: + torch.Tensor: Calculated loss + """ + assert cls_score.size() == label.size() + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_cls = self.loss_weight * self.multi_pos_cross_entropy( + cls_score, + label, + weight, + reduction=reduction, + avg_factor=avg_factor) + return loss_cls diff --git a/mmdet/models/losses/pisa_loss.py b/mmdet/models/losses/pisa_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b192aa0dbc7eb554755eb2f242eab0ea7f1fc650 --- /dev/null +++ b/mmdet/models/losses/pisa_loss.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.structures.bbox import bbox_overlaps +from ..task_modules.coders import BaseBBoxCoder +from ..task_modules.samplers import SamplingResult + + +def isr_p(cls_score: Tensor, + bbox_pred: Tensor, + bbox_targets: Tuple[Tensor], + rois: Tensor, + sampling_results: List[SamplingResult], + loss_cls: nn.Module, + bbox_coder: BaseBBoxCoder, + k: float = 2, + bias: float = 0, + num_class: int = 80) -> tuple: + """Importance-based Sample Reweighting (ISR_P), positive part. + + Args: + cls_score (Tensor): Predicted classification scores. + bbox_pred (Tensor): Predicted bbox deltas. + bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are + labels, label_weights, bbox_targets, bbox_weights, respectively. + rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs + (two_stage) in shape (n, 5). + sampling_results (:obj:`SamplingResult`): Sampling results. + loss_cls (:obj:`nn.Module`): Classification loss func of the head. + bbox_coder (:obj:`BaseBBoxCoder`): BBox coder of the head. + k (float): Power of the non-linear mapping. Defaults to 2. + bias (float): Shift of the non-linear mapping. Defaults to 0. + num_class (int): Number of classes, defaults to 80. + + Return: + tuple([Tensor]): labels, imp_based_label_weights, bbox_targets, + bbox_target_weights + """ + + labels, label_weights, bbox_targets, bbox_weights = bbox_targets + pos_label_inds = ((labels >= 0) & + (labels < num_class)).nonzero().reshape(-1) + pos_labels = labels[pos_label_inds] + + # if no positive samples, return the original targets + num_pos = float(pos_label_inds.size(0)) + if num_pos == 0: + return labels, label_weights, bbox_targets, bbox_weights + + # merge pos_assigned_gt_inds of per image to a single tensor + gts = list() + last_max_gt = 0 + for i in range(len(sampling_results)): + gt_i = sampling_results[i].pos_assigned_gt_inds + gts.append(gt_i + last_max_gt) + if len(gt_i) != 0: + last_max_gt = gt_i.max() + 1 + gts = torch.cat(gts) + assert len(gts) == num_pos + + cls_score = cls_score.detach() + bbox_pred = bbox_pred.detach() + + # For single stage detectors, rois here indicate anchors, in shape (N, 4) + # For two stage detectors, rois are in shape (N, 5) + if rois.size(-1) == 5: + pos_rois = rois[pos_label_inds][:, 1:] + else: + pos_rois = rois[pos_label_inds] + + if bbox_pred.size(-1) > 4: + bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4) + pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4) + else: + pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4) + + # compute iou of the predicted bbox and the corresponding GT + pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4) + pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred) + target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target) + ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True) + + pos_imp_weights = label_weights[pos_label_inds] + # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally, + # then sorted again within the same-rank group + max_l_num = pos_labels.bincount().max() + for label in pos_labels.unique(): + l_inds = (pos_labels == label).nonzero().view(-1) + l_gts = gts[l_inds] + for t in l_gts.unique(): + t_inds = l_inds[l_gts == t] + t_ious = ious[t_inds] + _, t_iou_rank_idx = t_ious.sort(descending=True) + _, t_iou_rank = t_iou_rank_idx.sort() + ious[t_inds] += max_l_num - t_iou_rank.float() + l_ious = ious[l_inds] + _, l_iou_rank_idx = l_ious.sort(descending=True) + _, l_iou_rank = l_iou_rank_idx.sort() # IoU-HLR + # linearly map HLR to label weights + pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num + + pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k) + + # normalize to make the new weighted loss value equal to the original loss + pos_loss_cls = loss_cls( + cls_score[pos_label_inds], pos_labels, reduction_override='none') + if pos_loss_cls.dim() > 1: + ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:, + None] + new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None] + else: + ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds] + new_pos_loss_cls = pos_loss_cls * pos_imp_weights + pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum() + pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio + label_weights[pos_label_inds] = pos_imp_weights + + bbox_targets = labels, label_weights, bbox_targets, bbox_weights + return bbox_targets + + +def carl_loss(cls_score: Tensor, + labels: Tensor, + bbox_pred: Tensor, + bbox_targets: Tensor, + loss_bbox: nn.Module, + k: float = 1, + bias: float = 0.2, + avg_factor: Optional[int] = None, + sigmoid: bool = False, + num_class: int = 80) -> dict: + """Classification-Aware Regression Loss (CARL). + + Args: + cls_score (Tensor): Predicted classification scores. + labels (Tensor): Targets of classification. + bbox_pred (Tensor): Predicted bbox deltas. + bbox_targets (Tensor): Target of bbox regression. + loss_bbox (func): Regression loss func of the head. + bbox_coder (obj): BBox coder of the head. + k (float): Power of the non-linear mapping. Defaults to 1. + bias (float): Shift of the non-linear mapping. Defaults to 0.2. + avg_factor (int, optional): Average factor used in regression loss. + sigmoid (bool): Activation of the classification score. + num_class (int): Number of classes, defaults to 80. + + Return: + dict: CARL loss dict. + """ + pos_label_inds = ((labels >= 0) & + (labels < num_class)).nonzero().reshape(-1) + if pos_label_inds.numel() == 0: + return dict(loss_carl=cls_score.sum()[None] * 0.) + pos_labels = labels[pos_label_inds] + + # multiply pos_cls_score with the corresponding bbox weight + # and remain gradient + if sigmoid: + pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels] + else: + pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels] + carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k) + + # normalize carl_loss_weight to make its sum equal to num positive + num_pos = float(pos_cls_score.size(0)) + weight_ratio = num_pos / carl_loss_weights.sum() + carl_loss_weights *= weight_ratio + + if avg_factor is None: + avg_factor = bbox_targets.size(0) + # if is class agnostic, bbox pred is in shape (N, 4) + # otherwise, bbox pred is in shape (N, #classes, 4) + if bbox_pred.size(-1) > 4: + bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4) + pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels] + else: + pos_bbox_preds = bbox_pred[pos_label_inds] + ori_loss_reg = loss_bbox( + pos_bbox_preds, + bbox_targets[pos_label_inds], + reduction_override='none') / avg_factor + loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum() + return dict(loss_carl=loss_carl[None]) diff --git a/mmdet/models/losses/seesaw_loss.py b/mmdet/models/losses/seesaw_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4dec62b0afdc01e848e0c7f53ba0b6b10b899ea4 --- /dev/null +++ b/mmdet/models/losses/seesaw_loss.py @@ -0,0 +1,278 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import MODELS +from .accuracy import accuracy +from .cross_entropy_loss import cross_entropy +from .utils import weight_reduce_loss + + +def seesaw_ce_loss(cls_score: Tensor, + labels: Tensor, + label_weights: Tensor, + cum_samples: Tensor, + num_classes: int, + p: float, + q: float, + eps: float, + reduction: str = 'mean', + avg_factor: Optional[int] = None) -> Tensor: + """Calculate the Seesaw CrossEntropy loss. + + Args: + cls_score (Tensor): The prediction with shape (N, C), + C is the number of classes. + labels (Tensor): The learning label of the prediction. + label_weights (Tensor): Sample-wise loss weight. + cum_samples (Tensor): Cumulative samples for each category. + num_classes (int): The number of classes. + p (float): The ``p`` in the mitigation factor. + q (float): The ``q`` in the compenstation factor. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor + reduction (str, optional): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + Tensor: The calculated loss + """ + assert cls_score.size(-1) == num_classes + assert len(cum_samples) == num_classes + + onehot_labels = F.one_hot(labels, num_classes) + seesaw_weights = cls_score.new_ones(onehot_labels.size()) + + # mitigation factor + if p > 0: + sample_ratio_matrix = cum_samples[None, :].clamp( + min=1) / cum_samples[:, None].clamp(min=1) + index = (sample_ratio_matrix < 1.0).float() + sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index) + mitigation_factor = sample_weights[labels.long(), :] + seesaw_weights = seesaw_weights * mitigation_factor + + # compensation factor + if q > 0: + scores = F.softmax(cls_score.detach(), dim=1) + self_scores = scores[ + torch.arange(0, len(scores)).to(scores.device).long(), + labels.long()] + score_matrix = scores / self_scores[:, None].clamp(min=eps) + index = (score_matrix > 1.0).float() + compensation_factor = score_matrix.pow(q) * index + (1 - index) + seesaw_weights = seesaw_weights * compensation_factor + + cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) + + loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') + + if label_weights is not None: + label_weights = label_weights.float() + loss = weight_reduce_loss( + loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class SeesawLoss(nn.Module): + """ + Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) + arXiv: https://arxiv.org/abs/2008.10032 + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Only False is supported. + p (float, optional): The ``p`` in the mitigation factor. + Defaults to 0.8. + q (float, optional): The ``q`` in the compenstation factor. + Defaults to 2.0. + num_classes (int, optional): The number of classes. + Default to 1203 for LVIS v1 dataset. + eps (float, optional): The minimal value of divisor to smooth + the computation of compensation factor + reduction (str, optional): The method that reduces the loss to a + scalar. Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of the loss. Defaults to 1.0 + return_dict (bool, optional): Whether return the losses as a dict. + Default to True. + """ + + def __init__(self, + use_sigmoid: bool = False, + p: float = 0.8, + q: float = 2.0, + num_classes: int = 1203, + eps: float = 1e-2, + reduction: str = 'mean', + loss_weight: float = 1.0, + return_dict: bool = True) -> None: + super().__init__() + assert not use_sigmoid + self.use_sigmoid = False + self.p = p + self.q = q + self.num_classes = num_classes + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + self.return_dict = return_dict + + # 0 for pos, 1 for neg + self.cls_criterion = seesaw_ce_loss + + # cumulative samples for each category + self.register_buffer( + 'cum_samples', + torch.zeros(self.num_classes + 1, dtype=torch.float)) + + # custom output channels of the classifier + self.custom_cls_channels = True + # custom activation of cls_score + self.custom_activation = True + # custom accuracy of the classsifier + self.custom_accuracy = True + + def _split_cls_score(self, cls_score: Tensor) -> Tuple[Tensor, Tensor]: + """split cls_score. + + Args: + cls_score (Tensor): The prediction with shape (N, C + 2). + + Returns: + Tuple[Tensor, Tensor]: The score for classes and objectness, + respectively + """ + # split cls_score to cls_score_classes and cls_score_objectness + assert cls_score.size(-1) == self.num_classes + 2 + cls_score_classes = cls_score[..., :-2] + cls_score_objectness = cls_score[..., -2:] + return cls_score_classes, cls_score_objectness + + def get_cls_channels(self, num_classes: int) -> int: + """Get custom classification channels. + + Args: + num_classes (int): The number of classes. + + Returns: + int: The custom classification channels. + """ + assert num_classes == self.num_classes + return num_classes + 2 + + def get_activation(self, cls_score: Tensor) -> Tensor: + """Get custom activation of cls_score. + + Args: + cls_score (Tensor): The prediction with shape (N, C + 2). + + Returns: + Tensor: The custom activation of cls_score with shape + (N, C + 1). + """ + cls_score_classes, cls_score_objectness = self._split_cls_score( + cls_score) + score_classes = F.softmax(cls_score_classes, dim=-1) + score_objectness = F.softmax(cls_score_objectness, dim=-1) + score_pos = score_objectness[..., [0]] + score_neg = score_objectness[..., [1]] + score_classes = score_classes * score_pos + scores = torch.cat([score_classes, score_neg], dim=-1) + return scores + + def get_accuracy(self, cls_score: Tensor, + labels: Tensor) -> Dict[str, Tensor]: + """Get custom accuracy w.r.t. cls_score and labels. + + Args: + cls_score (Tensor): The prediction with shape (N, C + 2). + labels (Tensor): The learning label of the prediction. + + Returns: + Dict [str, Tensor]: The accuracy for objectness and classes, + respectively. + """ + pos_inds = labels < self.num_classes + obj_labels = (labels == self.num_classes).long() + cls_score_classes, cls_score_objectness = self._split_cls_score( + cls_score) + acc_objectness = accuracy(cls_score_objectness, obj_labels) + acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds]) + acc = dict() + acc['acc_objectness'] = acc_objectness + acc['acc_classes'] = acc_classes + return acc + + def forward( + self, + cls_score: Tensor, + labels: Tensor, + label_weights: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None + ) -> Union[Tensor, Dict[str, Tensor]]: + """Forward function. + + Args: + cls_score (Tensor): The prediction with shape (N, C + 2). + labels (Tensor): The learning label of the prediction. + label_weights (Tensor, optional): Sample-wise loss weight. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + + Returns: + Tensor | Dict [str, Tensor]: + if return_dict == False: The calculated loss | + if return_dict == True: The dict of calculated losses + for objectness and classes, respectively. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + assert cls_score.size(-1) == self.num_classes + 2 + pos_inds = labels < self.num_classes + # 0 for pos, 1 for neg + obj_labels = (labels == self.num_classes).long() + + # accumulate the samples for each category + unique_labels = labels.unique() + for u_l in unique_labels: + inds_ = labels == u_l.item() + self.cum_samples[u_l] += inds_.sum() + + if label_weights is not None: + label_weights = label_weights.float() + else: + label_weights = labels.new_ones(labels.size(), dtype=torch.float) + + cls_score_classes, cls_score_objectness = self._split_cls_score( + cls_score) + # calculate loss_cls_classes (only need pos samples) + if pos_inds.sum() > 0: + loss_cls_classes = self.loss_weight * self.cls_criterion( + cls_score_classes[pos_inds], labels[pos_inds], + label_weights[pos_inds], self.cum_samples[:self.num_classes], + self.num_classes, self.p, self.q, self.eps, reduction, + avg_factor) + else: + loss_cls_classes = cls_score_classes[pos_inds].sum() + # calculate loss_cls_objectness + loss_cls_objectness = self.loss_weight * cross_entropy( + cls_score_objectness, obj_labels, label_weights, reduction, + avg_factor) + + if self.return_dict: + loss_cls = dict() + loss_cls['loss_cls_objectness'] = loss_cls_objectness + loss_cls['loss_cls_classes'] = loss_cls_classes + else: + loss_cls = loss_cls_classes + loss_cls_objectness + return loss_cls diff --git a/mmdet/models/losses/smooth_l1_loss.py b/mmdet/models/losses/smooth_l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..102f9780706172a44ade2ebe1709c7a1e847db7c --- /dev/null +++ b/mmdet/models/losses/smooth_l1_loss.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weighted_loss + + +@weighted_loss +def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: + """Smooth L1 loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + beta (float, optional): The threshold in the piecewise function. + Defaults to 1.0. + + Returns: + Tensor: Calculated loss + """ + assert beta > 0 + if target.numel() == 0: + return pred.sum() * 0 + + assert pred.size() == target.size() + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff * diff / beta, + diff - 0.5 * beta) + return loss + + +@weighted_loss +def l1_loss(pred: Tensor, target: Tensor) -> Tensor: + """L1 loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + + Returns: + Tensor: Calculated loss + """ + if target.numel() == 0: + return pred.sum() * 0 + + assert pred.size() == target.size() + loss = torch.abs(pred - target) + return loss + + +@MODELS.register_module() +class SmoothL1Loss(nn.Module): + """Smooth L1 loss. + + Args: + beta (float, optional): The threshold in the piecewise function. + Defaults to 1.0. + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". Defaults to "mean". + loss_weight (float, optional): The weight of loss. + """ + + def __init__(self, + beta: float = 1.0, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.beta = beta + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + + Returns: + Tensor: Calculated loss + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_bbox = self.loss_weight * smooth_l1_loss( + pred, + target, + weight, + beta=self.beta, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_bbox + + +@MODELS.register_module() +class L1Loss(nn.Module): + """L1 loss. + + Args: + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of loss. + """ + + def __init__(self, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + + Returns: + Tensor: Calculated loss + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_bbox = self.loss_weight * l1_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss_bbox diff --git a/mmdet/models/losses/triplet_loss.py b/mmdet/models/losses/triplet_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c9604b8c776d89e387ff13496c455ab89a37fb --- /dev/null +++ b/mmdet/models/losses/triplet_loss.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class TripletLoss(BaseModule): + """Triplet loss with hard positive/negative mining. + + Reference: + Hermans et al. In Defense of the Triplet Loss for + Person Re-Identification. arXiv:1703.07737. + Imported from ``_. + Args: + margin (float, optional): Margin for triplet loss. Defaults to 0.3. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + hard_mining (bool, optional): Whether to perform hard mining. + Defaults to True. + """ + + def __init__(self, + margin: float = 0.3, + loss_weight: float = 1.0, + hard_mining=True): + super(TripletLoss, self).__init__() + self.margin = margin + self.ranking_loss = nn.MarginRankingLoss(margin=margin) + self.loss_weight = loss_weight + self.hard_mining = hard_mining + + def hard_mining_triplet_loss_forward( + self, inputs: torch.Tensor, + targets: torch.LongTensor) -> torch.Tensor: + """ + Args: + inputs (torch.Tensor): feature matrix with shape + (batch_size, feat_dim). + targets (torch.LongTensor): ground truth labels with shape + (num_classes). + + Returns: + torch.Tensor: triplet loss with hard mining. + """ + + batch_size = inputs.size(0) + + # Compute Euclidean distance + dist = torch.pow(inputs, 2).sum( + dim=1, keepdim=True).expand(batch_size, batch_size) + dist = dist + dist.t() + dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) + dist = dist.clamp(min=1e-12).sqrt() # for numerical stability + + # For each anchor, find the furthest positive sample + # and nearest negative sample in the embedding space + mask = targets.expand(batch_size, batch_size).eq( + targets.expand(batch_size, batch_size).t()) + dist_ap, dist_an = [], [] + for i in range(batch_size): + dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) + dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) + dist_ap = torch.cat(dist_ap) + dist_an = torch.cat(dist_an) + + # Compute ranking hinge loss + y = torch.ones_like(dist_an) + return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y) + + def forward(self, inputs: torch.Tensor, + targets: torch.LongTensor) -> torch.Tensor: + """ + Args: + inputs (torch.Tensor): feature matrix with shape + (batch_size, feat_dim). + targets (torch.LongTensor): ground truth labels with shape + (num_classes). + + Returns: + torch.Tensor: triplet loss. + """ + if self.hard_mining: + return self.hard_mining_triplet_loss_forward(inputs, targets) + else: + raise NotImplementedError() diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6e7859f353f3e5456f0cfc1f66b4b0ad535427 --- /dev/null +++ b/mmdet/models/losses/utils.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def reduce_loss(loss: Tensor, reduction: str) -> Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss: Tensor, + weight: Optional[Tensor] = None, + reduction: str = 'mean', + avg_factor: Optional[float] = None) -> Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Optional[Tensor], optional): Element-wise weights. + Defaults to None. + reduction (str, optional): Same as built-in losses of PyTorch. + Defaults to 'mean'. + avg_factor (Optional[float], optional): Average factor when + computing the mean of losses. Defaults to None. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func: Callable) -> Callable: + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: str = 'mean', + avg_factor: Optional[int] = None, + **kwargs) -> Tensor: + """ + Args: + pred (Tensor): The prediction. + target (Tensor): Target bboxes. + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + reduction (str, optional): Options are "none", "mean" and "sum". + Defaults to 'mean'. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + + Returns: + Tensor: Loss tensor. + """ + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/mmdet/models/losses/varifocal_loss.py b/mmdet/models/losses/varifocal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..58ab167352e1ae32566f5e731339966d5fd10759 --- /dev/null +++ b/mmdet/models/losses/varifocal_loss.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import MODELS +from .utils import weight_reduce_loss + + +def varifocal_loss(pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + alpha: float = 0.75, + gamma: float = 2.0, + iou_weighted: bool = True, + reduction: str = 'mean', + avg_factor: Optional[int] = None) -> Tensor: + """`Varifocal Loss `_ + + Args: + pred (Tensor): The prediction with shape (N, C), C is the + number of classes. + target (Tensor): The learning target of the iou-aware + classification score with shape (N, C), C is the number of classes. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + alpha (float, optional): A balance factor for the negative part of + Varifocal Loss, which is different from the alpha of Focal Loss. + Defaults to 0.75. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + iou_weighted (bool, optional): Whether to weight the loss of the + positive example with the iou target. Defaults to True. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + Tensor: Loss tensor. + """ + # pred and target should be of the same size + assert pred.size() == target.size() + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + if iou_weighted: + focal_weight = target * (target > 0.0).float() + \ + alpha * (pred_sigmoid - target).abs().pow(gamma) * \ + (target <= 0.0).float() + else: + focal_weight = (target > 0.0).float() + \ + alpha * (pred_sigmoid - target).abs().pow(gamma) * \ + (target <= 0.0).float() + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class VarifocalLoss(nn.Module): + + def __init__(self, + use_sigmoid: bool = True, + alpha: float = 0.75, + gamma: float = 2.0, + iou_weighted: bool = True, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + """`Varifocal Loss `_ + + Args: + use_sigmoid (bool, optional): Whether the prediction is + used for sigmoid or softmax. Defaults to True. + alpha (float, optional): A balance factor for the negative part of + Varifocal Loss, which is different from the alpha of Focal + Loss. Defaults to 0.75. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + iou_weighted (bool, optional): Whether to weight the loss of the + positive examples with the iou target. Defaults to True. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + """ + super().__init__() + assert use_sigmoid is True, \ + 'Only sigmoid varifocal loss supported now.' + assert alpha >= 0.0 + self.use_sigmoid = use_sigmoid + self.alpha = alpha + self.gamma = gamma + self.iou_weighted = iou_weighted + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None) -> Tensor: + """Forward function. + + Args: + pred (Tensor): The prediction with shape (N, C), C is the + number of classes. + target (Tensor): The learning target of the iou-aware + classification score with shape (N, C), C is + the number of classes. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + loss_cls = self.loss_weight * varifocal_loss( + pred, + target, + weight, + alpha=self.alpha, + gamma=self.gamma, + iou_weighted=self.iou_weighted, + reduction=reduction, + avg_factor=avg_factor) + else: + raise NotImplementedError + return loss_cls diff --git a/mmdet/models/mot/__init__.py b/mmdet/models/mot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd3c8d3ba53daad736e05b5d29a6abb377fd595 --- /dev/null +++ b/mmdet/models/mot/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseMOTModel +from .bytetrack import ByteTrack +from .deep_sort import DeepSORT +from .ocsort import OCSORT +from .qdtrack import QDTrack +from .strongsort import StrongSORT + +__all__ = [ + 'BaseMOTModel', 'ByteTrack', 'QDTrack', 'DeepSORT', 'StrongSORT', 'OCSORT' +] diff --git a/mmdet/models/mot/base.py b/mmdet/models/mot/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9981417924af3970319b0cbe6a9cc8d8a1095451 --- /dev/null +++ b/mmdet/models/mot/base.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Tuple, Union + +from mmengine.model import BaseModel +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptTrackSampleList, TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class BaseMOTModel(BaseModel, metaclass=ABCMeta): + """Base class for multiple object tracking. + + Args: + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or list[dict]): Initialization config dict. + """ + + def __init__(self, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + def freeze_module(self, module: Union[List[str], Tuple[str], str]) -> None: + """Freeze module during training.""" + if isinstance(module, str): + modules = [module] + else: + if not (isinstance(module, list) or isinstance(module, tuple)): + raise TypeError('module must be a str or a list.') + else: + modules = module + for module in modules: + m = getattr(self, module) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + @property + def with_detector(self) -> bool: + """bool: whether the framework has a detector.""" + return hasattr(self, 'detector') and self.detector is not None + + @property + def with_reid(self) -> bool: + """bool: whether the framework has a reid model.""" + return hasattr(self, 'reid') and self.reid is not None + + @property + def with_motion(self) -> bool: + """bool: whether the framework has a motion model.""" + return hasattr(self, 'motion') and self.motion is not None + + @property + def with_track_head(self) -> bool: + """bool: whether the framework has a track_head.""" + return hasattr(self, 'track_head') and self.track_head is not None + + @property + def with_tracker(self) -> bool: + """bool: whether the framework has a tracker.""" + return hasattr(self, 'tracker') and self.tracker is not None + + def forward(self, + inputs: Dict[str, Tensor], + data_samples: OptTrackSampleList = None, + mode: str = 'predict', + **kwargs): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`TrackDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) + encoding input images. Typically these should be mean centered + and std scaled. The N denotes batch size. The T denotes the + number of key/reference frames. + - img (Tensor) : The key images. + - ref_img (Tensor): The reference images. + data_samples (list[:obj:`TrackDataSample`], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`TrackDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(inputs, data_samples, **kwargs) + elif mode == 'predict': + return self.predict(inputs, data_samples, **kwargs) + elif mode == 'tensor': + return self._forward(inputs, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + @abstractmethod + def loss(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, + **kwargs) -> Union[dict, tuple]: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, + **kwargs) -> TrackSampleList: + """Predict results from a batch of inputs and data samples with post- + processing.""" + pass + + def _forward(self, + inputs: Dict[str, Tensor], + data_samples: OptTrackSampleList = None, + **kwargs): + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + inputs (Dict[str, Tensor]): of shape (N, T, C, H, W). + data_samples (List[:obj:`TrackDataSample`], optional): The + Data Samples. It usually includes information such as + `gt_instance`. + + Returns: + tuple[list]: A tuple of features from ``head`` forward. + """ + raise NotImplementedError( + "_forward function (namely 'tensor' mode) is not supported now") diff --git a/mmdet/models/mot/bytetrack.py b/mmdet/models/mot/bytetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3bb867cb284aad9854de44b2942341a4a33be8 --- /dev/null +++ b/mmdet/models/mot/bytetrack.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList, TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig +from .base import BaseMOTModel + + +@MODELS.register_module() +class ByteTrack(BaseMOTModel): + """ByteTrack: Multi-Object Tracking by Associating Every Detection Box. + + This multi object tracker is the implementation of `ByteTrack + `_. + + Args: + detector (dict): Configuration of detector. Defaults to None. + tracker (dict): Configuration of tracker. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or list[dict]): Configuration of initialization. + Defaults to None. + """ + + def __init__(self, + detector: Optional[dict] = None, + tracker: Optional[dict] = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__(data_preprocessor, init_cfg) + + if detector is not None: + self.detector = MODELS.build(detector) + + if tracker is not None: + self.tracker = MODELS.build(tracker) + + def loss(self, inputs: Tensor, data_samples: SampleList, **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): of shape (N, C, H, W) encoding + input images. Typically these should be mean centered and std + scaled. The N denotes batch size + data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict: A dictionary of loss components. + """ + return self.detector.loss(inputs, data_samples, **kwargs) + + def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, + **kwargs) -> TrackSampleList: + """Predict results from a video and data samples with post-processing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of frames in a video. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + Returns: + TrackSampleList: Tracking results of the inputs. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(0) == 1, \ + 'Bytetrack inference only support ' \ + '1 batch size per gpu for now.' + + assert len(data_samples) == 1, \ + 'Bytetrack inference only support 1 batch size per gpu for now.' + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + # det_results List[DetDataSample] + det_results = self.detector.predict(single_img, [img_data_sample]) + assert len(det_results) == 1, 'Batch inference is not supported.' + + pred_track_instances = self.tracker.track( + data_sample=det_results[0], **kwargs) + img_data_sample.pred_track_instances = pred_track_instances + + return [track_data_sample] diff --git a/mmdet/models/mot/deep_sort.py b/mmdet/models/mot/deep_sort.py new file mode 100644 index 0000000000000000000000000000000000000000..70b30c7b07b2211fd0ad70767f479e57b6cd33f6 --- /dev/null +++ b/mmdet/models/mot/deep_sort.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import TrackSampleList +from mmdet.utils import OptConfigType +from .base import BaseMOTModel + + +@MODELS.register_module() +class DeepSORT(BaseMOTModel): + """Simple online and realtime tracking with a deep association metric. + + Details can be found at `DeepSORT`_. + + Args: + detector (dict): Configuration of detector. Defaults to None. + reid (dict): Configuration of reid. Defaults to None + tracker (dict): Configuration of tracker. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or list[dict]): Configuration of initialization. + Defaults to None. + """ + + def __init__(self, + detector: Optional[dict] = None, + reid: Optional[dict] = None, + tracker: Optional[dict] = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(data_preprocessor, init_cfg) + + if detector is not None: + self.detector = MODELS.build(detector) + + if reid is not None: + self.reid = MODELS.build(reid) + + if tracker is not None: + self.tracker = MODELS.build(tracker) + + self.preprocess_cfg = data_preprocessor + + def loss(self, inputs: Tensor, data_samples: TrackSampleList, + **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + raise NotImplementedError( + 'Please train `detector` and `reid` models firstly, then \ + inference with SORT/DeepSORT.') + + def predict(self, + inputs: Tensor, + data_samples: TrackSampleList, + rescale: bool = True, + **kwargs) -> TrackSampleList: + """Predict results from a video and data samples with post- processing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of key frames + and reference frames. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + TrackSampleList: List[TrackDataSample] + Tracking results of the input videos. + Each DetDataSample usually contains ``pred_track_instances``. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(0) == 1, \ + 'SORT/DeepSORT inference only support ' \ + '1 batch size per gpu for now.' + + assert len(data_samples) == 1, \ + 'SORT/DeepSORT inference only support ' \ + '1 batch size per gpu for now.' + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + if track_data_sample[0].frame_id == 0: + self.tracker.reset() + + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + # det_results List[DetDataSample] + det_results = self.detector.predict(single_img, [img_data_sample]) + assert len(det_results) == 1, 'Batch inference is not supported.' + + pred_track_instances = self.tracker.track( + model=self, + img=single_img, + feats=None, + data_sample=det_results[0], + data_preprocessor=self.preprocess_cfg, + rescale=rescale, + **kwargs) + img_data_sample.pred_track_instances = pred_track_instances + + return [track_data_sample] diff --git a/mmdet/models/mot/ocsort.py b/mmdet/models/mot/ocsort.py new file mode 100644 index 0000000000000000000000000000000000000000..abf4eb3b06e2b1b223fe948f30dac877248377e3 --- /dev/null +++ b/mmdet/models/mot/ocsort.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Dict, Optional + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig +from .base import BaseMOTModel + + +@MODELS.register_module() +class OCSORT(BaseMOTModel): + """OCOSRT: Observation-Centric SORT: Rethinking SORT for Robust + Multi-Object Tracking + + This multi object tracker is the implementation of `OC-SORT + `_. + + Args: + detector (dict): Configuration of detector. Defaults to None. + tracker (dict): Configuration of tracker. Defaults to None. + motion (dict): Configuration of motion. Defaults to None. + init_cfg (dict): Configuration of initialization. Defaults to None. + """ + + def __init__(self, + detector: Optional[dict] = None, + tracker: Optional[dict] = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__(data_preprocessor, init_cfg) + + if detector is not None: + self.detector = MODELS.build(detector) + + if tracker is not None: + self.tracker = MODELS.build(tracker) + + def loss(self, inputs: Tensor, data_samples: TrackSampleList, + **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + return self.detector.loss(inputs, data_samples, **kwargs) + + def predict(self, inputs: Dict[str, Tensor], data_samples: TrackSampleList, + **kwargs) -> TrackSampleList: + """Predict results from a video and data samples with post-processing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of frames in a video. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + Returns: + TrackSampleList: Tracking results of the inputs. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(0) == 1, \ + 'OCSORT inference only support ' \ + '1 batch size per gpu for now.' + + assert len(data_samples) == 1, \ + 'OCSORT inference only support 1 batch size per gpu for now.' + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + # det_results List[DetDataSample] + det_results = self.detector.predict(single_img, [img_data_sample]) + assert len(det_results) == 1, 'Batch inference is not supported.' + + pred_track_instances = self.tracker.track( + data_sample=det_results[0], **kwargs) + img_data_sample.pred_track_instances = pred_track_instances + + return [track_data_sample] diff --git a/mmdet/models/mot/qdtrack.py b/mmdet/models/mot/qdtrack.py new file mode 100644 index 0000000000000000000000000000000000000000..43d5dd60b8af8a6200e21a196c47d00dd2812a46 --- /dev/null +++ b/mmdet/models/mot/qdtrack.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig +from .base import BaseMOTModel + + +@MODELS.register_module() +class QDTrack(BaseMOTModel): + """Quasi-Dense Similarity Learning for Multiple Object Tracking. + + This multi object tracker is the implementation of `QDTrack + `_. + + Args: + detector (dict): Configuration of detector. Defaults to None. + track_head (dict): Configuration of track head. Defaults to None. + tracker (dict): Configuration of tracker. Defaults to None. + freeze_detector (bool): If True, freeze the detector weights. + Defaults to False. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or list[dict]): Configuration of initialization. + Defaults to None. + """ + + def __init__(self, + detector: Optional[dict] = None, + track_head: Optional[dict] = None, + tracker: Optional[dict] = None, + freeze_detector: bool = False, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__(data_preprocessor, init_cfg) + if detector is not None: + self.detector = MODELS.build(detector) + + if track_head is not None: + self.track_head = MODELS.build(track_head) + + if tracker is not None: + self.tracker = MODELS.build(tracker) + + self.freeze_detector = freeze_detector + if self.freeze_detector: + self.freeze_module('detector') + + def predict(self, + inputs: Tensor, + data_samples: TrackSampleList, + rescale: bool = True, + **kwargs) -> TrackSampleList: + """Predict results from a video and data samples with post- processing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of frames in a video. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + TrackSampleList: Tracking results of the inputs. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(0) == 1, \ + 'QDTrack inference only support 1 batch size per gpu for now.' + + assert len(data_samples) == 1, \ + 'QDTrack only support 1 batch size per gpu for now.' + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + if track_data_sample[0].frame_id == 0: + self.tracker.reset() + + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + x = self.detector.extract_feat(single_img) + rpn_results_list = self.detector.rpn_head.predict( + x, [img_data_sample]) + # det_results List[InstanceData] + det_results = self.detector.roi_head.predict( + x, rpn_results_list, [img_data_sample], rescale=rescale) + assert len(det_results) == 1, 'Batch inference is not supported.' + img_data_sample.pred_instances = det_results[0] + frame_pred_track_instances = self.tracker.track( + model=self, + img=single_img, + feats=x, + data_sample=img_data_sample, + **kwargs) + img_data_sample.pred_track_instances = frame_pred_track_instances + + return [track_data_sample] + + def loss(self, inputs: Tensor, data_samples: TrackSampleList, + **kwargs) -> Union[dict, tuple]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding + input images. Typically these should be mean centered and std + scaled. The N denotes batch size. The T denotes the number of + frames. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + + Returns: + dict: A dictionary of loss components. + """ + # modify the inputs shape to fit mmdet + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(1) == 2, \ + 'QDTrack can only have 1 key frame and 1 reference frame.' + + # split the data_samples into two aspects: key frames and reference + # frames + ref_data_samples, key_data_samples = [], [] + key_frame_inds, ref_frame_inds = [], [] + # set cat_id of gt_labels to 0 in RPN + for track_data_sample in data_samples: + key_frame_inds.append(track_data_sample.key_frames_inds[0]) + ref_frame_inds.append(track_data_sample.ref_frames_inds[0]) + key_data_sample = track_data_sample.get_key_frames()[0] + key_data_sample.gt_instances.labels = \ + torch.zeros_like(key_data_sample.gt_instances.labels) + key_data_samples.append(key_data_sample) + ref_data_sample = track_data_sample.get_ref_frames()[0] + ref_data_samples.append(ref_data_sample) + + key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64) + ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64) + batch_inds = torch.arange(len(inputs)) + key_imgs = inputs[batch_inds, key_frame_inds].contiguous() + ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous() + + x = self.detector.extract_feat(key_imgs) + ref_x = self.detector.extract_feat(ref_imgs) + + losses = dict() + # RPN head forward and loss + assert self.detector.with_rpn, \ + 'QDTrack only support detector with RPN.' + + proposal_cfg = self.detector.train_cfg.get('rpn_proposal', + self.detector.test_cfg.rpn) + rpn_losses, rpn_results_list = self.detector.rpn_head. \ + loss_and_predict(x, + key_data_samples, + proposal_cfg=proposal_cfg, + **kwargs) + ref_rpn_results_list = self.detector.rpn_head.predict( + ref_x, ref_data_samples, **kwargs) + + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in keys: + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + + # roi_head loss + losses_detect = self.detector.roi_head.loss(x, rpn_results_list, + key_data_samples, **kwargs) + losses.update(losses_detect) + + # tracking head loss + losses_track = self.track_head.loss(x, ref_x, rpn_results_list, + ref_rpn_results_list, data_samples, + **kwargs) + losses.update(losses_track) + + return losses diff --git a/mmdet/models/mot/strongsort.py b/mmdet/models/mot/strongsort.py new file mode 100644 index 0000000000000000000000000000000000000000..6129bf49972233206b3c05daa2174f99723d1b9d --- /dev/null +++ b/mmdet/models/mot/strongsort.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import TrackSampleList +from mmdet.utils import OptConfigType +from .deep_sort import DeepSORT + + +@MODELS.register_module() +class StrongSORT(DeepSORT): + """StrongSORT: Make DeepSORT Great Again. + + Details can be found at `StrongSORT`_. + + Args: + detector (dict): Configuration of detector. Defaults to None. + reid (dict): Configuration of reid. Defaults to None + tracker (dict): Configuration of tracker. Defaults to None. + kalman (dict): Configuration of Kalman filter. Defaults to None. + cmc (dict): Configuration of camera model compensation. + Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or list[dict]): Configuration of initialization. + Defaults to None. + """ + + def __init__(self, + detector: Optional[dict] = None, + reid: Optional[dict] = None, + cmc: Optional[dict] = None, + tracker: Optional[dict] = None, + postprocess_model: Optional[dict] = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(detector, reid, tracker, data_preprocessor, init_cfg) + + if cmc is not None: + self.cmc = TASK_UTILS.build(cmc) + + if postprocess_model is not None: + self.postprocess_model = TASK_UTILS.build(postprocess_model) + + @property + def with_cmc(self): + """bool: whether the framework has a camera model compensation + model. + """ + return hasattr(self, 'cmc') and self.cmc is not None + + def predict(self, + inputs: Tensor, + data_samples: TrackSampleList, + rescale: bool = True, + **kwargs) -> TrackSampleList: + """Predict results from a video and data samples with post- processing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of key frames + and reference frames. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + TrackSampleList: List[TrackDataSample] + Tracking results of the input videos. + Each DetDataSample usually contains ``pred_track_instances``. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(0) == 1, \ + 'SORT/DeepSORT inference only support ' \ + '1 batch size per gpu for now.' + + assert len(data_samples) == 1, \ + 'SORT/DeepSORT inference only support ' \ + '1 batch size per gpu for now.' + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + + video_track_instances = [] + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + # det_results List[DetDataSample] + det_results = self.detector.predict(single_img, [img_data_sample]) + assert len(det_results) == 1, 'Batch inference is not supported.' + + pred_track_instances = self.tracker.track( + model=self, + img=single_img, + data_sample=det_results[0], + data_preprocessor=self.preprocess_cfg, + rescale=rescale, + **kwargs) + for i in range(len(pred_track_instances.instances_id)): + video_track_instances.append( + np.array([ + frame_id + 1, + pred_track_instances.instances_id[i].cpu(), + pred_track_instances.bboxes[i][0].cpu(), + pred_track_instances.bboxes[i][1].cpu(), + (pred_track_instances.bboxes[i][2] - + pred_track_instances.bboxes[i][0]).cpu(), + (pred_track_instances.bboxes[i][3] - + pred_track_instances.bboxes[i][1]).cpu(), + pred_track_instances.scores[i].cpu() + ])) + video_track_instances = np.array(video_track_instances).reshape(-1, 7) + video_track_instances = self.postprocess_model.forward( + video_track_instances) + for frame_id in range(video_len): + track_data_sample[frame_id].pred_track_instances = \ + InstanceData(bboxes=video_track_instances[ + video_track_instances[:, 0] == frame_id + 1, :]) + + return [track_data_sample] diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..343fbfefbd871d00e855d1c3cf4b531345e4dcf1 --- /dev/null +++ b/mmdet/models/necks/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bfp import BFP +from .channel_mapper import ChannelMapper +from .cspnext_pafpn import CSPNeXtPAFPN +from .ct_resnet_neck import CTResNetNeck +from .dilated_encoder import DilatedEncoder +from .dyhead import DyHead +from .fpg import FPG +from .fpn import FPN +from .fpn_carafe import FPN_CARAFE +from .fpn_dropblock import FPN_DropBlock +from .hrfpn import HRFPN +from .nas_fpn import NASFPN +from .nasfcos_fpn import NASFCOS_FPN +from .pafpn import PAFPN +from .rfp import RFP +from .ssd_neck import SSDNeck +from .ssh import SSH +from .yolo_neck import YOLOV3Neck +from .yolox_pafpn import YOLOXPAFPN + +__all__ = [ + 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', + 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder', + 'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN', 'SSH', + 'FPN_DropBlock' +] diff --git a/mmdet/models/necks/bfp.py b/mmdet/models/necks/bfp.py new file mode 100644 index 0000000000000000000000000000000000000000..401cdb0f552b06c9e8eb185c3e8ae0ba7112a9d8 --- /dev/null +++ b/mmdet/models/necks/bfp.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import NonLocal2d +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class BFP(BaseModule): + """BFP (Balanced Feature Pyramids) + + BFP takes multi-level features as inputs and gather them into a single one, + then refine the gathered feature and scatter the refined results to + multi-level features. This module is used in Libra R-CNN (CVPR 2019), see + the paper `Libra R-CNN: Towards Balanced Learning for Object Detection + `_ for details. + + Args: + in_channels (int): Number of input channels (feature maps of all levels + should have the same channels). + num_levels (int): Number of input feature levels. + refine_level (int): Index of integration and refine level of BSF in + multi-level features from bottom to top. + refine_type (str): Type of the refine op, currently support + [None, 'conv', 'non_local']. + conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for + convolution layers. + norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for + normalization layers. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or + dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + num_levels: int, + refine_level: int = 2, + refine_type: str = None, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform') + ) -> None: + super().__init__(init_cfg=init_cfg) + assert refine_type in [None, 'conv', 'non_local'] + + self.in_channels = in_channels + self.num_levels = num_levels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.refine_level = refine_level + self.refine_type = refine_type + assert 0 <= self.refine_level < self.num_levels + + if self.refine_type == 'conv': + self.refine = ConvModule( + self.in_channels, + self.in_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + elif self.refine_type == 'non_local': + self.refine = NonLocal2d( + self.in_channels, + reduction=1, + use_scale=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + + def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: + """Forward function.""" + assert len(inputs) == self.num_levels + + # step 1: gather multi-level features by resize and average + feats = [] + gather_size = inputs[self.refine_level].size()[2:] + for i in range(self.num_levels): + if i < self.refine_level: + gathered = F.adaptive_max_pool2d( + inputs[i], output_size=gather_size) + else: + gathered = F.interpolate( + inputs[i], size=gather_size, mode='nearest') + feats.append(gathered) + + bsf = sum(feats) / len(feats) + + # step 2: refine gathered features + if self.refine_type is not None: + bsf = self.refine(bsf) + + # step 3: scatter refined features to multi-levels by a residual path + outs = [] + for i in range(self.num_levels): + out_size = inputs[i].size()[2:] + if i < self.refine_level: + residual = F.interpolate(bsf, size=out_size, mode='nearest') + else: + residual = F.adaptive_max_pool2d(bsf, output_size=out_size) + outs.append(residual + inputs[i]) + + return tuple(outs) diff --git a/mmdet/models/necks/channel_mapper.py b/mmdet/models/necks/channel_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..74293618f2b8a649328ae4a5a0571809de9991dd --- /dev/null +++ b/mmdet/models/necks/channel_mapper.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class ChannelMapper(BaseModule): + """Channel Mapper to reduce/increase channels of backbone features. + + This is used to reduce/increase channels of backbone features. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + kernel_size (int, optional): kernel_size for reducing channels (used + at each scale). Default: 3. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Default: None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Default: None. + act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + activation layer in ConvModule. Default: dict(type='ReLU'). + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + num_outs (int, optional): Number of output feature maps. There would + be extra_convs when num_outs larger than the length of in_channels. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or dict], + optional): Initialization config dict. + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = ChannelMapper(in_channels, 11, 3).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__( + self, + in_channels: List[int], + out_channels: int, + kernel_size: int = 3, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = dict(type='ReLU'), + bias: Union[bool, str] = 'auto', + num_outs: int = None, + init_cfg: OptMultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform') + ) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.extra_convs = None + if num_outs is None: + num_outs = len(in_channels) + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size, + padding=(kernel_size - 1) // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=bias)) + if num_outs > len(in_channels): + self.extra_convs = nn.ModuleList() + for i in range(len(in_channels), num_outs): + if i == len(in_channels): + in_channel = in_channels[-1] + else: + in_channel = out_channels + self.extra_convs.append( + ConvModule( + in_channel, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=bias)) + + def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: + """Forward function.""" + assert len(inputs) == len(self.convs) + outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] + if self.extra_convs: + for i in range(len(self.extra_convs)): + if i == 0: + outs.append(self.extra_convs[0](inputs[-1])) + else: + outs.append(self.extra_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmdet/models/necks/cspnext_pafpn.py b/mmdet/models/necks/cspnext_pafpn.py new file mode 100644 index 0000000000000000000000000000000000000000..a52ba72d9b3e48c4866fb16507bc2118eb23010e --- /dev/null +++ b/mmdet/models/necks/cspnext_pafpn.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptMultiConfig +from ..layers import CSPLayer + + +@MODELS.register_module() +class CSPNeXtPAFPN(BaseModule): + """Path Aggregation Network with CSPNeXt blocks. + + Args: + in_channels (Sequence[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_csp_blocks (int): Number of bottlenecks in CSPLayer. + Defaults to 3. + use_depthwise (bool): Whether to use depthwise separable convolution in + blocks. Defaults to False. + expand_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(scale_factor=2, mode='nearest')` + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + act_cfg (dict): Config dict for activation layer. + Default: dict(type='Swish') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__( + self, + in_channels: Sequence[int], + out_channels: int, + num_csp_blocks: int = 3, + use_depthwise: bool = False, + expand_ratio: float = 0.5, + upsample_cfg: ConfigType = dict(scale_factor=2, mode='nearest'), + conv_cfg: bool = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='Swish'), + init_cfg: OptMultiConfig = dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu') + ) -> None: + super().__init__(init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + + # build top-down blocks + self.upsample = nn.Upsample(**upsample_cfg) + self.reduce_layers = nn.ModuleList() + self.top_down_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1, 0, -1): + self.reduce_layers.append( + ConvModule( + in_channels[idx], + in_channels[idx - 1], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.top_down_blocks.append( + CSPLayer( + in_channels[idx - 1] * 2, + in_channels[idx - 1], + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + use_cspnext_block=True, + expand_ratio=expand_ratio, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv( + in_channels[idx], + in_channels[idx], + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottom_up_blocks.append( + CSPLayer( + in_channels[idx] * 2, + in_channels[idx + 1], + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + use_cspnext_block=True, + expand_ratio=expand_ratio, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.out_convs = nn.ModuleList() + for i in range(len(in_channels)): + self.out_convs.append( + conv( + in_channels[i], + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + """ + Args: + inputs (tuple[Tensor]): input features. + + Returns: + tuple[Tensor]: YOLOXPAFPN features. + """ + assert len(inputs) == len(self.in_channels) + + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx]( + feat_heigh) + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + torch.cat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx]( + torch.cat([downsample_feat, feat_height], 1)) + outs.append(out) + + # out convs + for idx, conv in enumerate(self.out_convs): + outs[idx] = conv(outs[idx]) + + return tuple(outs) diff --git a/mmdet/models/necks/ct_resnet_neck.py b/mmdet/models/necks/ct_resnet_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..9109fe79290fafecd954f223d5365ef619c0c301 --- /dev/null +++ b/mmdet/models/necks/ct_resnet_neck.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from mmdet.utils import OptMultiConfig + + +@MODELS.register_module() +class CTResNetNeck(BaseModule): + """The neck used in `CenterNet `_ for + object classification and box regression. + + Args: + in_channels (int): Number of input channels. + num_deconv_filters (tuple[int]): Number of filters per stage. + num_deconv_kernels (tuple[int]): Number of kernels per stage. + use_dcn (bool): If True, use DCNv2. Defaults to True. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`], optional): Initialization + config dict. + """ + + def __init__(self, + in_channels: int, + num_deconv_filters: Tuple[int, ...], + num_deconv_kernels: Tuple[int, ...], + use_dcn: bool = True, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + assert len(num_deconv_filters) == len(num_deconv_kernels) + self.fp16_enabled = False + self.use_dcn = use_dcn + self.in_channels = in_channels + self.deconv_layers = self._make_deconv_layer(num_deconv_filters, + num_deconv_kernels) + + def _make_deconv_layer( + self, num_deconv_filters: Tuple[int, ...], + num_deconv_kernels: Tuple[int, ...]) -> nn.Sequential: + """use deconv layers to upsample backbone's output.""" + layers = [] + for i in range(len(num_deconv_filters)): + feat_channels = num_deconv_filters[i] + conv_module = ConvModule( + self.in_channels, + feat_channels, + 3, + padding=1, + conv_cfg=dict(type='DCNv2') if self.use_dcn else None, + norm_cfg=dict(type='BN')) + layers.append(conv_module) + upsample_module = ConvModule( + feat_channels, + feat_channels, + num_deconv_kernels[i], + stride=2, + padding=1, + conv_cfg=dict(type='deconv'), + norm_cfg=dict(type='BN')) + layers.append(upsample_module) + self.in_channels = feat_channels + + return nn.Sequential(*layers) + + def init_weights(self) -> None: + """Initialize the parameters.""" + for m in self.modules(): + if isinstance(m, nn.ConvTranspose2d): + # In order to be consistent with the source code, + # reset the ConvTranspose2d initialization parameters + m.reset_parameters() + # Simulated bilinear upsampling kernel + w = m.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * ( + 1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + # self.use_dcn is False + elif not self.use_dcn and isinstance(m, nn.Conv2d): + # In order to be consistent with the source code, + # reset the Conv2d initialization parameters + m.reset_parameters() + + def forward(self, x: Sequence[torch.Tensor]) -> Tuple[torch.Tensor]: + """model forward.""" + assert isinstance(x, (list, tuple)) + outs = self.deconv_layers(x[-1]) + return outs, diff --git a/mmdet/models/necks/dilated_encoder.py b/mmdet/models/necks/dilated_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e9beb3ea9b4289da8d0100ae7759927f045829bb --- /dev/null +++ b/mmdet/models/necks/dilated_encoder.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, is_norm +from mmengine.model import caffe2_xavier_init, constant_init, normal_init +from torch.nn import BatchNorm2d + +from mmdet.registry import MODELS + + +class Bottleneck(nn.Module): + """Bottleneck block for DilatedEncoder used in `YOLOF. + + `. + + The Bottleneck contains three ConvLayers and one residual connection. + + Args: + in_channels (int): The number of input channels. + mid_channels (int): The number of middle output channels. + dilation (int): Dilation rate. + norm_cfg (dict): Dictionary to construct and config norm layer. + """ + + def __init__(self, + in_channels, + mid_channels, + dilation, + norm_cfg=dict(type='BN', requires_grad=True)): + super(Bottleneck, self).__init__() + self.conv1 = ConvModule( + in_channels, mid_channels, 1, norm_cfg=norm_cfg) + self.conv2 = ConvModule( + mid_channels, + mid_channels, + 3, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg) + self.conv3 = ConvModule( + mid_channels, in_channels, 1, norm_cfg=norm_cfg) + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + out = out + identity + return out + + +@MODELS.register_module() +class DilatedEncoder(nn.Module): + """Dilated Encoder for YOLOF `. + + This module contains two types of components: + - the original FPN lateral convolution layer and fpn convolution layer, + which are 1x1 conv + 3x3 conv + - the dilated residual block + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + block_mid_channels (int): The number of middle block output channels + num_residual_blocks (int): The number of residual blocks. + block_dilations (list): The list of residual blocks dilation. + """ + + def __init__(self, in_channels, out_channels, block_mid_channels, + num_residual_blocks, block_dilations): + super(DilatedEncoder, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.block_mid_channels = block_mid_channels + self.num_residual_blocks = num_residual_blocks + self.block_dilations = block_dilations + self._init_layers() + + def _init_layers(self): + self.lateral_conv = nn.Conv2d( + self.in_channels, self.out_channels, kernel_size=1) + self.lateral_norm = BatchNorm2d(self.out_channels) + self.fpn_conv = nn.Conv2d( + self.out_channels, self.out_channels, kernel_size=3, padding=1) + self.fpn_norm = BatchNorm2d(self.out_channels) + encoder_blocks = [] + for i in range(self.num_residual_blocks): + dilation = self.block_dilations[i] + encoder_blocks.append( + Bottleneck( + self.out_channels, + self.block_mid_channels, + dilation=dilation)) + self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks) + + def init_weights(self): + caffe2_xavier_init(self.lateral_conv) + caffe2_xavier_init(self.fpn_conv) + for m in [self.lateral_norm, self.fpn_norm]: + constant_init(m, 1) + for m in self.dilated_encoder_blocks.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + + def forward(self, feature): + out = self.lateral_norm(self.lateral_conv(feature[-1])) + out = self.fpn_norm(self.fpn_conv(out)) + return self.dilated_encoder_blocks(out), diff --git a/mmdet/models/necks/dyhead.py b/mmdet/models/necks/dyhead.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5ae0b285c20558a0c7bcc59cbb7b214684eab2 --- /dev/null +++ b/mmdet/models/necks/dyhead.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d +from mmengine.model import BaseModule, constant_init, normal_init + +from mmdet.registry import MODELS +from ..layers import DyReLU + +# Reference: +# https://github.com/microsoft/DynamicHead +# https://github.com/jshilong/SEPC + + +class DyDCNv2(nn.Module): + """ModulatedDeformConv2d with normalization layer used in DyHead. + + This module cannot be configured with `conv_cfg=dict(type='DCNv2')` + because DyHead calculates offset and mask from middle-level feature. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int | tuple[int], optional): Stride of the convolution. + Default: 1. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='GN', num_groups=16, requires_grad=True). + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)): + super().__init__() + self.with_norm = norm_cfg is not None + bias = not self.with_norm + self.conv = ModulatedDeformConv2d( + in_channels, out_channels, 3, stride=stride, padding=1, bias=bias) + if self.with_norm: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + + def forward(self, x, offset, mask): + """Forward function.""" + x = self.conv(x.contiguous(), offset, mask) + if self.with_norm: + x = self.norm(x) + return x + + +class DyHeadBlock(nn.Module): + """DyHead Block with three types of attention. + + HSigmoid arguments in default act_cfg follow official code, not paper. + https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + zero_init_offset (bool, optional): Whether to use zero init for + `spatial_conv_offset`. Default: True. + act_cfg (dict, optional): Config dict for the last activation layer of + scale-aware attention. Default: dict(type='HSigmoid', bias=3.0, + divisor=6.0). + """ + + def __init__(self, + in_channels, + out_channels, + zero_init_offset=True, + act_cfg=dict(type='HSigmoid', bias=3.0, divisor=6.0)): + super().__init__() + self.zero_init_offset = zero_init_offset + # (offset_x, offset_y, mask) * kernel_size_y * kernel_size_x + self.offset_and_mask_dim = 3 * 3 * 3 + self.offset_dim = 2 * 3 * 3 + + self.spatial_conv_high = DyDCNv2(in_channels, out_channels) + self.spatial_conv_mid = DyDCNv2(in_channels, out_channels) + self.spatial_conv_low = DyDCNv2(in_channels, out_channels, stride=2) + self.spatial_conv_offset = nn.Conv2d( + in_channels, self.offset_and_mask_dim, 3, padding=1) + self.scale_attn_module = nn.Sequential( + nn.AdaptiveAvgPool2d(1), nn.Conv2d(out_channels, 1, 1), + nn.ReLU(inplace=True), build_activation_layer(act_cfg)) + self.task_attn_module = DyReLU(out_channels) + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, 0, 0.01) + if self.zero_init_offset: + constant_init(self.spatial_conv_offset, 0) + + def forward(self, x): + """Forward function.""" + outs = [] + for level in range(len(x)): + # calculate offset and mask of DCNv2 from middle-level feature + offset_and_mask = self.spatial_conv_offset(x[level]) + offset = offset_and_mask[:, :self.offset_dim, :, :] + mask = offset_and_mask[:, self.offset_dim:, :, :].sigmoid() + + mid_feat = self.spatial_conv_mid(x[level], offset, mask) + sum_feat = mid_feat * self.scale_attn_module(mid_feat) + summed_levels = 1 + if level > 0: + low_feat = self.spatial_conv_low(x[level - 1], offset, mask) + sum_feat += low_feat * self.scale_attn_module(low_feat) + summed_levels += 1 + if level < len(x) - 1: + # this upsample order is weird, but faster than natural order + # https://github.com/microsoft/DynamicHead/issues/25 + high_feat = F.interpolate( + self.spatial_conv_high(x[level + 1], offset, mask), + size=x[level].shape[-2:], + mode='bilinear', + align_corners=True) + sum_feat += high_feat * self.scale_attn_module(high_feat) + summed_levels += 1 + outs.append(self.task_attn_module(sum_feat / summed_levels)) + + return outs + + +@MODELS.register_module() +class DyHead(BaseModule): + """DyHead neck consisting of multiple DyHead Blocks. + + See `Dynamic Head: Unifying Object Detection Heads with Attentions + `_ for details. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_blocks (int, optional): Number of DyHead Blocks. Default: 6. + zero_init_offset (bool, optional): Whether to use zero init for + `spatial_conv_offset`. Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_blocks=6, + zero_init_offset=True, + init_cfg=None): + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.zero_init_offset = zero_init_offset + + dyhead_blocks = [] + for i in range(num_blocks): + in_channels = self.in_channels if i == 0 else self.out_channels + dyhead_blocks.append( + DyHeadBlock( + in_channels, + self.out_channels, + zero_init_offset=zero_init_offset)) + self.dyhead_blocks = nn.Sequential(*dyhead_blocks) + + def forward(self, inputs): + """Forward function.""" + assert isinstance(inputs, (tuple, list)) + outs = self.dyhead_blocks(inputs) + return tuple(outs) diff --git a/mmdet/models/necks/fpg.py b/mmdet/models/necks/fpg.py new file mode 100644 index 0000000000000000000000000000000000000000..73ee799bb83645ab2556fe871dcd8b1c5bbff89e --- /dev/null +++ b/mmdet/models/necks/fpg.py @@ -0,0 +1,406 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmdet.registry import MODELS + + +class Transition(BaseModule): + """Base class for transition. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + + def __init__(self, in_channels, out_channels, init_cfg=None): + super().__init__(init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(x): + pass + + +class UpInterpolationConv(Transition): + """A transition used for up-sampling. + + Up-sample the input by interpolation then refines the feature by + a convolution layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Up-sampling factor. Default: 2. + mode (int): Interpolation mode. Default: nearest. + align_corners (bool): Whether align corners when interpolation. + Default: None. + kernel_size (int): Kernel size for the conv. Default: 3. + """ + + def __init__(self, + in_channels, + out_channels, + scale_factor=2, + mode='nearest', + align_corners=None, + kernel_size=3, + init_cfg=None, + **kwargs): + super().__init__(in_channels, out_channels, init_cfg) + self.mode = mode + self.scale_factor = scale_factor + self.align_corners = align_corners + self.conv = ConvModule( + in_channels, + out_channels, + kernel_size, + padding=(kernel_size - 1) // 2, + **kwargs) + + def forward(self, x): + x = F.interpolate( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners) + x = self.conv(x) + return x + + +class LastConv(Transition): + """A transition used for refining the output of the last stage. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_inputs (int): Number of inputs of the FPN features. + kernel_size (int): Kernel size for the conv. Default: 3. + """ + + def __init__(self, + in_channels, + out_channels, + num_inputs, + kernel_size=3, + init_cfg=None, + **kwargs): + super().__init__(in_channels, out_channels, init_cfg) + self.num_inputs = num_inputs + self.conv_out = ConvModule( + in_channels, + out_channels, + kernel_size, + padding=(kernel_size - 1) // 2, + **kwargs) + + def forward(self, inputs): + assert len(inputs) == self.num_inputs + return self.conv_out(inputs[-1]) + + +@MODELS.register_module() +class FPG(BaseModule): + """FPG. + + Implementation of `Feature Pyramid Grids (FPG) + `_. + This implementation only gives the basic structure stated in the paper. + But users can implement different type of transitions to fully explore the + the potential power of the structure of FPG. + + Args: + in_channels (int): Number of input channels (feature maps of all levels + should have the same channels). + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + stack_times (int): The number of times the pyramid architecture will + be stacked. + paths (list[str]): Specify the path order of each stack level. + Each element in the list should be either 'bu' (bottom-up) or + 'td' (top-down). + inter_channels (int): Number of inter channels. + same_up_trans (dict): Transition that goes down at the same stage. + same_down_trans (dict): Transition that goes up at the same stage. + across_lateral_trans (dict): Across-pathway same-stage + across_down_trans (dict): Across-pathway bottom-up connection. + across_up_trans (dict): Across-pathway top-down connection. + across_skip_trans (dict): Across-pathway skip connection. + output_trans (dict): Transition that trans the output of the + last stage. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool): It decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + norm_cfg (dict): Config dict for normalization layer. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + transition_types = { + 'conv': ConvModule, + 'interpolation_conv': UpInterpolationConv, + 'last_conv': LastConv, + } + + def __init__(self, + in_channels, + out_channels, + num_outs, + stack_times, + paths, + inter_channels=None, + same_down_trans=None, + same_up_trans=dict( + type='conv', kernel_size=3, stride=2, padding=1), + across_lateral_trans=dict(type='conv', kernel_size=1), + across_down_trans=dict(type='conv', kernel_size=3), + across_up_trans=None, + across_skip_trans=dict(type='identity'), + output_trans=dict(type='last_conv', kernel_size=3), + start_level=0, + end_level=-1, + add_extra_convs=False, + norm_cfg=None, + skip_inds=None, + init_cfg=[ + dict(type='Caffe2Xavier', layer='Conv2d'), + dict( + type='Constant', + layer=[ + '_BatchNorm', '_InstanceNorm', 'GroupNorm', + 'LayerNorm' + ], + val=1.0) + ]): + super(FPG, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + if inter_channels is None: + self.inter_channels = [out_channels for _ in range(num_outs)] + elif isinstance(inter_channels, int): + self.inter_channels = [inter_channels for _ in range(num_outs)] + else: + assert isinstance(inter_channels, list) + assert len(inter_channels) == num_outs + self.inter_channels = inter_channels + self.stack_times = stack_times + self.paths = paths + assert isinstance(paths, list) and len(paths) == stack_times + for d in paths: + assert d in ('bu', 'td') + + self.same_down_trans = same_down_trans + self.same_up_trans = same_up_trans + self.across_lateral_trans = across_lateral_trans + self.across_down_trans = across_down_trans + self.across_up_trans = across_up_trans + self.output_trans = output_trans + self.across_skip_trans = across_skip_trans + + self.with_bias = norm_cfg is None + # skip inds must be specified if across skip trans is not None + if self.across_skip_trans is not None: + skip_inds is not None + self.skip_inds = skip_inds + assert len(self.skip_inds[0]) <= self.stack_times + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + + # build lateral 1x1 convs to reduce channels + self.lateral_convs = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + l_conv = nn.Conv2d(self.in_channels[i], + self.inter_channels[i - self.start_level], 1) + self.lateral_convs.append(l_conv) + + extra_levels = num_outs - self.backbone_end_level + self.start_level + self.extra_downsamples = nn.ModuleList() + for i in range(extra_levels): + if self.add_extra_convs: + fpn_idx = self.backbone_end_level - self.start_level + i + extra_conv = nn.Conv2d( + self.inter_channels[fpn_idx - 1], + self.inter_channels[fpn_idx], + 3, + stride=2, + padding=1) + self.extra_downsamples.append(extra_conv) + else: + self.extra_downsamples.append(nn.MaxPool2d(1, stride=2)) + + self.fpn_transitions = nn.ModuleList() # stack times + for s in range(self.stack_times): + stage_trans = nn.ModuleList() # num of feature levels + for i in range(self.num_outs): + # same, across_lateral, across_down, across_up + trans = nn.ModuleDict() + if s in self.skip_inds[i]: + stage_trans.append(trans) + continue + # build same-stage down trans (used in bottom-up paths) + if i == 0 or self.same_up_trans is None: + same_up_trans = None + else: + same_up_trans = self.build_trans( + self.same_up_trans, self.inter_channels[i - 1], + self.inter_channels[i]) + trans['same_up'] = same_up_trans + # build same-stage up trans (used in top-down paths) + if i == self.num_outs - 1 or self.same_down_trans is None: + same_down_trans = None + else: + same_down_trans = self.build_trans( + self.same_down_trans, self.inter_channels[i + 1], + self.inter_channels[i]) + trans['same_down'] = same_down_trans + # build across lateral trans + across_lateral_trans = self.build_trans( + self.across_lateral_trans, self.inter_channels[i], + self.inter_channels[i]) + trans['across_lateral'] = across_lateral_trans + # build across down trans + if i == self.num_outs - 1 or self.across_down_trans is None: + across_down_trans = None + else: + across_down_trans = self.build_trans( + self.across_down_trans, self.inter_channels[i + 1], + self.inter_channels[i]) + trans['across_down'] = across_down_trans + # build across up trans + if i == 0 or self.across_up_trans is None: + across_up_trans = None + else: + across_up_trans = self.build_trans( + self.across_up_trans, self.inter_channels[i - 1], + self.inter_channels[i]) + trans['across_up'] = across_up_trans + if self.across_skip_trans is None: + across_skip_trans = None + else: + across_skip_trans = self.build_trans( + self.across_skip_trans, self.inter_channels[i - 1], + self.inter_channels[i]) + trans['across_skip'] = across_skip_trans + # build across_skip trans + stage_trans.append(trans) + self.fpn_transitions.append(stage_trans) + + self.output_transition = nn.ModuleList() # output levels + for i in range(self.num_outs): + trans = self.build_trans( + self.output_trans, + self.inter_channels[i], + self.out_channels, + num_inputs=self.stack_times + 1) + self.output_transition.append(trans) + + self.relu = nn.ReLU(inplace=True) + + def build_trans(self, cfg, in_channels, out_channels, **extra_args): + cfg_ = cfg.copy() + trans_type = cfg_.pop('type') + trans_cls = self.transition_types[trans_type] + return trans_cls(in_channels, out_channels, **cfg_, **extra_args) + + def fuse(self, fuse_dict): + out = None + for item in fuse_dict.values(): + if item is not None: + if out is None: + out = item + else: + out = out + item + return out + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build all levels from original feature maps + feats = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + for downsample in self.extra_downsamples: + feats.append(downsample(feats[-1])) + + outs = [feats] + + for i in range(self.stack_times): + current_outs = outs[-1] + next_outs = [] + direction = self.paths[i] + for j in range(self.num_outs): + if i in self.skip_inds[j]: + next_outs.append(outs[-1][j]) + continue + # feature level + if direction == 'td': + lvl = self.num_outs - j - 1 + else: + lvl = j + # get transitions + if direction == 'td': + same_trans = self.fpn_transitions[i][lvl]['same_down'] + else: + same_trans = self.fpn_transitions[i][lvl]['same_up'] + across_lateral_trans = self.fpn_transitions[i][lvl][ + 'across_lateral'] + across_down_trans = self.fpn_transitions[i][lvl]['across_down'] + across_up_trans = self.fpn_transitions[i][lvl]['across_up'] + across_skip_trans = self.fpn_transitions[i][lvl]['across_skip'] + # init output + to_fuse = dict( + same=None, lateral=None, across_up=None, across_down=None) + # same downsample/upsample + if same_trans is not None: + to_fuse['same'] = same_trans(next_outs[-1]) + # across lateral + if across_lateral_trans is not None: + to_fuse['lateral'] = across_lateral_trans( + current_outs[lvl]) + # across downsample + if lvl > 0 and across_up_trans is not None: + to_fuse['across_up'] = across_up_trans(current_outs[lvl - + 1]) + # across upsample + if (lvl < self.num_outs - 1 and across_down_trans is not None): + to_fuse['across_down'] = across_down_trans( + current_outs[lvl + 1]) + if across_skip_trans is not None: + to_fuse['across_skip'] = across_skip_trans(outs[0][lvl]) + x = self.fuse(to_fuse) + next_outs.append(x) + + if direction == 'td': + outs.append(next_outs[::-1]) + else: + outs.append(next_outs) + + # output trans + final_outs = [] + for i in range(self.num_outs): + lvl_out_list = [] + for s in range(len(outs)): + lvl_out_list.append(outs[s][i]) + lvl_out = self.output_transition[i](lvl_out_list) + final_outs.append(lvl_out) + + return final_outs diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..67bd8879641f8539f329e6ffb94f88d25e417244 --- /dev/null +++ b/mmdet/models/necks/fpn.py @@ -0,0 +1,221 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, MultiConfig, OptConfigType + + +@MODELS.register_module() +class FPN(BaseModule): + r"""Feature Pyramid Network. + + This is an implementation of paper `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Defaults to 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Defaults to -1, which means the + last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Defaults to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Defaults to False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Defaults to False. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Defaults to None. + act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + activation layer in ConvModule. Defaults to None. + upsample_cfg (:obj:`ConfigDict` or dict, optional): Config dict + for interpolate layer. Defaults to dict(mode='nearest'). + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__( + self, + in_channels: List[int], + out_channels: int, + num_outs: int, + start_level: int = 0, + end_level: int = -1, + add_extra_convs: Union[bool, str] = False, + relu_before_extra_convs: bool = False, + no_norm_on_lateral: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = None, + upsample_cfg: ConfigType = dict(mode='nearest'), + init_cfg: MultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform') + ) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + def forward(self, inputs: Tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmdet/models/necks/fpn_carafe.py b/mmdet/models/necks/fpn_carafe.py new file mode 100644 index 0000000000000000000000000000000000000000..b393ff7c340c0c343fc4c91a4d87d341f66a3177 --- /dev/null +++ b/mmdet/models/necks/fpn_carafe.py @@ -0,0 +1,275 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer +from mmcv.ops.carafe import CARAFEPack +from mmengine.model import BaseModule, ModuleList, xavier_init + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class FPN_CARAFE(BaseModule): + """FPN_CARAFE is a more flexible implementation of FPN. It allows more + choice for upsample methods during the top-down pathway. + + It can reproduce the performance of ICCV 2019 paper + CARAFE: Content-Aware ReAssembly of FEatures + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + in_channels (list[int]): Number of channels for each input feature map. + out_channels (int): Output channels of feature pyramids. + num_outs (int): Number of output stages. + start_level (int): Start level of feature pyramids. + (Default: 0) + end_level (int): End level of feature pyramids. + (Default: -1 indicates the last level). + norm_cfg (dict): Dictionary to construct and config norm layer. + activate (str): Type of activation function in ConvModule + (Default: None indicates w/o activation). + order (dict): Order of components in ConvModule. + upsample (str): Type of upsample layer. + upsample_cfg (dict): Dictionary to construct and config upsample layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + norm_cfg=None, + act_cfg=None, + order=('conv', 'norm', 'act'), + upsample_cfg=dict( + type='carafe', + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1), + init_cfg=None): + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super(FPN_CARAFE, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_bias = norm_cfg is None + self.upsample_cfg = upsample_cfg.copy() + self.upsample = self.upsample_cfg.get('type') + self.relu = nn.ReLU(inplace=False) + + self.order = order + assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')] + + assert self.upsample in [ + 'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None + ] + if self.upsample in ['deconv', 'pixel_shuffle']: + assert hasattr( + self.upsample_cfg, + 'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0 + self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel') + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + + self.lateral_convs = ModuleList() + self.fpn_convs = ModuleList() + self.upsample_modules = ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + norm_cfg=norm_cfg, + bias=self.with_bias, + act_cfg=act_cfg, + inplace=False, + order=self.order) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + bias=self.with_bias, + act_cfg=act_cfg, + inplace=False, + order=self.order) + if i != self.backbone_end_level - 1: + upsample_cfg_ = self.upsample_cfg.copy() + if self.upsample == 'deconv': + upsample_cfg_.update( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=self.upsample_kernel, + stride=2, + padding=(self.upsample_kernel - 1) // 2, + output_padding=(self.upsample_kernel - 1) // 2) + elif self.upsample == 'pixel_shuffle': + upsample_cfg_.update( + in_channels=out_channels, + out_channels=out_channels, + scale_factor=2, + upsample_kernel=self.upsample_kernel) + elif self.upsample == 'carafe': + upsample_cfg_.update(channels=out_channels, scale_factor=2) + else: + # suppress warnings + align_corners = (None + if self.upsample == 'nearest' else False) + upsample_cfg_.update( + scale_factor=2, + mode=self.upsample, + align_corners=align_corners) + upsample_module = build_upsample_layer(upsample_cfg_) + self.upsample_modules.append(upsample_module) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_out_levels = ( + num_outs - self.backbone_end_level + self.start_level) + if extra_out_levels >= 1: + for i in range(extra_out_levels): + in_channels = ( + self.in_channels[self.backbone_end_level - + 1] if i == 0 else out_channels) + extra_l_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + bias=self.with_bias, + act_cfg=act_cfg, + inplace=False, + order=self.order) + if self.upsample == 'deconv': + upsampler_cfg_ = dict( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=self.upsample_kernel, + stride=2, + padding=(self.upsample_kernel - 1) // 2, + output_padding=(self.upsample_kernel - 1) // 2) + elif self.upsample == 'pixel_shuffle': + upsampler_cfg_ = dict( + in_channels=out_channels, + out_channels=out_channels, + scale_factor=2, + upsample_kernel=self.upsample_kernel) + elif self.upsample == 'carafe': + upsampler_cfg_ = dict( + channels=out_channels, + scale_factor=2, + **self.upsample_cfg) + else: + # suppress warnings + align_corners = (None + if self.upsample == 'nearest' else False) + upsampler_cfg_ = dict( + scale_factor=2, + mode=self.upsample, + align_corners=align_corners) + upsampler_cfg_['type'] = self.upsample + upsample_module = build_upsample_layer(upsampler_cfg_) + extra_fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + bias=self.with_bias, + act_cfg=act_cfg, + inplace=False, + order=self.order) + self.upsample_modules.append(upsample_module) + self.fpn_convs.append(extra_fpn_conv) + self.lateral_convs.append(extra_l_conv) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + """Initialize the weights of module.""" + super(FPN_CARAFE, self).init_weights() + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + xavier_init(m, distribution='uniform') + for m in self.modules(): + if isinstance(m, CARAFEPack): + m.init_weights() + + def slice_as(self, src, dst): + """Slice ``src`` as ``dst`` + + Note: + ``src`` should have the same or larger size than ``dst``. + + Args: + src (torch.Tensor): Tensors to be sliced. + dst (torch.Tensor): ``src`` will be sliced to have the same + size as ``dst``. + + Returns: + torch.Tensor: Sliced tensor. + """ + assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3)) + if src.size(2) == dst.size(2) and src.size(3) == dst.size(3): + return src + else: + return src[:, :, :dst.size(2), :dst.size(3)] + + def tensor_add(self, a, b): + """Add tensors ``a`` and ``b`` that might have different sizes.""" + if a.size() == b.size(): + c = a + b + else: + c = a + self.slice_as(b, a) + return c + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [] + for i, lateral_conv in enumerate(self.lateral_convs): + if i <= self.backbone_end_level - self.start_level: + input = inputs[min(i + self.start_level, len(inputs) - 1)] + else: + input = laterals[-1] + lateral = lateral_conv(input) + laterals.append(lateral) + + # build top-down path + for i in range(len(laterals) - 1, 0, -1): + if self.upsample is not None: + upsample_feat = self.upsample_modules[i - 1](laterals[i]) + else: + upsample_feat = laterals[i] + laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat) + + # build outputs + num_conv_outs = len(self.fpn_convs) + outs = [] + for i in range(num_conv_outs): + out = self.fpn_convs[i](laterals[i]) + outs.append(out) + return tuple(outs) diff --git a/mmdet/models/necks/fpn_dropblock.py b/mmdet/models/necks/fpn_dropblock.py new file mode 100644 index 0000000000000000000000000000000000000000..473af924cdaaecf88aa4a0a6e1500511530b91a2 --- /dev/null +++ b/mmdet/models/necks/fpn_dropblock.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import MODELS +from .fpn import FPN + + +@MODELS.register_module() +class FPN_DropBlock(FPN): + + def __init__(self, + *args, + plugin: Optional[dict] = dict( + type='DropBlock', + drop_prob=0.3, + block_size=3, + warmup_iters=0), + **kwargs) -> None: + super().__init__(*args, **kwargs) + self.plugin = None + if plugin is not None: + self.plugin = MODELS.build(plugin) + + def forward(self, inputs: Tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + if self.plugin is not None: + laterals[i - 1] = self.plugin(laterals[i - 1]) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmdet/models/necks/hrfpn.py b/mmdet/models/necks/hrfpn.py new file mode 100644 index 0000000000000000000000000000000000000000..d2627549b4cb8acc6833bc40425e459c28aa5c20 --- /dev/null +++ b/mmdet/models/necks/hrfpn.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.utils.checkpoint import checkpoint + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class HRFPN(BaseModule): + """HRFPN (High Resolution Feature Pyramids) + + paper: `High-Resolution Representations for Labeling Pixels and Regions + `_. + + Args: + in_channels (list): number of channels for each branch. + out_channels (int): output channels of feature pyramids. + num_outs (int): number of output stages. + pooling_type (str): pooling for generating feature pyramids + from {MAX, AVG}. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + stride (int): stride of 3x3 convolutional layers + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + num_outs=5, + pooling_type='AVG', + conv_cfg=None, + norm_cfg=None, + with_cp=False, + stride=1, + init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')): + super(HRFPN, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.reduction_conv = ConvModule( + sum(in_channels), + out_channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + act_cfg=None) + + self.fpn_convs = nn.ModuleList() + for i in range(self.num_outs): + self.fpn_convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=stride, + conv_cfg=self.conv_cfg, + act_cfg=None)) + + if pooling_type == 'MAX': + self.pooling = F.max_pool2d + else: + self.pooling = F.avg_pool2d + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == self.num_ins + outs = [inputs[0]] + for i in range(1, self.num_ins): + outs.append( + F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear')) + out = torch.cat(outs, dim=1) + if out.requires_grad and self.with_cp: + out = checkpoint(self.reduction_conv, out) + else: + out = self.reduction_conv(out) + outs = [out] + for i in range(1, self.num_outs): + outs.append(self.pooling(out, kernel_size=2**i, stride=2**i)) + outputs = [] + + for i in range(self.num_outs): + if outs[i].requires_grad and self.with_cp: + tmp_out = checkpoint(self.fpn_convs[i], outs[i]) + else: + tmp_out = self.fpn_convs[i](outs[i]) + outputs.append(tmp_out) + return tuple(outputs) diff --git a/mmdet/models/necks/nas_fpn.py b/mmdet/models/necks/nas_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec90cd6eed3aa65a3a192d332cbfd8c16d5bc36 --- /dev/null +++ b/mmdet/models/necks/nas_fpn.py @@ -0,0 +1,171 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import MultiConfig, OptConfigType + + +@MODELS.register_module() +class NASFPN(BaseModule): + """NAS-FPN. + + Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture + for Object Detection `_ + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + stack_times (int): The number of times the pyramid architecture will + be stacked. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Defaults to 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Defaults to -1, which means the + last level. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ + + def __init__( + self, + in_channels: List[int], + out_channels: int, + num_outs: int, + stack_times: int, + start_level: int = 0, + end_level: int = -1, + norm_cfg: OptConfigType = None, + init_cfg: MultiConfig = dict(type='Caffe2Xavier', layer='Conv2d') + ) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) # num of input feature levels + self.num_outs = num_outs # num of output feature levels + self.stack_times = stack_times + self.norm_cfg = norm_cfg + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + + # add lateral connections + self.lateral_convs = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None) + self.lateral_convs.append(l_conv) + + # add extra downsample layers (stride-2 pooling or conv) + extra_levels = num_outs - self.backbone_end_level + self.start_level + self.extra_downsamples = nn.ModuleList() + for i in range(extra_levels): + extra_conv = ConvModule( + out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + self.extra_downsamples.append( + nn.Sequential(extra_conv, nn.MaxPool2d(2, 2))) + + # add NAS FPN connections + self.fpn_stages = ModuleList() + for _ in range(self.stack_times): + stage = nn.ModuleDict() + # gp(p6, p4) -> p4_1 + stage['gp_64_4'] = GlobalPoolingCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + # sum(p4_1, p4) -> p4_2 + stage['sum_44_4'] = SumCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + # sum(p4_2, p3) -> p3_out + stage['sum_43_3'] = SumCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + # sum(p3_out, p4_2) -> p4_out + stage['sum_34_4'] = SumCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + # sum(p5, gp(p4_out, p3_out)) -> p5_out + stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False) + stage['sum_55_5'] = SumCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + # sum(p7, gp(p5_out, p4_2)) -> p7_out + stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False) + stage['sum_77_7'] = SumCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + # gp(p7_out, p5_out) -> p6_out + stage['gp_75_6'] = GlobalPoolingCell( + in_channels=out_channels, + out_channels=out_channels, + out_norm_cfg=norm_cfg) + self.fpn_stages.append(stage) + + def forward(self, inputs: Tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + # build P3-P5 + feats = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # build P6-P7 on top of P5 + for downsample in self.extra_downsamples: + feats.append(downsample(feats[-1])) + + p3, p4, p5, p6, p7 = feats + + for stage in self.fpn_stages: + # gp(p6, p4) -> p4_1 + p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:]) + # sum(p4_1, p4) -> p4_2 + p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:]) + # sum(p4_2, p3) -> p3_out + p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:]) + # sum(p3_out, p4_2) -> p4_out + p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:]) + # sum(p5, gp(p4_out, p3_out)) -> p5_out + p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:]) + p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:]) + # sum(p7, gp(p5_out, p4_2)) -> p7_out + p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:]) + p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:]) + # gp(p7_out, p5_out) -> p6_out + p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:]) + + return p3, p4, p5, p6, p7 diff --git a/mmdet/models/necks/nasfcos_fpn.py b/mmdet/models/necks/nasfcos_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..12d0848f7634bb0113e0b5a16b5b65ba8b7ebb9c --- /dev/null +++ b/mmdet/models/necks/nasfcos_fpn.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.ops.merge_cells import ConcatCell +from mmengine.model import BaseModule, caffe2_xavier_init + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class NASFCOS_FPN(BaseModule): + """FPN structure in NASFPN. + + Implementation of paper `NAS-FCOS: Fast Neural Architecture Search for + Object Detection `_ + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool): It decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + conv_cfg (dict): dictionary to construct and config conv layer. + norm_cfg (dict): dictionary to construct and config norm layer. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=1, + end_level=-1, + add_extra_convs=False, + conv_cfg=None, + norm_cfg=None, + init_cfg=None): + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super(NASFCOS_FPN, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + + self.adapt_convs = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + adapt_conv = ConvModule( + in_channels[i], + out_channels, + 1, + stride=1, + padding=0, + bias=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU', inplace=False)) + self.adapt_convs.append(adapt_conv) + + # C2 is omitted according to the paper + extra_levels = num_outs - self.backbone_end_level + self.start_level + + def build_concat_cell(with_input1_conv, with_input2_conv): + cell_conv_cfg = dict( + kernel_size=1, padding=0, bias=False, groups=out_channels) + return ConcatCell( + in_channels=out_channels, + out_channels=out_channels, + with_out_conv=True, + out_conv_cfg=cell_conv_cfg, + out_norm_cfg=dict(type='BN'), + out_conv_order=('norm', 'act', 'conv'), + with_input1_conv=with_input1_conv, + with_input2_conv=with_input2_conv, + input_conv_cfg=conv_cfg, + input_norm_cfg=norm_cfg, + upsample_mode='nearest') + + # Denote c3=f0, c4=f1, c5=f2 for convince + self.fpn = nn.ModuleDict() + self.fpn['c22_1'] = build_concat_cell(True, True) + self.fpn['c22_2'] = build_concat_cell(True, True) + self.fpn['c32'] = build_concat_cell(True, False) + self.fpn['c02'] = build_concat_cell(True, False) + self.fpn['c42'] = build_concat_cell(True, True) + self.fpn['c36'] = build_concat_cell(True, True) + self.fpn['c61'] = build_concat_cell(True, True) # f9 + self.extra_downsamples = nn.ModuleList() + for i in range(extra_levels): + extra_act_cfg = None if i == 0 \ + else dict(type='ReLU', inplace=False) + self.extra_downsamples.append( + ConvModule( + out_channels, + out_channels, + 3, + stride=2, + padding=1, + act_cfg=extra_act_cfg, + order=('act', 'norm', 'conv'))) + + def forward(self, inputs): + """Forward function.""" + feats = [ + adapt_conv(inputs[i + self.start_level]) + for i, adapt_conv in enumerate(self.adapt_convs) + ] + + for (i, module_name) in enumerate(self.fpn): + idx_1, idx_2 = int(module_name[1]), int(module_name[2]) + res = self.fpn[module_name](feats[idx_1], feats[idx_2]) + feats.append(res) + + ret = [] + for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5 + feats1, feats2 = feats[idx], feats[5] + feats2_resize = F.interpolate( + feats2, + size=feats1.size()[2:], + mode='bilinear', + align_corners=False) + + feats_sum = feats1 + feats2_resize + ret.append( + F.interpolate( + feats_sum, + size=inputs[input_idx].size()[2:], + mode='bilinear', + align_corners=False)) + + for submodule in self.extra_downsamples: + ret.append(submodule(ret[-1])) + + return tuple(ret) + + def init_weights(self): + """Initialize the weights of module.""" + super(NASFCOS_FPN, self).init_weights() + for module in self.fpn.values(): + if hasattr(module, 'conv_out'): + caffe2_xavier_init(module.out_conv.conv) + + for modules in [ + self.adapt_convs.modules(), + self.extra_downsamples.modules() + ]: + for module in modules: + if isinstance(module, nn.Conv2d): + caffe2_xavier_init(module) diff --git a/mmdet/models/necks/pafpn.py b/mmdet/models/necks/pafpn.py new file mode 100644 index 0000000000000000000000000000000000000000..557638f48a629691f780d3e1466e234bbe987518 --- /dev/null +++ b/mmdet/models/necks/pafpn.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmdet.registry import MODELS +from .fpn import FPN + + +@MODELS.register_module() +class PAFPN(FPN): + """Path Aggregation Network for Instance Segmentation. + + This is an implementation of the `PAFPN in Path Aggregation Network + `_. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(PAFPN, self).__init__( + in_channels, + out_channels, + num_outs, + start_level, + end_level, + add_extra_convs, + relu_before_extra_convs, + no_norm_on_lateral, + conv_cfg, + norm_cfg, + act_cfg, + init_cfg=init_cfg) + # add extra bottom up pathway + self.downsample_convs = nn.ModuleList() + self.pafpn_convs = nn.ModuleList() + for i in range(self.start_level + 1, self.backbone_end_level): + d_conv = ConvModule( + out_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + pafpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.downsample_convs.append(d_conv) + self.pafpn_convs.append(pafpn_conv) + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, mode='nearest') + + # build outputs + # part 1: from original levels + inter_outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + + # part 2: add bottom-up path + for i in range(0, used_backbone_levels - 1): + inter_outs[i + 1] = inter_outs[i + 1] + \ + self.downsample_convs[i](inter_outs[i]) + + outs = [] + outs.append(inter_outs[0]) + outs.extend([ + self.pafpn_convs[i - 1](inter_outs[i]) + for i in range(1, used_backbone_levels) + ]) + + # part 3: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + orig = inputs[self.backbone_end_level - 1] + outs.append(self.fpn_convs[used_backbone_levels](orig)) + elif self.add_extra_convs == 'on_lateral': + outs.append(self.fpn_convs[used_backbone_levels]( + laterals[-1])) + elif self.add_extra_convs == 'on_output': + outs.append(self.fpn_convs[used_backbone_levels](outs[-1])) + else: + raise NotImplementedError + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmdet/models/necks/rfp.py b/mmdet/models/necks/rfp.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec9b3753c5031bb12a2b4c88733f13bf27c44e2 --- /dev/null +++ b/mmdet/models/necks/rfp.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule, ModuleList, constant_init, xavier_init + +from mmdet.registry import MODELS +from .fpn import FPN + + +class ASPP(BaseModule): + """ASPP (Atrous Spatial Pyramid Pooling) + + This is an implementation of the ASPP module used in DetectoRS + (https://arxiv.org/pdf/2006.02334.pdf) + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of channels produced by this module + dilations (tuple[int]): Dilations of the four branches. + Default: (1, 3, 6, 1) + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + dilations=(1, 3, 6, 1), + init_cfg=dict(type='Kaiming', layer='Conv2d')): + super().__init__(init_cfg) + assert dilations[-1] == 1 + self.aspp = nn.ModuleList() + for dilation in dilations: + kernel_size = 3 if dilation > 1 else 1 + padding = dilation if dilation > 1 else 0 + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + padding=padding, + bias=True) + self.aspp.append(conv) + self.gap = nn.AdaptiveAvgPool2d(1) + + def forward(self, x): + avg_x = self.gap(x) + out = [] + for aspp_idx in range(len(self.aspp)): + inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x + out.append(F.relu_(self.aspp[aspp_idx](inp))) + out[-1] = out[-1].expand_as(out[-2]) + out = torch.cat(out, dim=1) + return out + + +@MODELS.register_module() +class RFP(FPN): + """RFP (Recursive Feature Pyramid) + + This is an implementation of RFP in `DetectoRS + `_. Different from standard FPN, the + input of RFP should be multi level features along with origin input image + of backbone. + + Args: + rfp_steps (int): Number of unrolled steps of RFP. + rfp_backbone (dict): Configuration of the backbone for RFP. + aspp_out_channels (int): Number of output channels of ASPP module. + aspp_dilations (tuple[int]): Dilation rates of four branches. + Default: (1, 3, 6, 1) + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + rfp_steps, + rfp_backbone, + aspp_out_channels, + aspp_dilations=(1, 3, 6, 1), + init_cfg=None, + **kwargs): + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__(init_cfg=init_cfg, **kwargs) + self.rfp_steps = rfp_steps + # Be careful! Pretrained weights cannot be loaded when use + # nn.ModuleList + self.rfp_modules = ModuleList() + for rfp_idx in range(1, rfp_steps): + rfp_module = MODELS.build(rfp_backbone) + self.rfp_modules.append(rfp_module) + self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels, + aspp_dilations) + self.rfp_weight = nn.Conv2d( + self.out_channels, + 1, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def init_weights(self): + # Avoid using super().init_weights(), which may alter the default + # initialization of the modules in self.rfp_modules that have missing + # keys in the pretrained checkpoint. + for convs in [self.lateral_convs, self.fpn_convs]: + for m in convs.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + for rfp_idx in range(self.rfp_steps - 1): + self.rfp_modules[rfp_idx].init_weights() + constant_init(self.rfp_weight, 0) + + def forward(self, inputs): + inputs = list(inputs) + assert len(inputs) == len(self.in_channels) + 1 # +1 for input image + img = inputs.pop(0) + # FPN forward + x = super().forward(tuple(inputs)) + for rfp_idx in range(self.rfp_steps - 1): + rfp_feats = [x[0]] + list( + self.rfp_aspp(x[i]) for i in range(1, len(x))) + x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats) + # FPN forward + x_idx = super().forward(x_idx) + x_new = [] + for ft_idx in range(len(x_idx)): + add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx])) + x_new.append(add_weight * x_idx[ft_idx] + + (1 - add_weight) * x[ft_idx]) + x = x_new + return x diff --git a/mmdet/models/necks/ssd_neck.py b/mmdet/models/necks/ssd_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..17ba319370b988b9c7e2d98c2f10607ff8f8b5c3 --- /dev/null +++ b/mmdet/models/necks/ssd_neck.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class SSDNeck(BaseModule): + """Extra layers of SSD backbone to generate multi-scale feature maps. + + Args: + in_channels (Sequence[int]): Number of input channels per scale. + out_channels (Sequence[int]): Number of output channels per scale. + level_strides (Sequence[int]): Stride of 3x3 conv per level. + level_paddings (Sequence[int]): Padding size of 3x3 conv per level. + l2_norm_scale (float|None): L2 normalization layer init scale. + If None, not use L2 normalization on the first input feature. + last_kernel_size (int): Kernel size of the last conv layer. + Default: 3. + use_depthwise (bool): Whether to use DepthwiseSeparableConv. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + level_strides, + level_paddings, + l2_norm_scale=20., + last_kernel_size=3, + use_depthwise=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + init_cfg=[ + dict( + type='Xavier', distribution='uniform', + layer='Conv2d'), + dict(type='Constant', val=1, layer='BatchNorm2d'), + ]): + super(SSDNeck, self).__init__(init_cfg) + assert len(out_channels) > len(in_channels) + assert len(out_channels) - len(in_channels) == len(level_strides) + assert len(level_strides) == len(level_paddings) + assert in_channels == out_channels[:len(in_channels)] + + if l2_norm_scale: + self.l2_norm = L2Norm(in_channels[0], l2_norm_scale) + self.init_cfg += [ + dict( + type='Constant', + val=self.l2_norm.scale, + override=dict(name='l2_norm')) + ] + + self.extra_layers = nn.ModuleList() + extra_layer_channels = out_channels[len(in_channels):] + second_conv = DepthwiseSeparableConvModule if \ + use_depthwise else ConvModule + + for i, (out_channel, stride, padding) in enumerate( + zip(extra_layer_channels, level_strides, level_paddings)): + kernel_size = last_kernel_size \ + if i == len(extra_layer_channels) - 1 else 3 + per_lvl_convs = nn.Sequential( + ConvModule( + out_channels[len(in_channels) - 1 + i], + out_channel // 2, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + second_conv( + out_channel // 2, + out_channel, + kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.extra_layers.append(per_lvl_convs) + + def forward(self, inputs): + """Forward function.""" + outs = [feat for feat in inputs] + if hasattr(self, 'l2_norm'): + outs[0] = self.l2_norm(outs[0]) + + feat = outs[-1] + for layer in self.extra_layers: + feat = layer(feat) + outs.append(feat) + return tuple(outs) + + +class L2Norm(nn.Module): + + def __init__(self, n_dims, scale=20., eps=1e-10): + """L2 normalization layer. + + Args: + n_dims (int): Number of dimensions to be normalized + scale (float, optional): Defaults to 20.. + eps (float, optional): Used to avoid division by zero. + Defaults to 1e-10. + """ + super(L2Norm, self).__init__() + self.n_dims = n_dims + self.weight = nn.Parameter(torch.Tensor(self.n_dims)) + self.eps = eps + self.scale = scale + + def forward(self, x): + """Forward function.""" + # normalization layer convert to FP32 in FP16 training + x_float = x.float() + norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps + return (self.weight[None, :, None, None].float().expand_as(x_float) * + x_float / norm).type_as(x) diff --git a/mmdet/models/necks/ssh.py b/mmdet/models/necks/ssh.py new file mode 100644 index 0000000000000000000000000000000000000000..75a6561489d8d3634fc34829dafe819bbf066ed4 --- /dev/null +++ b/mmdet/models/necks/ssh.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + + +class SSHContextModule(BaseModule): + """This is an implementation of `SSH context module` described in `SSH: + Single Stage Headless Face Detector. + + `_. + + Args: + in_channels (int): Number of input channels used at each scale. + out_channels (int): Number of output channels used at each scale. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to dict(type='BN'). + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + init_cfg: OptMultiConfig = None): + super().__init__(init_cfg=init_cfg) + assert out_channels % 4 == 0 + + self.in_channels = in_channels + self.out_channels = out_channels + + self.conv5x5_1 = ConvModule( + self.in_channels, + self.out_channels // 4, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + ) + + self.conv5x5_2 = ConvModule( + self.out_channels // 4, + self.out_channels // 4, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.conv7x7_2 = ConvModule( + self.out_channels // 4, + self.out_channels // 4, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + ) + + self.conv7x7_3 = ConvModule( + self.out_channels // 4, + self.out_channels // 4, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + ) + + def forward(self, x: torch.Tensor) -> tuple: + conv5x5_1 = self.conv5x5_1(x) + conv5x5 = self.conv5x5_2(conv5x5_1) + conv7x7_2 = self.conv7x7_2(conv5x5_1) + conv7x7 = self.conv7x7_3(conv7x7_2) + + return (conv5x5, conv7x7) + + +class SSHDetModule(BaseModule): + """This is an implementation of `SSH detection module` described in `SSH: + Single Stage Headless Face Detector. + + `_. + + Args: + in_channels (int): Number of input channels used at each scale. + out_channels (int): Number of output channels used at each scale. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to dict(type='BN'). + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + init_cfg: OptMultiConfig = None): + super().__init__(init_cfg=init_cfg) + assert out_channels % 4 == 0 + + self.in_channels = in_channels + self.out_channels = out_channels + + self.conv3x3 = ConvModule( + self.in_channels, + self.out_channels // 2, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.context_module = SSHContextModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + conv3x3 = self.conv3x3(x) + conv5x5, conv7x7 = self.context_module(x) + out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1) + out = F.relu(out) + + return out + + +@MODELS.register_module() +class SSH(BaseModule): + """`SSH Neck` used in `SSH: Single Stage Headless Face Detector. + + `_. + + Args: + num_scales (int): The number of scales / stages. + in_channels (list[int]): The number of input channels per scale. + out_channels (list[int]): The number of output channels per scale. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to dict(type='BN'). + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [8, 16, 32, 64] + >>> out_channels = [16, 32, 64, 128] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = SSH(num_scales=4, in_channels=in_channels, + ... out_channels=out_channels) + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 16, 340, 340]) + outputs[1].shape = torch.Size([1, 32, 170, 170]) + outputs[2].shape = torch.Size([1, 64, 84, 84]) + outputs[3].shape = torch.Size([1, 128, 43, 43]) + """ + + def __init__(self, + num_scales: int, + in_channels: List[int], + out_channels: List[int], + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + init_cfg: OptMultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super().__init__(init_cfg=init_cfg) + assert (num_scales == len(in_channels) == len(out_channels)) + self.num_scales = num_scales + self.in_channels = in_channels + self.out_channels = out_channels + + for idx in range(self.num_scales): + in_c, out_c = self.in_channels[idx], self.out_channels[idx] + self.add_module( + f'ssh_module{idx}', + SSHDetModule( + in_channels=in_c, + out_channels=out_c, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + def forward(self, inputs: Tuple[torch.Tensor]) -> tuple: + assert len(inputs) == self.num_scales + + outs = [] + for idx, x in enumerate(inputs): + ssh_module = getattr(self, f'ssh_module{idx}') + out = ssh_module(x) + outs.append(out) + + return tuple(outs) diff --git a/mmdet/models/necks/yolo_neck.py b/mmdet/models/necks/yolo_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..48a6b1a4897c85083aa1e1e7d692263f66de67c3 --- /dev/null +++ b/mmdet/models/necks/yolo_neck.py @@ -0,0 +1,145 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2019 Western Digital Corporation or its affiliates. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + + +class DetectionBlock(BaseModule): + """Detection block in YOLO neck. + + Let out_channels = n, the DetectionBlock contains: + Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer. + The first 6 ConvLayers are formed the following way: + 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n. + The Conv2D layer is 1x1x255. + Some block will have branch after the fifth ConvLayer. + The input channel is arbitrary (in_channels) + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True) + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + act_cfg: ConfigType = dict( + type='LeakyReLU', negative_slope=0.1), + init_cfg: OptMultiConfig = None) -> None: + super(DetectionBlock, self).__init__(init_cfg) + double_out_channels = out_channels * 2 + + # shortcut + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg) + self.conv2 = ConvModule( + out_channels, double_out_channels, 3, padding=1, **cfg) + self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg) + self.conv4 = ConvModule( + out_channels, double_out_channels, 3, padding=1, **cfg) + self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg) + + def forward(self, x: Tensor) -> Tensor: + tmp = self.conv1(x) + tmp = self.conv2(tmp) + tmp = self.conv3(tmp) + tmp = self.conv4(tmp) + out = self.conv5(tmp) + return out + + +@MODELS.register_module() +class YOLOV3Neck(BaseModule): + """The neck of YOLOV3. + + It can be treated as a simplified version of FPN. It + will take the result from Darknet backbone and do some upsampling and + concatenation. It will finally output the detection result. + + Note: + The input feats should be from top to bottom. + i.e., from high-lvl to low-lvl + But YOLOV3Neck will process them in reversed order. + i.e., from bottom (high-lvl) to top (low-lvl) + + Args: + num_scales (int): The number of scales / stages. + in_channels (List[int]): The number of input channels per scale. + out_channels (List[int]): The number of output channels per scale. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None. + norm_cfg (dict, optional): Dictionary to construct and config norm + layer. Default: dict(type='BN', requires_grad=True) + act_cfg (dict, optional): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + num_scales: int, + in_channels: List[int], + out_channels: List[int], + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + act_cfg: ConfigType = dict( + type='LeakyReLU', negative_slope=0.1), + init_cfg: OptMultiConfig = None) -> None: + super(YOLOV3Neck, self).__init__(init_cfg) + assert (num_scales == len(in_channels) == len(out_channels)) + self.num_scales = num_scales + self.in_channels = in_channels + self.out_channels = out_channels + + # shortcut + cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) + + # To support arbitrary scales, the code looks awful, but it works. + # Better solution is welcomed. + self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg) + for i in range(1, self.num_scales): + in_c, out_c = self.in_channels[i], self.out_channels[i] + inter_c = out_channels[i - 1] + self.add_module(f'conv{i}', ConvModule(inter_c, out_c, 1, **cfg)) + # in_c + out_c : High-lvl feats will be cat with low-lvl feats + self.add_module(f'detect{i+1}', + DetectionBlock(in_c + out_c, out_c, **cfg)) + + def forward(self, feats=Tuple[Tensor]) -> Tuple[Tensor]: + assert len(feats) == self.num_scales + + # processed from bottom (high-lvl) to top (low-lvl) + outs = [] + out = self.detect1(feats[-1]) + outs.append(out) + + for i, x in enumerate(reversed(feats[:-1])): + conv = getattr(self, f'conv{i+1}') + tmp = conv(out) + + # Cat with low-lvl feats + tmp = F.interpolate(tmp, scale_factor=2) + tmp = torch.cat((tmp, x), 1) + + detect = getattr(self, f'detect{i+2}') + out = detect(tmp) + outs.append(out) + + return tuple(outs) diff --git a/mmdet/models/necks/yolox_pafpn.py b/mmdet/models/necks/yolox_pafpn.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec3d12bfde8158c1a817fbf223a8eea94798667 --- /dev/null +++ b/mmdet/models/necks/yolox_pafpn.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from ..layers import CSPLayer + + +@MODELS.register_module() +class YOLOXPAFPN(BaseModule): + """Path Aggregation Network used in YOLOX. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(scale_factor=2, mode='nearest')` + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + act_cfg (dict): Config dict for activation layer. + Default: dict(type='Swish') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_csp_blocks=3, + use_depthwise=False, + upsample_cfg=dict(scale_factor=2, mode='nearest'), + conv_cfg=None, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + init_cfg=dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu')): + super(YOLOXPAFPN, self).__init__(init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + + # build top-down blocks + self.upsample = nn.Upsample(**upsample_cfg) + self.reduce_layers = nn.ModuleList() + self.top_down_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1, 0, -1): + self.reduce_layers.append( + ConvModule( + in_channels[idx], + in_channels[idx - 1], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.top_down_blocks.append( + CSPLayer( + in_channels[idx - 1] * 2, + in_channels[idx - 1], + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv( + in_channels[idx], + in_channels[idx], + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottom_up_blocks.append( + CSPLayer( + in_channels[idx] * 2, + in_channels[idx + 1], + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.out_convs = nn.ModuleList() + for i in range(len(in_channels)): + self.out_convs.append( + ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + + Returns: + tuple[Tensor]: YOLOXPAFPN features. + """ + assert len(inputs) == len(self.in_channels) + + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx]( + feat_heigh) + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + torch.cat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx]( + torch.cat([downsample_feat, feat_height], 1)) + outs.append(out) + + # out convs + for idx, conv in enumerate(self.out_convs): + outs[idx] = conv(outs[idx]) + + return tuple(outs) diff --git a/mmdet/models/reid/__init__.py b/mmdet/models/reid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aca617f7dea0b8047891c666ddb684dbbd018c81 --- /dev/null +++ b/mmdet/models/reid/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_reid import BaseReID +from .fc_module import FcModule +from .gap import GlobalAveragePooling +from .linear_reid_head import LinearReIDHead + +__all__ = ['BaseReID', 'GlobalAveragePooling', 'LinearReIDHead', 'FcModule'] diff --git a/mmdet/models/reid/base_reid.py b/mmdet/models/reid/base_reid.py new file mode 100644 index 0000000000000000000000000000000000000000..4c45964394aa1651f846f2a7e63da3ee70b78909 --- /dev/null +++ b/mmdet/models/reid/base_reid.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +try: + import mmpretrain + from mmpretrain.models.classifiers import ImageClassifier +except ImportError: + mmpretrain = None + ImageClassifier = object + +from mmdet.registry import MODELS +from mmdet.structures import ReIDDataSample + + +@MODELS.register_module() +class BaseReID(ImageClassifier): + """Base model for re-identification.""" + + def __init__(self, *args, **kwargs): + if mmpretrain is None: + raise RuntimeError('Please run "pip install openmim" and ' + 'run "mim install mmpretrain" to ' + 'install mmpretrain first.') + super().__init__(*args, **kwargs) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[ReIDDataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`ReIDDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, H, W) or (N, T, C, H, W). + data_samples (List[ReIDDataSample], optional): The annotation + data of every sample. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`ReIDDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if len(inputs.size()) == 5: + assert inputs.size(0) == 1 + inputs = inputs[0] + return super().forward(inputs, data_samples, mode) diff --git a/mmdet/models/reid/fc_module.py b/mmdet/models/reid/fc_module.py new file mode 100644 index 0000000000000000000000000000000000000000..76e7efd66e300a242bb250cc6ba5cc68ed722034 --- /dev/null +++ b/mmdet/models/reid/fc_module.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class FcModule(BaseModule): + """Fully-connected layer module. + + Args: + in_channels (int): Input channels. + out_channels (int): Ourput channels. + norm_cfg (dict, optional): Configuration of normlization method + after fc. Defaults to None. + act_cfg (dict, optional): Configuration of activation method after fc. + Defaults to dict(type='ReLU'). + inplace (bool, optional): Whether inplace the activatation module. + Defaults to True. + init_cfg (dict, optional): Initialization config dict. + Defaults to dict(type='Kaiming', layer='Linear'). + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: dict = None, + act_cfg: dict = dict(type='ReLU'), + inplace: bool = True, + init_cfg=dict(type='Kaiming', layer='Linear')): + super(FcModule, self).__init__(init_cfg) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert act_cfg is None or isinstance(act_cfg, dict) + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.inplace = inplace + + self.with_norm = norm_cfg is not None + self.with_activation = act_cfg is not None + + self.fc = nn.Linear(in_channels, out_channels) + # build normalization layers + if self.with_norm: + self.norm_name, norm = build_norm_layer(norm_cfg, out_channels) + self.add_module(self.norm_name, norm) + + # build activation layer + if self.with_activation: + act_cfg_ = act_cfg.copy() + # nn.Tanh has no 'inplace' argument + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish' + ]: + act_cfg_.setdefault('inplace', inplace) + self.activate = build_activation_layer(act_cfg_) + + @property + def norm(self): + """Normalization.""" + return getattr(self, self.norm_name) + + def forward(self, x, activate=True, norm=True): + """Model forward.""" + x = self.fc(x) + if norm and self.with_norm: + x = self.norm(x) + if activate and self.with_activation: + x = self.activate(x) + return x diff --git a/mmdet/models/reid/gap.py b/mmdet/models/reid/gap.py new file mode 100644 index 0000000000000000000000000000000000000000..aadc25e7144f2ca9efb66b496bf8ffa5504619ff --- /dev/null +++ b/mmdet/models/reid/gap.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmdet.registry import MODELS + + +@MODELS.register_module() +class GlobalAveragePooling(BaseModule): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + """ + + def __init__(self, kernel_size=None, stride=None): + super(GlobalAveragePooling, self).__init__() + if kernel_size is None and stride is None: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + else: + self.gap = nn.AvgPool2d(kernel_size, stride) + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([self.gap(x) for x in inputs]) + outs = tuple([ + out.view(x.size(0), + torch.tensor(out.size()[1:]).prod()) + for out, x in zip(outs, inputs) + ]) + elif isinstance(inputs, torch.Tensor): + outs = self.gap(inputs) + outs = outs.view( + inputs.size(0), + torch.tensor(outs.size()[1:]).prod()) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmdet/models/reid/linear_reid_head.py b/mmdet/models/reid/linear_reid_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f35aaf6c2fc57b60e36017268e2a632df60ed342 --- /dev/null +++ b/mmdet/models/reid/linear_reid_head.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +try: + import mmpretrain + from mmpretrain.evaluation.metrics import Accuracy +except ImportError: + mmpretrain = None + +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from mmdet.structures import ReIDDataSample +from .fc_module import FcModule + + +@MODELS.register_module() +class LinearReIDHead(BaseModule): + """Linear head for re-identification. + + Args: + num_fcs (int): Number of fcs. + in_channels (int): Number of channels in the input. + fc_channels (int): Number of channels in the fcs. + out_channels (int): Number of channels in the output. + norm_cfg (dict, optional): Configuration of normlization method + after fc. Defaults to None. + act_cfg (dict, optional): Configuration of activation method after fc. + Defaults to None. + num_classes (int, optional): Number of the identities. Default to None. + loss_cls (dict, optional): Cross entropy loss to train the ReID module. + Defaults to None. + loss_triplet (dict, optional): Triplet loss to train the ReID module. + Defaults to None. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01, + bias=0). + """ + + def __init__(self, + num_fcs: int, + in_channels: int, + fc_channels: int, + out_channels: int, + norm_cfg: Optional[dict] = None, + act_cfg: Optional[dict] = None, + num_classes: Optional[int] = None, + loss_cls: Optional[dict] = None, + loss_triplet: Optional[dict] = None, + topk: Union[int, Tuple[int]] = (1, ), + init_cfg: Union[dict, List[dict]] = dict( + type='Normal', layer='Linear', mean=0, std=0.01, bias=0)): + if mmpretrain is None: + raise RuntimeError('Please run "pip install openmim" and ' + 'run "mim install mmpretrain" to ' + 'install mmpretrain first.') + super(LinearReIDHead, self).__init__(init_cfg=init_cfg) + + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + for _topk in topk: + assert _topk > 0, 'Top-k should be larger than 0' + self.topk = topk + + if loss_cls is None: + if isinstance(num_classes, int): + warnings.warn('Since cross entropy is not set, ' + 'the num_classes will be ignored.') + if loss_triplet is None: + raise ValueError('Please choose at least one loss in ' + 'triplet loss and cross entropy loss.') + elif not isinstance(num_classes, int): + raise TypeError('The num_classes must be a current number, ' + 'if there is cross entropy loss.') + self.loss_cls = MODELS.build(loss_cls) if loss_cls else None + self.loss_triplet = MODELS.build(loss_triplet) \ + if loss_triplet else None + + self.num_fcs = num_fcs + self.in_channels = in_channels + self.fc_channels = fc_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.num_classes = num_classes + + self._init_layers() + + def _init_layers(self): + """Initialize fc layers.""" + self.fcs = nn.ModuleList() + for i in range(self.num_fcs): + in_channels = self.in_channels if i == 0 else self.fc_channels + self.fcs.append( + FcModule(in_channels, self.fc_channels, self.norm_cfg, + self.act_cfg)) + in_channels = self.in_channels if self.num_fcs == 0 else \ + self.fc_channels + self.fc_out = nn.Linear(in_channels, self.out_channels) + if self.loss_cls: + self.bn = nn.BatchNorm1d(self.out_channels) + self.classifier = nn.Linear(self.out_channels, self.num_classes) + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + # Multiple stage inputs are acceptable + # but only the last stage will be used. + feats = feats[-1] + + for m in self.fcs: + feats = m(feats) + feats = self.fc_out(feats) + return feats + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[ReIDDataSample]) -> dict: + """Calculate losses. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[ReIDDataSample]): The annotation data of + every samples. + + Returns: + dict: a dictionary of loss components + """ + # The part can be traced by torch.fx + feats = self(feats) + + # The part can not be traced by torch.fx + losses = self.loss_by_feat(feats, data_samples) + return losses + + def loss_by_feat(self, feats: torch.Tensor, + data_samples: List[ReIDDataSample]) -> dict: + """Unpack data samples and compute loss.""" + losses = dict() + gt_label = torch.cat([i.gt_label.label for i in data_samples]) + gt_label = gt_label.to(feats.device) + + if self.loss_triplet: + losses['triplet_loss'] = self.loss_triplet(feats, gt_label) + + if self.loss_cls: + feats_bn = self.bn(feats) + cls_score = self.classifier(feats_bn) + losses['ce_loss'] = self.loss_cls(cls_score, gt_label) + acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]: + """Inference without augmentation. + + Args: + feats (Tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used. + data_samples (List[ReIDDataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[ReIDDataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + feats = self(feats) + + # The part can not be traced by torch.fx + data_samples = self.predict_by_feat(feats, data_samples) + + return data_samples + + def predict_by_feat( + self, + feats: torch.Tensor, + data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]: + """Add prediction features to data samples.""" + if data_samples is not None: + for data_sample, feat in zip(data_samples, feats): + data_sample.pred_feature = feat + else: + data_samples = [] + for feat in feats: + data_sample = ReIDDataSample() + data_sample.pred_feature = feat + data_samples.append(data_sample) + + return data_samples diff --git a/mmdet/models/roi_heads/__init__.py b/mmdet/models/roi_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bba5664cc5ae5229ddebcb42f7583364ca9f77d8 --- /dev/null +++ b/mmdet/models/roi_heads/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_roi_head import BaseRoIHead +from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DIIHead, + DoubleConvFCBBoxHead, SABLHead, SCNetBBoxHead, + Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) +from .cascade_roi_head import CascadeRoIHead +from .double_roi_head import DoubleHeadRoIHead +from .dynamic_roi_head import DynamicRoIHead +from .grid_roi_head import GridRoIHead +from .htc_roi_head import HybridTaskCascadeRoIHead +from .mask_heads import (CoarseMaskHead, FCNMaskHead, FeatureRelayHead, + FusedSemanticHead, GlobalContextHead, GridHead, + HTCMaskHead, MaskIoUHead, MaskPointHead, + SCNetMaskHead, SCNetSemanticHead) +from .mask_scoring_roi_head import MaskScoringRoIHead +from .multi_instance_roi_head import MultiInstanceRoIHead +from .pisa_roi_head import PISARoIHead +from .point_rend_roi_head import PointRendRoIHead +from .roi_extractors import (BaseRoIExtractor, GenericRoIExtractor, + SingleRoIExtractor) +from .scnet_roi_head import SCNetRoIHead +from .shared_heads import ResLayer +from .sparse_roi_head import SparseRoIHead +from .standard_roi_head import StandardRoIHead +from .trident_roi_head import TridentRoIHead + +__all__ = [ + 'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead', + 'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead', + 'ConvFCBBoxHead', 'DIIHead', 'SABLHead', 'Shared2FCBBoxHead', + 'StandardRoIHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', + 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', + 'MaskIoUHead', 'BaseRoIExtractor', 'GenericRoIExtractor', + 'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead', + 'CoarseMaskHead', 'DynamicRoIHead', 'SparseRoIHead', 'TridentRoIHead', + 'SCNetRoIHead', 'SCNetMaskHead', 'SCNetSemanticHead', 'SCNetBBoxHead', + 'FeatureRelayHead', 'GlobalContextHead', 'MultiInstanceRoIHead' +] diff --git a/mmdet/models/roi_heads/base_roi_head.py b/mmdet/models/roi_heads/base_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..405f80a73ecc5db7343d81ca55518160fcbc2b63 --- /dev/null +++ b/mmdet/models/roi_heads/base_roi_head.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Tuple + +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig + + +class BaseRoIHead(BaseModule, metaclass=ABCMeta): + """Base class for RoIHeads.""" + + def __init__(self, + bbox_roi_extractor: OptMultiConfig = None, + bbox_head: OptMultiConfig = None, + mask_roi_extractor: OptMultiConfig = None, + mask_head: OptMultiConfig = None, + shared_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + if shared_head is not None: + self.shared_head = MODELS.build(shared_head) + + if bbox_head is not None: + self.init_bbox_head(bbox_roi_extractor, bbox_head) + + if mask_head is not None: + self.init_mask_head(mask_roi_extractor, mask_head) + + self.init_assigner_sampler() + + @property + def with_bbox(self) -> bool: + """bool: whether the RoI head contains a `bbox_head`""" + return hasattr(self, 'bbox_head') and self.bbox_head is not None + + @property + def with_mask(self) -> bool: + """bool: whether the RoI head contains a `mask_head`""" + return hasattr(self, 'mask_head') and self.mask_head is not None + + @property + def with_shared_head(self) -> bool: + """bool: whether the RoI head contains a `shared_head`""" + return hasattr(self, 'shared_head') and self.shared_head is not None + + @abstractmethod + def init_bbox_head(self, *args, **kwargs): + """Initialize ``bbox_head``""" + pass + + @abstractmethod + def init_mask_head(self, *args, **kwargs): + """Initialize ``mask_head``""" + pass + + @abstractmethod + def init_assigner_sampler(self, *args, **kwargs): + """Initialize assigner and sampler.""" + pass + + @abstractmethod + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList): + """Perform forward propagation and loss calculation of the roi head on + the features of the upstream network.""" + + def predict(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from upstream network. Each + has shape (N, C, H, W). + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results to + the original image. Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + assert self.with_bbox, 'Bbox head must be implemented.' + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + # TODO: nms_op in mmcv need be enhanced, the bbox result may get + # difference when not rescale in bbox_head + + # If it has the mask branch, the bbox branch does not need + # to be scaled to the original image scale, because the mask + # branch will scale both bbox and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.predict_bbox( + x, + batch_img_metas, + rpn_results_list, + rcnn_test_cfg=self.test_cfg, + rescale=bbox_rescale) + + if self.with_mask: + results_list = self.predict_mask( + x, batch_img_metas, results_list, rescale=rescale) + + return results_list diff --git a/mmdet/models/roi_heads/bbox_heads/__init__.py b/mmdet/models/roi_heads/bbox_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e742abfecfc9dfe37b78822407fc92e9d64cc3 --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bbox_head import BBoxHead +from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead, + Shared4Conv1FCBBoxHead) +from .dii_head import DIIHead +from .double_bbox_head import DoubleConvFCBBoxHead +from .multi_instance_bbox_head import MultiInstanceBBoxHead +from .sabl_head import SABLHead +from .scnet_bbox_head import SCNetBBoxHead + +__all__ = [ + 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', + 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead', 'DIIHead', + 'SCNetBBoxHead', 'MultiInstanceBBoxHead' +] diff --git a/mmdet/models/roi_heads/bbox_heads/bbox_head.py b/mmdet/models/roi_heads/bbox_heads/bbox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2e8aae0833ae0351b544099d79d296f082a76e --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/bbox_head.py @@ -0,0 +1,708 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.config import ConfigDict +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.models.layers import multiclass_nms +from mmdet.models.losses import accuracy +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.utils import empty_instances, multi_apply +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures.bbox import get_box_tensor, scale_boxes +from mmdet.utils import ConfigType, InstanceList, OptMultiConfig + + +@MODELS.register_module() +class BBoxHead(BaseModule): + """Simplest RoI head, with only two fc layers for classification and + regression respectively.""" + + def __init__(self, + with_avg_pool: bool = False, + with_cls: bool = True, + with_reg: bool = True, + roi_feat_size: int = 7, + in_channels: int = 256, + num_classes: int = 80, + bbox_coder: ConfigType = dict( + type='DeltaXYWHBBoxCoder', + clip_border=True, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + predict_box_type: str = 'hbox', + reg_class_agnostic: bool = False, + reg_decoded_bbox: bool = False, + reg_predictor_cfg: ConfigType = dict(type='Linear'), + cls_predictor_cfg: ConfigType = dict(type='Linear'), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='SmoothL1Loss', beta=1.0, loss_weight=1.0), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + assert with_cls or with_reg + self.with_avg_pool = with_avg_pool + self.with_cls = with_cls + self.with_reg = with_reg + self.roi_feat_size = _pair(roi_feat_size) + self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] + self.in_channels = in_channels + self.num_classes = num_classes + self.predict_box_type = predict_box_type + self.reg_class_agnostic = reg_class_agnostic + self.reg_decoded_bbox = reg_decoded_bbox + self.reg_predictor_cfg = reg_predictor_cfg + self.cls_predictor_cfg = cls_predictor_cfg + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + + in_channels = self.in_channels + if self.with_avg_pool: + self.avg_pool = nn.AvgPool2d(self.roi_feat_size) + else: + in_channels *= self.roi_feat_area + if self.with_cls: + # need to add background class + if self.custom_cls_channels: + cls_channels = self.loss_cls.get_cls_channels(self.num_classes) + else: + cls_channels = num_classes + 1 + cls_predictor_cfg_ = self.cls_predictor_cfg.copy() + cls_predictor_cfg_.update( + in_features=in_channels, out_features=cls_channels) + self.fc_cls = MODELS.build(cls_predictor_cfg_) + if self.with_reg: + box_dim = self.bbox_coder.encode_size + out_dim_reg = box_dim if reg_class_agnostic else \ + box_dim * num_classes + reg_predictor_cfg_ = self.reg_predictor_cfg.copy() + if isinstance(reg_predictor_cfg_, (dict, ConfigDict)): + reg_predictor_cfg_.update( + in_features=in_channels, out_features=out_dim_reg) + self.fc_reg = MODELS.build(reg_predictor_cfg_) + self.debug_imgs = None + if init_cfg is None: + self.init_cfg = [] + if self.with_cls: + self.init_cfg += [ + dict( + type='Normal', std=0.01, override=dict(name='fc_cls')) + ] + if self.with_reg: + self.init_cfg += [ + dict( + type='Normal', std=0.001, override=dict(name='fc_reg')) + ] + + # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead + @property + def custom_cls_channels(self) -> bool: + """get custom_cls_channels from loss_cls.""" + return getattr(self.loss_cls, 'custom_cls_channels', False) + + # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead + @property + def custom_activation(self) -> bool: + """get custom_activation from loss_cls.""" + return getattr(self.loss_cls, 'custom_activation', False) + + # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead + @property + def custom_accuracy(self) -> bool: + """get custom_accuracy from loss_cls.""" + return getattr(self.loss_cls, 'custom_accuracy', False) + + def forward(self, x: Tuple[Tensor]) -> tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and bbox prediction. + + - cls_score (Tensor): Classification scores for all + scale levels, each is a 4D-tensor, the channels number + is num_base_priors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for all + scale levels, each is a 4D-tensor, the channels number + is num_base_priors * 4. + """ + if self.with_avg_pool: + if x.numel() > 0: + x = self.avg_pool(x) + x = x.view(x.size(0), -1) + else: + # avg_pool does not support empty tensor, + # so use torch.mean instead it + x = torch.mean(x, dim=(-1, -2)) + cls_score = self.fc_cls(x) if self.with_cls else None + bbox_pred = self.fc_reg(x) if self.with_reg else None + return cls_score, bbox_pred + + def _get_targets_single(self, pos_priors: Tensor, neg_priors: Tensor, + pos_gt_bboxes: Tensor, pos_gt_labels: Tensor, + cfg: ConfigDict) -> tuple: + """Calculate the ground truth for proposals in the single image + according to the sampling results. + + Args: + pos_priors (Tensor): Contains all the positive boxes, + has shape (num_pos, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + neg_priors (Tensor): Contains all the negative boxes, + has shape (num_neg, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + pos_gt_bboxes (Tensor): Contains gt_boxes for + all positive samples, has shape (num_pos, 4), + the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + pos_gt_labels (Tensor): Contains gt_labels for + all positive samples, has shape (num_pos, ). + cfg (obj:`ConfigDict`): `train_cfg` of R-CNN. + + Returns: + Tuple[Tensor]: Ground truth for proposals + in a single image. Containing the following Tensors: + + - labels(Tensor): Gt_labels for all proposals, has + shape (num_proposals,). + - label_weights(Tensor): Labels_weights for all + proposals, has shape (num_proposals,). + - bbox_targets(Tensor):Regression target for all + proposals, has shape (num_proposals, 4), the + last dimension 4 represents [tl_x, tl_y, br_x, br_y]. + - bbox_weights(Tensor):Regression weights for all + proposals, has shape (num_proposals, 4). + """ + num_pos = pos_priors.size(0) + num_neg = neg_priors.size(0) + num_samples = num_pos + num_neg + + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = pos_priors.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + reg_dim = pos_gt_bboxes.size(-1) if self.reg_decoded_bbox \ + else self.bbox_coder.encode_size + label_weights = pos_priors.new_zeros(num_samples) + bbox_targets = pos_priors.new_zeros(num_samples, reg_dim) + bbox_weights = pos_priors.new_zeros(num_samples, reg_dim) + if num_pos > 0: + labels[:num_pos] = pos_gt_labels + pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight + label_weights[:num_pos] = pos_weight + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + pos_priors, pos_gt_bboxes) + else: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, both + # the predicted boxes and regression targets should be with + # absolute coordinate format. + pos_bbox_targets = get_box_tensor(pos_gt_bboxes) + bbox_targets[:num_pos, :] = pos_bbox_targets + bbox_weights[:num_pos, :] = 1 + if num_neg > 0: + label_weights[-num_neg:] = 1.0 + + return labels, label_weights, bbox_targets, bbox_weights + + def get_targets(self, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict, + concat: bool = True) -> tuple: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + + Almost the same as the implementation in bbox_head, we passed + additional parameters pos_inds_list and neg_inds_list to + `_get_targets_single` function. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + concat (bool): Whether to concatenate the results of all + the images in a single batch. + + Returns: + Tuple[Tensor]: Ground truth for proposals in a single image. + Containing the following list of Tensors: + + - labels (list[Tensor],Tensor): Gt_labels for all + proposals in a batch, each tensor in list has + shape (num_proposals,) when `concat=False`, otherwise + just a single tensor has shape (num_all_proposals,). + - label_weights (list[Tensor]): Labels_weights for + all proposals in a batch, each tensor in list has + shape (num_proposals,) when `concat=False`, otherwise + just a single tensor has shape (num_all_proposals,). + - bbox_targets (list[Tensor],Tensor): Regression target + for all proposals in a batch, each tensor in list + has shape (num_proposals, 4) when `concat=False`, + otherwise just a single tensor has shape + (num_all_proposals, 4), the last dimension 4 represents + [tl_x, tl_y, br_x, br_y]. + - bbox_weights (list[tensor],Tensor): Regression weights for + all proposals in a batch, each tensor in list has shape + (num_proposals, 4) when `concat=False`, otherwise just a + single tensor has shape (num_all_proposals, 4). + """ + pos_priors_list = [res.pos_priors for res in sampling_results] + neg_priors_list = [res.neg_priors for res in sampling_results] + pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] + pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] + labels, label_weights, bbox_targets, bbox_weights = multi_apply( + self._get_targets_single, + pos_priors_list, + neg_priors_list, + pos_gt_bboxes_list, + pos_gt_labels_list, + cfg=rcnn_train_cfg) + + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + bbox_targets = torch.cat(bbox_targets, 0) + bbox_weights = torch.cat(bbox_weights, 0) + return labels, label_weights, bbox_targets, bbox_weights + + def loss_and_target(self, + cls_score: Tensor, + bbox_pred: Tensor, + rois: Tensor, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict, + concat: bool = True, + reduction_override: Optional[str] = None) -> dict: + """Calculate the loss based on the features extracted by the bbox head. + + Args: + cls_score (Tensor): Classification prediction + results of all class, has shape + (batch_size * num_proposals_single_image, num_classes) + bbox_pred (Tensor): Regression prediction results, + has shape + (batch_size * num_proposals_single_image, 4), the last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + rois (Tensor): RoIs with the shape + (batch_size * num_proposals_single_image, 5) where the first + column indicates batch id of each RoI. + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + concat (bool): Whether to concatenate the results of all + the images in a single batch. Defaults to True. + reduction_override (str, optional): The reduction + method used to override the original reduction + method of the loss. Options are "none", + "mean" and "sum". Defaults to None, + + Returns: + dict: A dictionary of loss and targets components. + The targets are only used for cascade rcnn. + """ + + cls_reg_targets = self.get_targets( + sampling_results, rcnn_train_cfg, concat=concat) + losses = self.loss( + cls_score, + bbox_pred, + rois, + *cls_reg_targets, + reduction_override=reduction_override) + + # cls_reg_targets is only for cascade rcnn + return dict(loss_bbox=losses, bbox_targets=cls_reg_targets) + + def loss(self, + cls_score: Tensor, + bbox_pred: Tensor, + rois: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tensor, + bbox_weights: Tensor, + reduction_override: Optional[str] = None) -> dict: + """Calculate the loss based on the network predictions and targets. + + Args: + cls_score (Tensor): Classification prediction + results of all class, has shape + (batch_size * num_proposals_single_image, num_classes) + bbox_pred (Tensor): Regression prediction results, + has shape + (batch_size * num_proposals_single_image, 4), the last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + rois (Tensor): RoIs with the shape + (batch_size * num_proposals_single_image, 5) where the first + column indicates batch id of each RoI. + labels (Tensor): Gt_labels for all proposals in a batch, has + shape (batch_size * num_proposals_single_image, ). + label_weights (Tensor): Labels_weights for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, ). + bbox_targets (Tensor): Regression target for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, 4), + the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. + bbox_weights (Tensor): Regression weights for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, 4). + reduction_override (str, optional): The reduction + method used to override the original reduction + method of the loss. Options are "none", + "mean" and "sum". Defaults to None, + + Returns: + dict: A dictionary of loss. + """ + + losses = dict() + + if cls_score is not None: + avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) + if cls_score.numel() > 0: + loss_cls_ = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + if isinstance(loss_cls_, dict): + losses.update(loss_cls_) + else: + losses['loss_cls'] = loss_cls_ + if self.custom_activation: + acc_ = self.loss_cls.get_accuracy(cls_score, labels) + losses.update(acc_) + else: + losses['acc'] = accuracy(cls_score, labels) + if bbox_pred is not None: + bg_class_ind = self.num_classes + # 0~self.num_classes-1 are FG, self.num_classes is BG + pos_inds = (labels >= 0) & (labels < bg_class_ind) + # do not perform bounding box regression for BG anymore. + if pos_inds.any(): + if self.reg_decoded_bbox: + # When the regression loss (e.g. `IouLoss`, + # `GIouLoss`, `DIouLoss`) is applied directly on + # the decoded bounding boxes, it decodes the + # already encoded coordinates to absolute format. + bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) + bbox_pred = get_box_tensor(bbox_pred) + if self.reg_class_agnostic: + pos_bbox_pred = bbox_pred.view( + bbox_pred.size(0), -1)[pos_inds.type(torch.bool)] + else: + pos_bbox_pred = bbox_pred.view( + bbox_pred.size(0), self.num_classes, + -1)[pos_inds.type(torch.bool), + labels[pos_inds.type(torch.bool)]] + losses['loss_bbox'] = self.loss_bbox( + pos_bbox_pred, + bbox_targets[pos_inds.type(torch.bool)], + bbox_weights[pos_inds.type(torch.bool)], + avg_factor=bbox_targets.size(0), + reduction_override=reduction_override) + else: + losses['loss_bbox'] = bbox_pred[pos_inds].sum() + + return losses + + def predict_by_feat(self, + rois: Tuple[Tensor], + cls_scores: Tuple[Tensor], + bbox_preds: Tuple[Tensor], + batch_img_metas: List[dict], + rcnn_test_cfg: Optional[ConfigDict] = None, + rescale: bool = False) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + rois (tuple[Tensor]): Tuple of boxes to be transformed. + Each has shape (num_boxes, 5). last dimension 5 arrange as + (batch_index, x1, y1, x2, y2). + cls_scores (tuple[Tensor]): Tuple of box scores, each has shape + (num_boxes, num_classes + 1). + bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each + has shape (num_boxes, num_classes * 4). + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Instance segmentation + results of each image after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) + result_list = [] + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single( + roi=rois[img_id], + cls_score=cls_scores[img_id], + bbox_pred=bbox_preds[img_id], + img_meta=img_meta, + rescale=rescale, + rcnn_test_cfg=rcnn_test_cfg) + result_list.append(results) + + return result_list + + def _predict_by_feat_single( + self, + roi: Tensor, + cls_score: Tensor, + bbox_pred: Tensor, + img_meta: dict, + rescale: bool = False, + rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5). + last dimension 5 arrange as (batch_index, x1, y1, x2, y2). + cls_score (Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tensor): Box energies / deltas. + has shape (num_boxes, num_classes * 4). + img_meta (dict): image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None + + Returns: + :obj:`InstanceData`: Detection results of each image\ + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + results = InstanceData() + if roi.shape[0] == 0: + return empty_instances([img_meta], + roi.device, + task_type='bbox', + instance_results=[results], + box_type=self.predict_box_type, + use_box_type=False, + num_classes=self.num_classes, + score_per_cls=rcnn_test_cfg is None)[0] + + # some loss (Seesaw loss..) may have custom activation + if self.custom_cls_channels: + scores = self.loss_cls.get_activation(cls_score) + else: + scores = F.softmax( + cls_score, dim=-1) if cls_score is not None else None + + img_shape = img_meta['img_shape'] + num_rois = roi.size(0) + # bbox_pred would be None in some detector when with_reg is False, + # e.g. Grid R-CNN. + if bbox_pred is not None: + num_classes = 1 if self.reg_class_agnostic else self.num_classes + roi = roi.repeat_interleave(num_classes, dim=0) + bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size) + bboxes = self.bbox_coder.decode( + roi[..., 1:], bbox_pred, max_shape=img_shape) + else: + bboxes = roi[:, 1:].clone() + if img_shape is not None and bboxes.size(-1) == 4: + bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1]) + bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0]) + + if rescale and bboxes.size(0) > 0: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + bboxes = scale_boxes(bboxes, scale_factor) + + # Get the inside tensor when `bboxes` is a box type + bboxes = get_box_tensor(bboxes) + box_dim = bboxes.size(-1) + bboxes = bboxes.view(num_rois, -1) + + if rcnn_test_cfg is None: + # This means that it is aug test. + # It needs to return the raw results without nms. + results.bboxes = bboxes + results.scores = scores + else: + det_bboxes, det_labels = multiclass_nms( + bboxes, + scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img, + box_dim=box_dim) + results.bboxes = det_bboxes[:, :-1] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + return results + + def refine_bboxes(self, sampling_results: Union[List[SamplingResult], + InstanceList], + bbox_results: dict, + batch_img_metas: List[dict]) -> InstanceList: + """Refine bboxes during training. + + Args: + sampling_results (List[:obj:`SamplingResult`] or + List[:obj:`InstanceData`]): Sampling results. + :obj:`SamplingResult` is the real sampling results + calculate from bbox_head, while :obj:`InstanceData` is + fake sampling results, e.g., in Sparse R-CNN or QueryInst, etc. + bbox_results (dict): Usually is a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `rois` (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + - `bbox_targets` (tuple): Ground truth for proposals in a + single image. Containing the following list of Tensors: + (labels, label_weights, bbox_targets, bbox_weights) + batch_img_metas (List[dict]): List of image information. + + Returns: + list[:obj:`InstanceData`]: Refined bboxes of each image. + + Example: + >>> # xdoctest: +REQUIRES(module:kwarray) + >>> import numpy as np + >>> from mmdet.models.task_modules.samplers. + ... sampling_result import random_boxes + >>> from mmdet.models.task_modules.samplers import SamplingResult + >>> self = BBoxHead(reg_class_agnostic=True) + >>> n_roi = 2 + >>> n_img = 4 + >>> scale = 512 + >>> rng = np.random.RandomState(0) + ... batch_img_metas = [{'img_shape': (scale, scale)} + >>> for _ in range(n_img)] + >>> sampling_results = [SamplingResult.random(rng=10) + ... for _ in range(n_img)] + >>> # Create rois in the expected format + >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng) + >>> img_ids = torch.randint(0, n_img, (n_roi,)) + >>> img_ids = img_ids.float() + >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1) + >>> # Create other args + >>> labels = torch.randint(0, 81, (scale,)).long() + >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng) + >>> cls_score = torch.randn((scale, 81)) + ... # For each image, pretend random positive boxes are gts + >>> bbox_targets = (labels, None, None, None) + ... bbox_results = dict(rois=rois, bbox_pred=bbox_preds, + ... cls_score=cls_score, + ... bbox_targets=bbox_targets) + >>> bboxes_list = self.refine_bboxes(sampling_results, + ... bbox_results, + ... batch_img_metas) + >>> print(bboxes_list) + """ + pos_is_gts = [res.pos_is_gt for res in sampling_results] + # bbox_targets is a tuple + labels = bbox_results['bbox_targets'][0] + cls_scores = bbox_results['cls_score'] + rois = bbox_results['rois'] + bbox_preds = bbox_results['bbox_pred'] + if self.custom_activation: + # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead + cls_scores = self.loss_cls.get_activation(cls_scores) + if cls_scores.numel() == 0: + return None + if cls_scores.shape[-1] == self.num_classes + 1: + # remove background class + cls_scores = cls_scores[:, :-1] + elif cls_scores.shape[-1] != self.num_classes: + raise ValueError('The last dim of `cls_scores` should equal to ' + '`num_classes` or `num_classes + 1`,' + f'but got {cls_scores.shape[-1]}.') + labels = torch.where(labels == self.num_classes, cls_scores.argmax(1), + labels) + + img_ids = rois[:, 0].long().unique(sorted=True) + assert img_ids.numel() <= len(batch_img_metas) + + results_list = [] + for i in range(len(batch_img_metas)): + inds = torch.nonzero( + rois[:, 0] == i, as_tuple=False).squeeze(dim=1) + num_rois = inds.numel() + + bboxes_ = rois[inds, 1:] + label_ = labels[inds] + bbox_pred_ = bbox_preds[inds] + img_meta_ = batch_img_metas[i] + pos_is_gts_ = pos_is_gts[i] + + bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, + img_meta_) + # filter gt bboxes + pos_keep = 1 - pos_is_gts_ + keep_inds = pos_is_gts_.new_ones(num_rois) + keep_inds[:len(pos_is_gts_)] = pos_keep + results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)]) + results_list.append(results) + + return results_list + + def regress_by_class(self, priors: Tensor, label: Tensor, + bbox_pred: Tensor, img_meta: dict) -> Tensor: + """Regress the bbox for the predicted class. Used in Cascade R-CNN. + + Args: + priors (Tensor): Priors from `rpn_head` or last stage + `bbox_head`, has shape (num_proposals, 4). + label (Tensor): Only used when `self.reg_class_agnostic` + is False, has shape (num_proposals, ). + bbox_pred (Tensor): Regression prediction of + current stage `bbox_head`. When `self.reg_class_agnostic` + is False, it has shape (n, num_classes * 4), otherwise + it has shape (n, 4). + img_meta (dict): Image meta info. + + Returns: + Tensor: Regressed bboxes, the same shape as input rois. + """ + reg_dim = self.bbox_coder.encode_size + if not self.reg_class_agnostic: + label = label * reg_dim + inds = torch.stack([label + i for i in range(reg_dim)], 1) + bbox_pred = torch.gather(bbox_pred, 1, inds) + assert bbox_pred.size()[1] == reg_dim + + max_shape = img_meta['img_shape'] + regressed_bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=max_shape) + return regressed_bboxes diff --git a/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6aadd86d34af3605d432492931442026432cc8 --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py @@ -0,0 +1,249 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdet.registry import MODELS +from .bbox_head import BBoxHead + + +@MODELS.register_module() +class ConvFCBBoxHead(BBoxHead): + r"""More general bbox head, with shared conv and fc layers and two optional + separated branches. + + .. code-block:: none + + /-> cls convs -> cls fcs -> cls + shared convs -> shared fcs + \-> reg convs -> reg fcs -> reg + """ # noqa: W605 + + def __init__(self, + num_shared_convs: int = 0, + num_shared_fcs: int = 0, + num_cls_convs: int = 0, + num_cls_fcs: int = 0, + num_reg_convs: int = 0, + num_reg_fcs: int = 0, + conv_out_channels: int = 256, + fc_out_channels: int = 1024, + conv_cfg: Optional[Union[dict, ConfigDict]] = None, + norm_cfg: Optional[Union[dict, ConfigDict]] = None, + init_cfg: Optional[Union[dict, ConfigDict]] = None, + *args, + **kwargs) -> None: + super().__init__(*args, init_cfg=init_cfg, **kwargs) + assert (num_shared_convs + num_shared_fcs + num_cls_convs + + num_cls_fcs + num_reg_convs + num_reg_fcs > 0) + if num_cls_convs > 0 or num_reg_convs > 0: + assert num_shared_fcs == 0 + if not self.with_cls: + assert num_cls_convs == 0 and num_cls_fcs == 0 + if not self.with_reg: + assert num_reg_convs == 0 and num_reg_fcs == 0 + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.num_cls_convs = num_cls_convs + self.num_cls_fcs = num_cls_fcs + self.num_reg_convs = num_reg_convs + self.num_reg_fcs = num_reg_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + # add shared convs and fcs + self.shared_convs, self.shared_fcs, last_layer_dim = \ + self._add_conv_fc_branch( + self.num_shared_convs, self.num_shared_fcs, self.in_channels, + True) + self.shared_out_channels = last_layer_dim + + # add cls specific branch + self.cls_convs, self.cls_fcs, self.cls_last_dim = \ + self._add_conv_fc_branch( + self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels) + + # add reg specific branch + self.reg_convs, self.reg_fcs, self.reg_last_dim = \ + self._add_conv_fc_branch( + self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels) + + if self.num_shared_fcs == 0 and not self.with_avg_pool: + if self.num_cls_fcs == 0: + self.cls_last_dim *= self.roi_feat_area + if self.num_reg_fcs == 0: + self.reg_last_dim *= self.roi_feat_area + + self.relu = nn.ReLU(inplace=True) + # reconstruct fc_cls and fc_reg since input channels are changed + if self.with_cls: + if self.custom_cls_channels: + cls_channels = self.loss_cls.get_cls_channels(self.num_classes) + else: + cls_channels = self.num_classes + 1 + cls_predictor_cfg_ = self.cls_predictor_cfg.copy() + cls_predictor_cfg_.update( + in_features=self.cls_last_dim, out_features=cls_channels) + self.fc_cls = MODELS.build(cls_predictor_cfg_) + if self.with_reg: + box_dim = self.bbox_coder.encode_size + out_dim_reg = box_dim if self.reg_class_agnostic else \ + box_dim * self.num_classes + reg_predictor_cfg_ = self.reg_predictor_cfg.copy() + if isinstance(reg_predictor_cfg_, (dict, ConfigDict)): + reg_predictor_cfg_.update( + in_features=self.reg_last_dim, out_features=out_dim_reg) + self.fc_reg = MODELS.build(reg_predictor_cfg_) + + if init_cfg is None: + # when init_cfg is None, + # It has been set to + # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], + # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] + # after `super(ConvFCBBoxHead, self).__init__()` + # we only need to append additional configuration + # for `shared_fcs`, `cls_fcs` and `reg_fcs` + self.init_cfg += [ + dict( + type='Xavier', + distribution='uniform', + override=[ + dict(name='shared_fcs'), + dict(name='cls_fcs'), + dict(name='reg_fcs') + ]) + ] + + def _add_conv_fc_branch(self, + num_branch_convs: int, + num_branch_fcs: int, + in_channels: int, + is_shared: bool = False) -> tuple: + """Add shared or separable branch. + + convs -> avg pool (optional) -> fcs + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_channels = ( + last_layer_dim if i == 0 else self.conv_out_channels) + branch_convs.append( + ConvModule( + conv_in_channels, + self.conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + last_layer_dim = self.conv_out_channels + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + # for shared branch, only consider self.with_avg_pool + # for separated branches, also consider self.num_shared_fcs + if (is_shared + or self.num_shared_fcs == 0) and not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + branch_fcs.append( + nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + return branch_convs, branch_fcs, last_layer_dim + + def forward(self, x: Tuple[Tensor]) -> tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and bbox prediction. + + - cls_score (Tensor): Classification scores for all \ + scale levels, each is a 4D-tensor, the channels number \ + is num_base_priors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for all \ + scale levels, each is a 4D-tensor, the channels number \ + is num_base_priors * 4. + """ + # shared part + if self.num_shared_convs > 0: + for conv in self.shared_convs: + x = conv(x) + + if self.num_shared_fcs > 0: + if self.with_avg_pool: + x = self.avg_pool(x) + + x = x.flatten(1) + + for fc in self.shared_fcs: + x = self.relu(fc(x)) + # separate branches + x_cls = x + x_reg = x + + for conv in self.cls_convs: + x_cls = conv(x_cls) + if x_cls.dim() > 2: + if self.with_avg_pool: + x_cls = self.avg_pool(x_cls) + x_cls = x_cls.flatten(1) + for fc in self.cls_fcs: + x_cls = self.relu(fc(x_cls)) + + for conv in self.reg_convs: + x_reg = conv(x_reg) + if x_reg.dim() > 2: + if self.with_avg_pool: + x_reg = self.avg_pool(x_reg) + x_reg = x_reg.flatten(1) + for fc in self.reg_fcs: + x_reg = self.relu(fc(x_reg)) + + cls_score = self.fc_cls(x_cls) if self.with_cls else None + bbox_pred = self.fc_reg(x_reg) if self.with_reg else None + return cls_score, bbox_pred + + +@MODELS.register_module() +class Shared2FCBBoxHead(ConvFCBBoxHead): + + def __init__(self, fc_out_channels: int = 1024, *args, **kwargs) -> None: + super().__init__( + num_shared_convs=0, + num_shared_fcs=2, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, + **kwargs) + + +@MODELS.register_module() +class Shared4Conv1FCBBoxHead(ConvFCBBoxHead): + + def __init__(self, fc_out_channels: int = 1024, *args, **kwargs) -> None: + super().__init__( + num_shared_convs=4, + num_shared_fcs=1, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, + **kwargs) diff --git a/mmdet/models/roi_heads/bbox_heads/dii_head.py b/mmdet/models/roi_heads/bbox_heads/dii_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9a31bbeb2a8f1da62b457363fa05031d21925a --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/dii_head.py @@ -0,0 +1,422 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.config import ConfigDict +from mmengine.model import bias_init_with_prob +from torch import Tensor + +from mmdet.models.losses import accuracy +from mmdet.models.task_modules import SamplingResult +from mmdet.models.utils import multi_apply +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, reduce_mean +from .bbox_head import BBoxHead + + +@MODELS.register_module() +class DIIHead(BBoxHead): + r"""Dynamic Instance Interactive Head for `Sparse R-CNN: End-to-End Object + Detection with Learnable Proposals `_ + + Args: + num_classes (int): Number of class in dataset. + Defaults to 80. + num_ffn_fcs (int): The number of fully-connected + layers in FFNs. Defaults to 2. + num_heads (int): The hidden dimension of FFNs. + Defaults to 8. + num_cls_fcs (int): The number of fully-connected + layers in classification subnet. Defaults to 1. + num_reg_fcs (int): The number of fully-connected + layers in regression subnet. Defaults to 3. + feedforward_channels (int): The hidden dimension + of FFNs. Defaults to 2048 + in_channels (int): Hidden_channels of MultiheadAttention. + Defaults to 256. + dropout (float): Probability of drop the channel. + Defaults to 0.0 + ffn_act_cfg (:obj:`ConfigDict` or dict): The activation config + for FFNs. + dynamic_conv_cfg (:obj:`ConfigDict` or dict): The convolution + config for DynamicConv. + loss_iou (:obj:`ConfigDict` or dict): The config for iou or + giou loss. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. Defaults to None. + """ + + def __init__(self, + num_classes: int = 80, + num_ffn_fcs: int = 2, + num_heads: int = 8, + num_cls_fcs: int = 1, + num_reg_fcs: int = 3, + feedforward_channels: int = 2048, + in_channels: int = 256, + dropout: float = 0.0, + ffn_act_cfg: ConfigType = dict(type='ReLU', inplace=True), + dynamic_conv_cfg: ConfigType = dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + input_feat_shape=7, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')), + loss_iou: ConfigType = dict(type='GIoULoss', loss_weight=2.0), + init_cfg: OptConfigType = None, + **kwargs) -> None: + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__( + num_classes=num_classes, + reg_decoded_bbox=True, + reg_class_agnostic=True, + init_cfg=init_cfg, + **kwargs) + self.loss_iou = MODELS.build(loss_iou) + self.in_channels = in_channels + self.fp16_enabled = False + self.attention = MultiheadAttention(in_channels, num_heads, dropout) + self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.instance_interactive_conv = MODELS.build(dynamic_conv_cfg) + self.instance_interactive_conv_dropout = nn.Dropout(dropout) + self.instance_interactive_conv_norm = build_norm_layer( + dict(type='LN'), in_channels)[1] + + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + dropout=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.cls_fcs = nn.ModuleList() + for _ in range(num_cls_fcs): + self.cls_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.cls_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.cls_fcs.append( + build_activation_layer(dict(type='ReLU', inplace=True))) + + # over load the self.fc_cls in BBoxHead + if self.loss_cls.use_sigmoid: + self.fc_cls = nn.Linear(in_channels, self.num_classes) + else: + self.fc_cls = nn.Linear(in_channels, self.num_classes + 1) + + self.reg_fcs = nn.ModuleList() + for _ in range(num_reg_fcs): + self.reg_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.reg_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.reg_fcs.append( + build_activation_layer(dict(type='ReLU', inplace=True))) + # over load the self.fc_cls in BBoxHead + self.fc_reg = nn.Linear(in_channels, 4) + + assert self.reg_class_agnostic, 'DIIHead only ' \ + 'suppport `reg_class_agnostic=True` ' + assert self.reg_decoded_bbox, 'DIIHead only ' \ + 'suppport `reg_decoded_bbox=True`' + + def init_weights(self) -> None: + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + super().init_weights() + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.fc_cls.bias, bias_init) + + def forward(self, roi_feat: Tensor, proposal_feat: Tensor) -> tuple: + """Forward function of Dynamic Instance Interactive Head. + + Args: + roi_feat (Tensor): Roi-pooling features with shape + (batch_size*num_proposals, feature_dimensions, + pooling_h , pooling_w). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size, num_proposals, feature_dimensions) + + Returns: + tuple[Tensor]: Usually a tuple of classification scores + and bbox prediction and a intermediate feature. + + - cls_scores (Tensor): Classification scores for + all proposals, has shape + (batch_size, num_proposals, num_classes). + - bbox_preds (Tensor): Box energies / deltas for + all proposals, has shape + (batch_size, num_proposals, 4). + - obj_feat (Tensor): Object feature before classification + and regression subnet, has shape + (batch_size, num_proposal, feature_dimensions). + - attn_feats (Tensor): Intermediate feature. + """ + N, num_proposals = proposal_feat.shape[:2] + + # Self attention + proposal_feat = proposal_feat.permute(1, 0, 2) + proposal_feat = self.attention_norm(self.attention(proposal_feat)) + attn_feats = proposal_feat.permute(1, 0, 2) + + # instance interactive + proposal_feat = attn_feats.reshape(-1, self.in_channels) + proposal_feat_iic = self.instance_interactive_conv( + proposal_feat, roi_feat) + proposal_feat = proposal_feat + self.instance_interactive_conv_dropout( + proposal_feat_iic) + obj_feat = self.instance_interactive_conv_norm(proposal_feat) + + # FFN + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + cls_feat = obj_feat + reg_feat = obj_feat + + for cls_layer in self.cls_fcs: + cls_feat = cls_layer(cls_feat) + for reg_layer in self.reg_fcs: + reg_feat = reg_layer(reg_feat) + + cls_score = self.fc_cls(cls_feat).view( + N, num_proposals, self.num_classes + if self.loss_cls.use_sigmoid else self.num_classes + 1) + bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, 4) + + return cls_score, bbox_delta, obj_feat.view( + N, num_proposals, self.in_channels), attn_feats + + def loss_and_target(self, + cls_score: Tensor, + bbox_pred: Tensor, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigType, + imgs_whwh: Tensor, + concat: bool = True, + reduction_override: str = None) -> dict: + """Calculate the loss based on the features extracted by the DIIHead. + + Args: + cls_score (Tensor): Classification prediction + results of all class, has shape + (batch_size * num_proposals_single_image, num_classes) + bbox_pred (Tensor): Regression prediction results, has shape + (batch_size * num_proposals_single_image, 4), the last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + imgs_whwh (Tensor): imgs_whwh (Tensor): Tensor with\ + shape (batch_size, num_proposals, 4), the last + dimension means + [img_width,img_height, img_width, img_height]. + concat (bool): Whether to concatenate the results of all + the images in a single batch. Defaults to True. + reduction_override (str, optional): The reduction + method used to override the original reduction + method of the loss. Options are "none", + "mean" and "sum". Defaults to None. + + Returns: + dict: A dictionary of loss and targets components. + The targets are only used for cascade rcnn. + """ + cls_reg_targets = self.get_targets( + sampling_results=sampling_results, + rcnn_train_cfg=rcnn_train_cfg, + concat=concat) + (labels, label_weights, bbox_targets, bbox_weights) = cls_reg_targets + + losses = dict() + bg_class_ind = self.num_classes + # note in spare rcnn num_gt == num_pos + pos_inds = (labels >= 0) & (labels < bg_class_ind) + num_pos = pos_inds.sum().float() + avg_factor = reduce_mean(num_pos) + if cls_score is not None: + if cls_score.numel() > 0: + losses['loss_cls'] = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['pos_acc'] = accuracy(cls_score[pos_inds], + labels[pos_inds]) + if bbox_pred is not None: + # 0~self.num_classes-1 are FG, self.num_classes is BG + # do not perform bounding box regression for BG anymore. + if pos_inds.any(): + pos_bbox_pred = bbox_pred.reshape(bbox_pred.size(0), + 4)[pos_inds.type(torch.bool)] + imgs_whwh = imgs_whwh.reshape(bbox_pred.size(0), + 4)[pos_inds.type(torch.bool)] + losses['loss_bbox'] = self.loss_bbox( + pos_bbox_pred / imgs_whwh, + bbox_targets[pos_inds.type(torch.bool)] / imgs_whwh, + bbox_weights[pos_inds.type(torch.bool)], + avg_factor=avg_factor) + losses['loss_iou'] = self.loss_iou( + pos_bbox_pred, + bbox_targets[pos_inds.type(torch.bool)], + bbox_weights[pos_inds.type(torch.bool)], + avg_factor=avg_factor) + else: + losses['loss_bbox'] = bbox_pred.sum() * 0 + losses['loss_iou'] = bbox_pred.sum() * 0 + return dict(loss_bbox=losses, bbox_targets=cls_reg_targets) + + def _get_targets_single(self, pos_inds: Tensor, neg_inds: Tensor, + pos_priors: Tensor, neg_priors: Tensor, + pos_gt_bboxes: Tensor, pos_gt_labels: Tensor, + cfg: ConfigDict) -> tuple: + """Calculate the ground truth for proposals in the single image + according to the sampling results. + + Almost the same as the implementation in `bbox_head`, + we add pos_inds and neg_inds to select positive and + negative samples instead of selecting the first num_pos + as positive samples. + + Args: + pos_inds (Tensor): The length is equal to the + positive sample numbers contain all index + of the positive sample in the origin proposal set. + neg_inds (Tensor): The length is equal to the + negative sample numbers contain all index + of the negative sample in the origin proposal set. + pos_priors (Tensor): Contains all the positive boxes, + has shape (num_pos, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + neg_priors (Tensor): Contains all the negative boxes, + has shape (num_neg, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + pos_gt_bboxes (Tensor): Contains gt_boxes for + all positive samples, has shape (num_pos, 4), + the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + pos_gt_labels (Tensor): Contains gt_labels for + all positive samples, has shape (num_pos, ). + cfg (obj:`ConfigDict`): `train_cfg` of R-CNN. + + Returns: + Tuple[Tensor]: Ground truth for proposals in a single image. + Containing the following Tensors: + + - labels(Tensor): Gt_labels for all proposals, has + shape (num_proposals,). + - label_weights(Tensor): Labels_weights for all proposals, has + shape (num_proposals,). + - bbox_targets(Tensor):Regression target for all proposals, has + shape (num_proposals, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + - bbox_weights(Tensor):Regression weights for all proposals, + has shape (num_proposals, 4). + """ + num_pos = pos_priors.size(0) + num_neg = neg_priors.size(0) + num_samples = num_pos + num_neg + + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = pos_priors.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + label_weights = pos_priors.new_zeros(num_samples) + bbox_targets = pos_priors.new_zeros(num_samples, 4) + bbox_weights = pos_priors.new_zeros(num_samples, 4) + if num_pos > 0: + labels[pos_inds] = pos_gt_labels + pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight + label_weights[pos_inds] = pos_weight + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + pos_priors, pos_gt_bboxes) + else: + pos_bbox_targets = pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1 + if num_neg > 0: + label_weights[neg_inds] = 1.0 + + return labels, label_weights, bbox_targets, bbox_weights + + def get_targets(self, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict, + concat: bool = True) -> tuple: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + + Almost the same as the implementation in bbox_head, we passed + additional parameters pos_inds_list and neg_inds_list to + `_get_targets_single` function. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + concat (bool): Whether to concatenate the results of all + the images in a single batch. + + Returns: + Tuple[Tensor]: Ground truth for proposals in a single image. + Containing the following list of Tensors: + + - labels (list[Tensor],Tensor): Gt_labels for all + proposals in a batch, each tensor in list has + shape (num_proposals,) when `concat=False`, otherwise just + a single tensor has shape (num_all_proposals,). + - label_weights (list[Tensor]): Labels_weights for + all proposals in a batch, each tensor in list has shape + (num_proposals,) when `concat=False`, otherwise just a + single tensor has shape (num_all_proposals,). + - bbox_targets (list[Tensor],Tensor): Regression target + for all proposals in a batch, each tensor in list has + shape (num_proposals, 4) when `concat=False`, otherwise + just a single tensor has shape (num_all_proposals, 4), + the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. + - bbox_weights (list[tensor],Tensor): Regression weights for + all proposals in a batch, each tensor in list has shape + (num_proposals, 4) when `concat=False`, otherwise just a + single tensor has shape (num_all_proposals, 4). + """ + pos_inds_list = [res.pos_inds for res in sampling_results] + neg_inds_list = [res.neg_inds for res in sampling_results] + pos_priors_list = [res.pos_priors for res in sampling_results] + neg_priors_list = [res.neg_priors for res in sampling_results] + pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] + pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] + labels, label_weights, bbox_targets, bbox_weights = multi_apply( + self._get_targets_single, + pos_inds_list, + neg_inds_list, + pos_priors_list, + neg_priors_list, + pos_gt_bboxes_list, + pos_gt_labels_list, + cfg=rcnn_train_cfg) + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + bbox_targets = torch.cat(bbox_targets, 0) + bbox_weights = torch.cat(bbox_weights, 0) + return labels, label_weights, bbox_targets, bbox_weights diff --git a/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..076c35843375c7aef5e58786d55ebacd281d54a3 --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py @@ -0,0 +1,199 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + +from mmdet.models.backbones.resnet import Bottleneck +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, MultiConfig, OptConfigType, OptMultiConfig +from .bbox_head import BBoxHead + + +class BasicResBlock(BaseModule): + """Basic residual block. + + This block is a little different from the block in the ResNet backbone. + The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock. + + Args: + in_channels (int): Channels of the input feature map. + out_channels (int): Channels of the output feature map. + conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict + for convolution layers. + norm_cfg (:obj:`ConfigDict` or dict): The config dict for + normalization layers. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + + # main path + self.conv1 = ConvModule( + in_channels, + in_channels, + kernel_size=3, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + self.conv2 = ConvModule( + in_channels, + out_channels, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + # identity path + self.conv_identity = ConvModule( + in_channels, + out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + identity = x + + x = self.conv1(x) + x = self.conv2(x) + + identity = self.conv_identity(identity) + out = x + identity + + out = self.relu(out) + return out + + +@MODELS.register_module() +class DoubleConvFCBBoxHead(BBoxHead): + r"""Bbox head used in Double-Head R-CNN + + .. code-block:: none + + /-> cls + /-> shared convs -> + \-> reg + roi features + /-> cls + \-> shared fc -> + \-> reg + """ # noqa: W605 + + def __init__(self, + num_convs: int = 0, + num_fcs: int = 0, + conv_out_channels: int = 1024, + fc_out_channels: int = 1024, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN'), + init_cfg: MultiConfig = dict( + type='Normal', + override=[ + dict(type='Normal', name='fc_cls', std=0.01), + dict(type='Normal', name='fc_reg', std=0.001), + dict( + type='Xavier', + name='fc_branch', + distribution='uniform') + ]), + **kwargs) -> None: + kwargs.setdefault('with_avg_pool', True) + super().__init__(init_cfg=init_cfg, **kwargs) + assert self.with_avg_pool + assert num_convs > 0 + assert num_fcs > 0 + self.num_convs = num_convs + self.num_fcs = num_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + # increase the channel of input features + self.res_block = BasicResBlock(self.in_channels, + self.conv_out_channels) + + # add conv heads + self.conv_branch = self._add_conv_branch() + # add fc heads + self.fc_branch = self._add_fc_branch() + + out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes + self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg) + + self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes + 1) + self.relu = nn.ReLU() + + def _add_conv_branch(self) -> None: + """Add the fc branch which consists of a sequential of conv layers.""" + branch_convs = ModuleList() + for i in range(self.num_convs): + branch_convs.append( + Bottleneck( + inplanes=self.conv_out_channels, + planes=self.conv_out_channels // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + return branch_convs + + def _add_fc_branch(self) -> None: + """Add the fc branch which consists of a sequential of fc layers.""" + branch_fcs = ModuleList() + for i in range(self.num_fcs): + fc_in_channels = ( + self.in_channels * + self.roi_feat_area if i == 0 else self.fc_out_channels) + branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels)) + return branch_fcs + + def forward(self, x_cls: Tensor, x_reg: Tensor) -> Tuple[Tensor]: + """Forward features from the upstream network. + + Args: + x_cls (Tensor): Classification features of rois + x_reg (Tensor): Regression features from the upstream network. + + Returns: + tuple: A tuple of classification scores and bbox prediction. + + - cls_score (Tensor): Classification score predictions of rois. + each roi predicts num_classes + 1 channels. + - bbox_pred (Tensor): BBox deltas predictions of rois. each roi + predicts 4 * num_classes channels. + """ + # conv head + x_conv = self.res_block(x_reg) + + for conv in self.conv_branch: + x_conv = conv(x_conv) + + if self.with_avg_pool: + x_conv = self.avg_pool(x_conv) + + x_conv = x_conv.view(x_conv.size(0), -1) + bbox_pred = self.fc_reg(x_conv) + + # fc head + x_fc = x_cls.view(x_cls.size(0), -1) + for fc in self.fc_branch: + x_fc = self.relu(fc(x_fc)) + + cls_score = self.fc_cls(x_fc) + + return cls_score, bbox_pred diff --git a/mmdet/models/roi_heads/bbox_heads/multi_instance_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/multi_instance_bbox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..38e57d2eddd580b13256da63c9bd8723be98e764 --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/multi_instance_bbox_head.py @@ -0,0 +1,626 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor, nn + +from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.utils import empty_instances +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox_overlaps + + +@MODELS.register_module() +class MultiInstanceBBoxHead(BBoxHead): + r"""Bbox head used in CrowdDet. + + .. code-block:: none + + /-> cls convs_1 -> cls fcs_1 -> cls_1 + |-- + | \-> reg convs_1 -> reg fcs_1 -> reg_1 + | + | /-> cls convs_2 -> cls fcs_2 -> cls_2 + shared convs -> shared fcs |-- + | \-> reg convs_2 -> reg fcs_2 -> reg_2 + | + | ... + | + | /-> cls convs_k -> cls fcs_k -> cls_k + |-- + \-> reg convs_k -> reg fcs_k -> reg_k + + + Args: + num_instance (int): The number of branches after shared fcs. + Defaults to 2. + with_refine (bool): Whether to use refine module. Defaults to False. + num_shared_convs (int): The number of shared convs. Defaults to 0. + num_shared_fcs (int): The number of shared fcs. Defaults to 2. + num_cls_convs (int): The number of cls convs. Defaults to 0. + num_cls_fcs (int): The number of cls fcs. Defaults to 0. + num_reg_convs (int): The number of reg convs. Defaults to 0. + num_reg_fcs (int): The number of reg fcs. Defaults to 0. + conv_out_channels (int): The number of conv out channels. + Defaults to 256. + fc_out_channels (int): The number of fc out channels. Defaults to 1024. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: W605 + + def __init__(self, + num_instance: int = 2, + with_refine: bool = False, + num_shared_convs: int = 0, + num_shared_fcs: int = 2, + num_cls_convs: int = 0, + num_cls_fcs: int = 0, + num_reg_convs: int = 0, + num_reg_fcs: int = 0, + conv_out_channels: int = 256, + fc_out_channels: int = 1024, + init_cfg: Optional[Union[dict, ConfigDict]] = None, + *args, + **kwargs) -> None: + super().__init__(*args, init_cfg=init_cfg, **kwargs) + assert (num_shared_convs + num_shared_fcs + num_cls_convs + + num_cls_fcs + num_reg_convs + num_reg_fcs > 0) + assert num_instance == 2, 'Currently only 2 instances are supported' + if num_cls_convs > 0 or num_reg_convs > 0: + assert num_shared_fcs == 0 + if not self.with_cls: + assert num_cls_convs == 0 and num_cls_fcs == 0 + if not self.with_reg: + assert num_reg_convs == 0 and num_reg_fcs == 0 + self.num_instance = num_instance + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.num_cls_convs = num_cls_convs + self.num_cls_fcs = num_cls_fcs + self.num_reg_convs = num_reg_convs + self.num_reg_fcs = num_reg_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.with_refine = with_refine + + # add shared convs and fcs + self.shared_convs, self.shared_fcs, last_layer_dim = \ + self._add_conv_fc_branch( + self.num_shared_convs, self.num_shared_fcs, self.in_channels, + True) + self.shared_out_channels = last_layer_dim + self.relu = nn.ReLU(inplace=True) + + if self.with_refine: + refine_model_cfg = { + 'type': 'Linear', + 'in_features': self.shared_out_channels + 20, + 'out_features': self.shared_out_channels + } + self.shared_fcs_ref = MODELS.build(refine_model_cfg) + self.fc_cls_ref = nn.ModuleList() + self.fc_reg_ref = nn.ModuleList() + + self.cls_convs = nn.ModuleList() + self.cls_fcs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.reg_fcs = nn.ModuleList() + self.cls_last_dim = list() + self.reg_last_dim = list() + self.fc_cls = nn.ModuleList() + self.fc_reg = nn.ModuleList() + for k in range(self.num_instance): + # add cls specific branch + cls_convs, cls_fcs, cls_last_dim = self._add_conv_fc_branch( + self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels) + self.cls_convs.append(cls_convs) + self.cls_fcs.append(cls_fcs) + self.cls_last_dim.append(cls_last_dim) + + # add reg specific branch + reg_convs, reg_fcs, reg_last_dim = self._add_conv_fc_branch( + self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels) + self.reg_convs.append(reg_convs) + self.reg_fcs.append(reg_fcs) + self.reg_last_dim.append(reg_last_dim) + + if self.num_shared_fcs == 0 and not self.with_avg_pool: + if self.num_cls_fcs == 0: + self.cls_last_dim *= self.roi_feat_area + if self.num_reg_fcs == 0: + self.reg_last_dim *= self.roi_feat_area + + if self.with_cls: + if self.custom_cls_channels: + cls_channels = self.loss_cls.get_cls_channels( + self.num_classes) + else: + cls_channels = self.num_classes + 1 + cls_predictor_cfg_ = self.cls_predictor_cfg.copy() # deepcopy + cls_predictor_cfg_.update( + in_features=self.cls_last_dim[k], + out_features=cls_channels) + self.fc_cls.append(MODELS.build(cls_predictor_cfg_)) + if self.with_refine: + self.fc_cls_ref.append(MODELS.build(cls_predictor_cfg_)) + + if self.with_reg: + out_dim_reg = (4 if self.reg_class_agnostic else 4 * + self.num_classes) + reg_predictor_cfg_ = self.reg_predictor_cfg.copy() + reg_predictor_cfg_.update( + in_features=self.reg_last_dim[k], out_features=out_dim_reg) + self.fc_reg.append(MODELS.build(reg_predictor_cfg_)) + if self.with_refine: + self.fc_reg_ref.append(MODELS.build(reg_predictor_cfg_)) + + if init_cfg is None: + # when init_cfg is None, + # It has been set to + # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], + # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] + # after `super(ConvFCBBoxHead, self).__init__()` + # we only need to append additional configuration + # for `shared_fcs`, `cls_fcs` and `reg_fcs` + self.init_cfg += [ + dict( + type='Xavier', + distribution='uniform', + override=[ + dict(name='shared_fcs'), + dict(name='cls_fcs'), + dict(name='reg_fcs') + ]) + ] + + def _add_conv_fc_branch(self, + num_branch_convs: int, + num_branch_fcs: int, + in_channels: int, + is_shared: bool = False) -> tuple: + """Add shared or separable branch. + + convs -> avg pool (optional) -> fcs + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_channels = ( + last_layer_dim if i == 0 else self.conv_out_channels) + branch_convs.append( + ConvModule( + conv_in_channels, self.conv_out_channels, 3, + padding=1)) + last_layer_dim = self.conv_out_channels + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + # for shared branch, only consider self.with_avg_pool + # for separated branches, also consider self.num_shared_fcs + if (is_shared + or self.num_shared_fcs == 0) and not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + branch_fcs.append( + nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + return branch_convs, branch_fcs, last_layer_dim + + def forward(self, x: Tuple[Tensor]) -> tuple: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and bbox prediction. + + - cls_score (Tensor): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + - cls_score_ref (Tensor): The cls_score after refine model. + - bbox_pred_ref (Tensor): The bbox_pred after refine model. + """ + # shared part + if self.num_shared_convs > 0: + for conv in self.shared_convs: + x = conv(x) + + if self.num_shared_fcs > 0: + if self.with_avg_pool: + x = self.avg_pool(x) + + x = x.flatten(1) + for fc in self.shared_fcs: + x = self.relu(fc(x)) + + x_cls = x + x_reg = x + # separate branches + cls_score = list() + bbox_pred = list() + for k in range(self.num_instance): + for conv in self.cls_convs[k]: + x_cls = conv(x_cls) + if x_cls.dim() > 2: + if self.with_avg_pool: + x_cls = self.avg_pool(x_cls) + x_cls = x_cls.flatten(1) + for fc in self.cls_fcs[k]: + x_cls = self.relu(fc(x_cls)) + + for conv in self.reg_convs[k]: + x_reg = conv(x_reg) + if x_reg.dim() > 2: + if self.with_avg_pool: + x_reg = self.avg_pool(x_reg) + x_reg = x_reg.flatten(1) + for fc in self.reg_fcs[k]: + x_reg = self.relu(fc(x_reg)) + + cls_score.append(self.fc_cls[k](x_cls) if self.with_cls else None) + bbox_pred.append(self.fc_reg[k](x_reg) if self.with_reg else None) + + if self.with_refine: + x_ref = x + cls_score_ref = list() + bbox_pred_ref = list() + for k in range(self.num_instance): + feat_ref = cls_score[k].softmax(dim=-1) + feat_ref = torch.cat((bbox_pred[k], feat_ref[:, 1][:, None]), + dim=1).repeat(1, 4) + feat_ref = torch.cat((x_ref, feat_ref), dim=1) + feat_ref = F.relu_(self.shared_fcs_ref(feat_ref)) + + cls_score_ref.append(self.fc_cls_ref[k](feat_ref)) + bbox_pred_ref.append(self.fc_reg_ref[k](feat_ref)) + + cls_score = torch.cat(cls_score, dim=1) + bbox_pred = torch.cat(bbox_pred, dim=1) + cls_score_ref = torch.cat(cls_score_ref, dim=1) + bbox_pred_ref = torch.cat(bbox_pred_ref, dim=1) + return cls_score, bbox_pred, cls_score_ref, bbox_pred_ref + + cls_score = torch.cat(cls_score, dim=1) + bbox_pred = torch.cat(bbox_pred, dim=1) + + return cls_score, bbox_pred + + def get_targets(self, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict, + concat: bool = True) -> tuple: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + + Almost the same as the implementation in bbox_head, we passed + additional parameters pos_inds_list and neg_inds_list to + `_get_targets_single` function. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + concat (bool): Whether to concatenate the results of all + the images in a single batch. + + Returns: + Tuple[Tensor]: Ground truth for proposals in a single image. + Containing the following list of Tensors: + + - labels (list[Tensor],Tensor): Gt_labels for all proposals in a + batch, each tensor in list has shape (num_proposals,) when + `concat=False`, otherwise just a single tensor has shape + (num_all_proposals,). + - label_weights (list[Tensor]): Labels_weights for + all proposals in a batch, each tensor in list has shape + (num_proposals,) when `concat=False`, otherwise just a single + tensor has shape (num_all_proposals,). + - bbox_targets (list[Tensor],Tensor): Regression target for all + proposals in a batch, each tensor in list has shape + (num_proposals, 4) when `concat=False`, otherwise just a single + tensor has shape (num_all_proposals, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + - bbox_weights (list[tensor],Tensor): Regression weights for + all proposals in a batch, each tensor in list has shape + (num_proposals, 4) when `concat=False`, otherwise just a + single tensor has shape (num_all_proposals, 4). + """ + labels = [] + bbox_targets = [] + bbox_weights = [] + label_weights = [] + for i in range(len(sampling_results)): + sample_bboxes = torch.cat([ + sampling_results[i].pos_gt_bboxes, + sampling_results[i].neg_gt_bboxes + ]) + sample_priors = sampling_results[i].priors + sample_priors = sample_priors.repeat(1, self.num_instance).reshape( + -1, 4) + sample_bboxes = sample_bboxes.reshape(-1, 4) + + if not self.reg_decoded_bbox: + _bbox_targets = self.bbox_coder.encode(sample_priors, + sample_bboxes) + else: + _bbox_targets = sample_priors + _bbox_targets = _bbox_targets.reshape(-1, self.num_instance * 4) + _bbox_weights = torch.ones(_bbox_targets.shape) + _labels = torch.cat([ + sampling_results[i].pos_gt_labels, + sampling_results[i].neg_gt_labels + ]) + _labels_weights = torch.ones(_labels.shape) + + bbox_targets.append(_bbox_targets) + bbox_weights.append(_bbox_weights) + labels.append(_labels) + label_weights.append(_labels_weights) + + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + bbox_targets = torch.cat(bbox_targets, 0) + bbox_weights = torch.cat(bbox_weights, 0) + return labels, label_weights, bbox_targets, bbox_weights + + def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, + labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, **kwargs) -> dict: + """Calculate the loss based on the network predictions and targets. + + Args: + cls_score (Tensor): Classification prediction results of all class, + has shape (batch_size * num_proposals_single_image, + (num_classes + 1) * k), k represents the number of prediction + boxes generated by each proposal box. + bbox_pred (Tensor): Regression prediction results, has shape + (batch_size * num_proposals_single_image, 4 * k), the last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + rois (Tensor): RoIs with the shape + (batch_size * num_proposals_single_image, 5) where the first + column indicates batch id of each RoI. + labels (Tensor): Gt_labels for all proposals in a batch, has + shape (batch_size * num_proposals_single_image, k). + label_weights (Tensor): Labels_weights for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, k). + bbox_targets (Tensor): Regression target for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, + 4 * k), the last dimension 4 represents [tl_x, tl_y, br_x, + br_y]. + bbox_weights (Tensor): Regression weights for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, + 4 * k). + + Returns: + dict: A dictionary of loss. + """ + losses = dict() + if bbox_pred.numel(): + loss_0 = self.emd_loss(bbox_pred[:, 0:4], cls_score[:, 0:2], + bbox_pred[:, 4:8], cls_score[:, 2:4], + bbox_targets, labels) + loss_1 = self.emd_loss(bbox_pred[:, 4:8], cls_score[:, 2:4], + bbox_pred[:, 0:4], cls_score[:, 0:2], + bbox_targets, labels) + loss = torch.cat([loss_0, loss_1], dim=1) + _, min_indices = loss.min(dim=1) + loss_emd = loss[torch.arange(loss.shape[0]), min_indices] + loss_emd = loss_emd.mean() + else: + loss_emd = bbox_pred.sum() + losses['loss_rcnn_emd'] = loss_emd + return losses + + def emd_loss(self, bbox_pred_0: Tensor, cls_score_0: Tensor, + bbox_pred_1: Tensor, cls_score_1: Tensor, targets: Tensor, + labels: Tensor) -> Tensor: + """Calculate the emd loss. + + Note: + This implementation is modified from https://github.com/Purkialo/ + CrowdDet/blob/master/lib/det_oprs/loss_opr.py + + Args: + bbox_pred_0 (Tensor): Part of regression prediction results, has + shape (batch_size * num_proposals_single_image, 4), the last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + cls_score_0 (Tensor): Part of classification prediction results, + has shape (batch_size * num_proposals_single_image, + (num_classes + 1)), where 1 represents the background. + bbox_pred_1 (Tensor): The other part of regression prediction + results, has shape (batch_size*num_proposals_single_image, 4). + cls_score_1 (Tensor):The other part of classification prediction + results, has shape (batch_size * num_proposals_single_image, + (num_classes + 1)). + targets (Tensor):Regression target for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, + 4 * k), the last dimension 4 represents [tl_x, tl_y, br_x, + br_y], k represents the number of prediction boxes generated + by each proposal box. + labels (Tensor): Gt_labels for all proposals in a batch, has + shape (batch_size * num_proposals_single_image, k). + + Returns: + torch.Tensor: The calculated loss. + """ + + bbox_pred = torch.cat([bbox_pred_0, bbox_pred_1], + dim=1).reshape(-1, bbox_pred_0.shape[-1]) + cls_score = torch.cat([cls_score_0, cls_score_1], + dim=1).reshape(-1, cls_score_0.shape[-1]) + targets = targets.reshape(-1, 4) + labels = labels.long().flatten() + + # masks + valid_masks = labels >= 0 + fg_masks = labels > 0 + + # multiple class + bbox_pred = bbox_pred.reshape(-1, self.num_classes, 4) + fg_gt_classes = labels[fg_masks] + bbox_pred = bbox_pred[fg_masks, fg_gt_classes - 1, :] + + # loss for regression + loss_bbox = self.loss_bbox(bbox_pred, targets[fg_masks]) + loss_bbox = loss_bbox.sum(dim=1) + + # loss for classification + labels = labels * valid_masks + loss_cls = self.loss_cls(cls_score, labels) + + loss_cls[fg_masks] = loss_cls[fg_masks] + loss_bbox + loss = loss_cls.reshape(-1, 2).sum(dim=1) + return loss.reshape(-1, 1) + + def _predict_by_feat_single( + self, + roi: Tensor, + cls_score: Tensor, + bbox_pred: Tensor, + img_meta: dict, + rescale: bool = False, + rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5). + last dimension 5 arrange as (batch_index, x1, y1, x2, y2). + cls_score (Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tensor): Box energies / deltas. has shape + (num_boxes, num_classes * 4). + img_meta (dict): image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None + + Returns: + :obj:`InstanceData`: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + + cls_score = cls_score.reshape(-1, self.num_classes + 1) + bbox_pred = bbox_pred.reshape(-1, 4) + roi = roi.repeat_interleave(self.num_instance, dim=0) + + results = InstanceData() + if roi.shape[0] == 0: + return empty_instances([img_meta], + roi.device, + task_type='bbox', + instance_results=[results])[0] + + scores = cls_score.softmax(dim=-1) if cls_score is not None else None + img_shape = img_meta['img_shape'] + bboxes = self.bbox_coder.decode( + roi[..., 1:], bbox_pred, max_shape=img_shape) + + if rescale and bboxes.size(0) > 0: + assert img_meta.get('scale_factor') is not None + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view( + bboxes.size()[0], -1) + + if rcnn_test_cfg is None: + # This means that it is aug test. + # It needs to return the raw results without nms. + results.bboxes = bboxes + results.scores = scores + else: + roi_idx = np.tile( + np.arange(bboxes.shape[0] / self.num_instance)[:, None], + (1, self.num_instance)).reshape(-1, 1)[:, 0] + roi_idx = torch.from_numpy(roi_idx).to(bboxes.device).reshape( + -1, 1) + bboxes = torch.cat([bboxes, roi_idx], dim=1) + det_bboxes, det_scores = self.set_nms( + bboxes, scores[:, 1], rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms['iou_threshold'], rcnn_test_cfg.max_per_img) + + results.bboxes = det_bboxes[:, :-1] + results.scores = det_scores + results.labels = torch.zeros_like(det_scores) + + return results + + @staticmethod + def set_nms(bboxes: Tensor, + scores: Tensor, + score_thr: float, + iou_threshold: float, + max_num: int = -1) -> Tuple[Tensor, Tensor]: + """NMS for multi-instance prediction. Please refer to + https://github.com/Purkialo/CrowdDet for more details. + + Args: + bboxes (Tensor): predict bboxes. + scores (Tensor): The score of each predict bbox. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + iou_threshold (float): IoU threshold to be considered as + conflicted. + max_num (int, optional): if there are more than max_num bboxes + after NMS, only top max_num will be kept. Default to -1. + + Returns: + Tuple[Tensor, Tensor]: (bboxes, scores). + """ + + bboxes = bboxes[scores > score_thr] + scores = scores[scores > score_thr] + + ordered_scores, order = scores.sort(descending=True) + ordered_bboxes = bboxes[order] + roi_idx = ordered_bboxes[:, -1] + + keep = torch.ones(len(ordered_bboxes)) == 1 + ruler = torch.arange(len(ordered_bboxes)) + + keep = keep.to(bboxes.device) + ruler = ruler.to(bboxes.device) + + while ruler.shape[0] > 0: + basement = ruler[0] + ruler = ruler[1:] + idx = roi_idx[basement] + # calculate the body overlap + basement_bbox = ordered_bboxes[:, :4][basement].reshape(-1, 4) + ruler_bbox = ordered_bboxes[:, :4][ruler].reshape(-1, 4) + overlap = bbox_overlaps(basement_bbox, ruler_bbox) + indices = torch.where(overlap > iou_threshold)[1] + loc = torch.where(roi_idx[ruler][indices] == idx) + # the mask won't change in the step + mask = keep[ruler[indices][loc]] + keep[ruler[indices]] = False + keep[ruler[indices][loc][mask]] = True + ruler[~keep[ruler]] = -1 + ruler = ruler[ruler > 0] + + keep = keep[order.sort()[1]] + return bboxes[keep][:max_num, :], scores[keep][:max_num] diff --git a/mmdet/models/roi_heads/bbox_heads/sabl_head.py b/mmdet/models/roi_heads/bbox_heads/sabl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9ee6aba9669514ec8ce7218e8c97e026830f6c --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/sabl_head.py @@ -0,0 +1,684 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.layers import multiclass_nms +from mmdet.models.losses import accuracy +from mmdet.models.task_modules import SamplingResult +from mmdet.models.utils import multi_apply +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig +from .bbox_head import BBoxHead + + +@MODELS.register_module() +class SABLHead(BBoxHead): + """Side-Aware Boundary Localization (SABL) for RoI-Head. + + Side-Aware features are extracted by conv layers + with an attention mechanism. + Boundary Localization with Bucketing and Bucketing Guided Rescoring + are implemented in BucketingBBoxCoder. + + Please refer to https://arxiv.org/abs/1912.04260 for more details. + + Args: + cls_in_channels (int): Input channels of cls RoI feature. \ + Defaults to 256. + reg_in_channels (int): Input channels of reg RoI feature. \ + Defaults to 256. + roi_feat_size (int): Size of RoI features. Defaults to 7. + reg_feat_up_ratio (int): Upsample ratio of reg features. \ + Defaults to 2. + reg_pre_kernel (int): Kernel of 2D conv layers before \ + attention pooling. Defaults to 3. + reg_post_kernel (int): Kernel of 1D conv layers after \ + attention pooling. Defaults to 3. + reg_pre_num (int): Number of pre convs. Defaults to 2. + reg_post_num (int): Number of post convs. Defaults to 1. + num_classes (int): Number of classes in dataset. Defaults to 80. + cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024. + reg_offset_out_channels (int): Hidden and output channel \ + of reg offset branch. Defaults to 256. + reg_cls_out_channels (int): Hidden and output channel \ + of reg cls branch. Defaults to 256. + num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1. + num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0. + reg_class_agnostic (bool): Class agnostic regression or not. \ + Defaults to True. + norm_cfg (dict): Config of norm layers. Defaults to None. + bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'. + loss_cls (dict): Config of classification loss. + loss_bbox_cls (dict): Config of classification loss for bbox branch. + loss_bbox_reg (dict): Config of regression loss for bbox branch. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + cls_in_channels: int = 256, + reg_in_channels: int = 256, + roi_feat_size: int = 7, + reg_feat_up_ratio: int = 2, + reg_pre_kernel: int = 3, + reg_post_kernel: int = 3, + reg_pre_num: int = 2, + reg_post_num: int = 1, + cls_out_channels: int = 1024, + reg_offset_out_channels: int = 256, + reg_cls_out_channels: int = 256, + num_cls_fcs: int = 1, + num_reg_fcs: int = 0, + reg_class_agnostic: bool = True, + norm_cfg: OptConfigType = None, + bbox_coder: ConfigType = dict( + type='BucketingBBoxCoder', + num_buckets=14, + scale_factor=1.7), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_bbox_reg: ConfigType = dict( + type='SmoothL1Loss', beta=0.1, loss_weight=1.0), + init_cfg: OptMultiConfig = None) -> None: + super(BBoxHead, self).__init__(init_cfg=init_cfg) + self.cls_in_channels = cls_in_channels + self.reg_in_channels = reg_in_channels + self.roi_feat_size = roi_feat_size + self.reg_feat_up_ratio = int(reg_feat_up_ratio) + self.num_buckets = bbox_coder['num_buckets'] + assert self.reg_feat_up_ratio // 2 >= 1 + self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio + assert self.up_reg_feat_size == bbox_coder['num_buckets'] + self.reg_pre_kernel = reg_pre_kernel + self.reg_post_kernel = reg_post_kernel + self.reg_pre_num = reg_pre_num + self.reg_post_num = reg_post_num + self.num_classes = num_classes + self.cls_out_channels = cls_out_channels + self.reg_offset_out_channels = reg_offset_out_channels + self.reg_cls_out_channels = reg_cls_out_channels + self.num_cls_fcs = num_cls_fcs + self.num_reg_fcs = num_reg_fcs + self.reg_class_agnostic = reg_class_agnostic + assert self.reg_class_agnostic + self.norm_cfg = norm_cfg + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox_cls = MODELS.build(loss_bbox_cls) + self.loss_bbox_reg = MODELS.build(loss_bbox_reg) + + self.cls_fcs = self._add_fc_branch(self.num_cls_fcs, + self.cls_in_channels, + self.roi_feat_size, + self.cls_out_channels) + + self.side_num = int(np.ceil(self.num_buckets / 2)) + + if self.reg_feat_up_ratio > 1: + self.upsample_x = nn.ConvTranspose1d( + reg_in_channels, + reg_in_channels, + self.reg_feat_up_ratio, + stride=self.reg_feat_up_ratio) + self.upsample_y = nn.ConvTranspose1d( + reg_in_channels, + reg_in_channels, + self.reg_feat_up_ratio, + stride=self.reg_feat_up_ratio) + + self.reg_pre_convs = nn.ModuleList() + for i in range(self.reg_pre_num): + reg_pre_conv = ConvModule( + reg_in_channels, + reg_in_channels, + kernel_size=reg_pre_kernel, + padding=reg_pre_kernel // 2, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.reg_pre_convs.append(reg_pre_conv) + + self.reg_post_conv_xs = nn.ModuleList() + for i in range(self.reg_post_num): + reg_post_conv_x = ConvModule( + reg_in_channels, + reg_in_channels, + kernel_size=(1, reg_post_kernel), + padding=(0, reg_post_kernel // 2), + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.reg_post_conv_xs.append(reg_post_conv_x) + self.reg_post_conv_ys = nn.ModuleList() + for i in range(self.reg_post_num): + reg_post_conv_y = ConvModule( + reg_in_channels, + reg_in_channels, + kernel_size=(reg_post_kernel, 1), + padding=(reg_post_kernel // 2, 0), + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.reg_post_conv_ys.append(reg_post_conv_y) + + self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1) + self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1) + + self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1) + self.relu = nn.ReLU(inplace=True) + + self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs, + self.reg_in_channels, 1, + self.reg_cls_out_channels) + self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs, + self.reg_in_channels, 1, + self.reg_offset_out_channels) + self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1) + self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1) + + if init_cfg is None: + self.init_cfg = [ + dict( + type='Xavier', + layer='Linear', + distribution='uniform', + override=[ + dict(type='Normal', name='reg_conv_att_x', std=0.01), + dict(type='Normal', name='reg_conv_att_y', std=0.01), + dict(type='Normal', name='fc_reg_cls', std=0.01), + dict(type='Normal', name='fc_cls', std=0.01), + dict(type='Normal', name='fc_reg_offset', std=0.001) + ]) + ] + if self.reg_feat_up_ratio > 1: + self.init_cfg += [ + dict( + type='Kaiming', + distribution='normal', + override=[ + dict(name='upsample_x'), + dict(name='upsample_y') + ]) + ] + + def _add_fc_branch(self, num_branch_fcs: int, in_channels: int, + roi_feat_size: int, + fc_out_channels: int) -> nn.ModuleList: + """build fc layers.""" + in_channels = in_channels * roi_feat_size * roi_feat_size + branch_fcs = nn.ModuleList() + for i in range(num_branch_fcs): + fc_in_channels = (in_channels if i == 0 else fc_out_channels) + branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels)) + return branch_fcs + + def cls_forward(self, cls_x: Tensor) -> Tensor: + """forward of classification fc layers.""" + cls_x = cls_x.view(cls_x.size(0), -1) + for fc in self.cls_fcs: + cls_x = self.relu(fc(cls_x)) + cls_score = self.fc_cls(cls_x) + return cls_score + + def attention_pool(self, reg_x: Tensor) -> tuple: + """Extract direction-specific features fx and fy with attention + methanism.""" + reg_fx = reg_x + reg_fy = reg_x + reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid() + reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid() + reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2) + reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3) + reg_fx = (reg_fx * reg_fx_att).sum(dim=2) + reg_fy = (reg_fy * reg_fy_att).sum(dim=3) + return reg_fx, reg_fy + + def side_aware_feature_extractor(self, reg_x: Tensor) -> tuple: + """Refine and extract side-aware features without split them.""" + for reg_pre_conv in self.reg_pre_convs: + reg_x = reg_pre_conv(reg_x) + reg_fx, reg_fy = self.attention_pool(reg_x) + + if self.reg_post_num > 0: + reg_fx = reg_fx.unsqueeze(2) + reg_fy = reg_fy.unsqueeze(3) + for i in range(self.reg_post_num): + reg_fx = self.reg_post_conv_xs[i](reg_fx) + reg_fy = self.reg_post_conv_ys[i](reg_fy) + reg_fx = reg_fx.squeeze(2) + reg_fy = reg_fy.squeeze(3) + if self.reg_feat_up_ratio > 1: + reg_fx = self.relu(self.upsample_x(reg_fx)) + reg_fy = self.relu(self.upsample_y(reg_fy)) + reg_fx = torch.transpose(reg_fx, 1, 2) + reg_fy = torch.transpose(reg_fy, 1, 2) + return reg_fx.contiguous(), reg_fy.contiguous() + + def reg_pred(self, x: Tensor, offset_fcs: nn.ModuleList, + cls_fcs: nn.ModuleList) -> tuple: + """Predict bucketing estimation (cls_pred) and fine regression (offset + pred) with side-aware features.""" + x_offset = x.view(-1, self.reg_in_channels) + x_cls = x.view(-1, self.reg_in_channels) + + for fc in offset_fcs: + x_offset = self.relu(fc(x_offset)) + for fc in cls_fcs: + x_cls = self.relu(fc(x_cls)) + offset_pred = self.fc_reg_offset(x_offset) + cls_pred = self.fc_reg_cls(x_cls) + + offset_pred = offset_pred.view(x.size(0), -1) + cls_pred = cls_pred.view(x.size(0), -1) + + return offset_pred, cls_pred + + def side_aware_split(self, feat: Tensor) -> Tensor: + """Split side-aware features aligned with orders of bucketing + targets.""" + l_end = int(np.ceil(self.up_reg_feat_size / 2)) + r_start = int(np.floor(self.up_reg_feat_size / 2)) + feat_fl = feat[:, :l_end] + feat_fr = feat[:, r_start:].flip(dims=(1, )) + feat_fl = feat_fl.contiguous() + feat_fr = feat_fr.contiguous() + feat = torch.cat([feat_fl, feat_fr], dim=-1) + return feat + + def bbox_pred_split(self, bbox_pred: tuple, + num_proposals_per_img: Sequence[int]) -> tuple: + """Split batch bbox prediction back to each image.""" + bucket_cls_preds, bucket_offset_preds = bbox_pred + bucket_cls_preds = bucket_cls_preds.split(num_proposals_per_img, 0) + bucket_offset_preds = bucket_offset_preds.split( + num_proposals_per_img, 0) + bbox_pred = tuple(zip(bucket_cls_preds, bucket_offset_preds)) + return bbox_pred + + def reg_forward(self, reg_x: Tensor) -> tuple: + """forward of regression branch.""" + outs = self.side_aware_feature_extractor(reg_x) + edge_offset_preds = [] + edge_cls_preds = [] + reg_fx = outs[0] + reg_fy = outs[1] + offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs, + self.reg_cls_fcs) + offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs, + self.reg_cls_fcs) + offset_pred_x = self.side_aware_split(offset_pred_x) + offset_pred_y = self.side_aware_split(offset_pred_y) + cls_pred_x = self.side_aware_split(cls_pred_x) + cls_pred_y = self.side_aware_split(cls_pred_y) + edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1) + edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1) + + return edge_cls_preds, edge_offset_preds + + def forward(self, x: Tensor) -> tuple: + """Forward features from the upstream network.""" + bbox_pred = self.reg_forward(x) + cls_score = self.cls_forward(x) + + return cls_score, bbox_pred + + def get_targets(self, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict, + concat: bool = True) -> tuple: + """Calculate the ground truth for all samples in a batch according to + the sampling_results.""" + pos_proposals = [res.pos_bboxes for res in sampling_results] + neg_proposals = [res.neg_bboxes for res in sampling_results] + pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results] + pos_gt_labels = [res.pos_gt_labels for res in sampling_results] + cls_reg_targets = self.bucket_target( + pos_proposals, + neg_proposals, + pos_gt_bboxes, + pos_gt_labels, + rcnn_train_cfg, + concat=concat) + (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) = cls_reg_targets + return (labels, label_weights, (bucket_cls_targets, + bucket_offset_targets), + (bucket_cls_weights, bucket_offset_weights)) + + def bucket_target(self, + pos_proposals_list: list, + neg_proposals_list: list, + pos_gt_bboxes_list: list, + pos_gt_labels_list: list, + rcnn_train_cfg: ConfigDict, + concat: bool = True) -> tuple: + """Compute bucketing estimation targets and fine regression targets for + a batch of images.""" + (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) = multi_apply( + self._bucket_target_single, + pos_proposals_list, + neg_proposals_list, + pos_gt_bboxes_list, + pos_gt_labels_list, + cfg=rcnn_train_cfg) + + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + bucket_cls_targets = torch.cat(bucket_cls_targets, 0) + bucket_cls_weights = torch.cat(bucket_cls_weights, 0) + bucket_offset_targets = torch.cat(bucket_offset_targets, 0) + bucket_offset_weights = torch.cat(bucket_offset_weights, 0) + return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) + + def _bucket_target_single(self, pos_proposals: Tensor, + neg_proposals: Tensor, pos_gt_bboxes: Tensor, + pos_gt_labels: Tensor, cfg: ConfigDict) -> tuple: + """Compute bucketing estimation targets and fine regression targets for + a single image. + + Args: + pos_proposals (Tensor): positive proposals of a single image, + Shape (n_pos, 4) + neg_proposals (Tensor): negative proposals of a single image, + Shape (n_neg, 4). + pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals + of a single image, Shape (n_pos, 4). + pos_gt_labels (Tensor): gt labels assigned to positive proposals + of a single image, Shape (n_pos, ). + cfg (dict): Config of calculating targets + + Returns: + tuple: + + - labels (Tensor): Labels in a single image. Shape (n,). + - label_weights (Tensor): Label weights in a single image. + Shape (n,) + - bucket_cls_targets (Tensor): Bucket cls targets in + a single image. Shape (n, num_buckets*2). + - bucket_cls_weights (Tensor): Bucket cls weights in + a single image. Shape (n, num_buckets*2). + - bucket_offset_targets (Tensor): Bucket offset targets + in a single image. Shape (n, num_buckets*2). + - bucket_offset_targets (Tensor): Bucket offset weights + in a single image. Shape (n, num_buckets*2). + """ + num_pos = pos_proposals.size(0) + num_neg = neg_proposals.size(0) + num_samples = num_pos + num_neg + labels = pos_gt_bboxes.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + label_weights = pos_proposals.new_zeros(num_samples) + bucket_cls_targets = pos_proposals.new_zeros(num_samples, + 4 * self.side_num) + bucket_cls_weights = pos_proposals.new_zeros(num_samples, + 4 * self.side_num) + bucket_offset_targets = pos_proposals.new_zeros( + num_samples, 4 * self.side_num) + bucket_offset_weights = pos_proposals.new_zeros( + num_samples, 4 * self.side_num) + if num_pos > 0: + labels[:num_pos] = pos_gt_labels + label_weights[:num_pos] = 1.0 + (pos_bucket_offset_targets, pos_bucket_offset_weights, + pos_bucket_cls_targets, + pos_bucket_cls_weights) = self.bbox_coder.encode( + pos_proposals, pos_gt_bboxes) + bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets + bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights + bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets + bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights + if num_neg > 0: + label_weights[-num_neg:] = 1.0 + return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) + + def loss(self, + cls_score: Tensor, + bbox_pred: Tuple[Tensor, Tensor], + rois: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tuple[Tensor, Tensor], + bbox_weights: Tuple[Tensor, Tensor], + reduction_override: Optional[str] = None) -> dict: + """Calculate the loss based on the network predictions and targets. + + Args: + cls_score (Tensor): Classification prediction + results of all class, has shape + (batch_size * num_proposals_single_image, num_classes) + bbox_pred (Tensor): A tuple of regression prediction results + containing `bucket_cls_preds and` `bucket_offset_preds`. + rois (Tensor): RoIs with the shape + (batch_size * num_proposals_single_image, 5) where the first + column indicates batch id of each RoI. + labels (Tensor): Gt_labels for all proposals in a batch, has + shape (batch_size * num_proposals_single_image, ). + label_weights (Tensor): Labels_weights for all proposals in a + batch, has shape (batch_size * num_proposals_single_image, ). + bbox_targets (Tuple[Tensor, Tensor]): A tuple of regression target + containing `bucket_cls_targets` and `bucket_offset_targets`. + the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. + bbox_weights (Tuple[Tensor, Tensor]): A tuple of regression + weights containing `bucket_cls_weights` and + `bucket_offset_weights`. + reduction_override (str, optional): The reduction + method used to override the original reduction + method of the loss. Options are "none", + "mean" and "sum". Defaults to None, + + Returns: + dict: A dictionary of loss. + """ + losses = dict() + if cls_score is not None: + avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) + losses['loss_cls'] = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['acc'] = accuracy(cls_score, labels) + + if bbox_pred is not None: + bucket_cls_preds, bucket_offset_preds = bbox_pred + bucket_cls_targets, bucket_offset_targets = bbox_targets + bucket_cls_weights, bucket_offset_weights = bbox_weights + # edge cls + bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num) + bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num) + bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num) + losses['loss_bbox_cls'] = self.loss_bbox_cls( + bucket_cls_preds, + bucket_cls_targets, + bucket_cls_weights, + avg_factor=bucket_cls_targets.size(0), + reduction_override=reduction_override) + + losses['loss_bbox_reg'] = self.loss_bbox_reg( + bucket_offset_preds, + bucket_offset_targets, + bucket_offset_weights, + avg_factor=bucket_offset_targets.size(0), + reduction_override=reduction_override) + + return losses + + def _predict_by_feat_single( + self, + roi: Tensor, + cls_score: Tensor, + bbox_pred: Tuple[Tensor, Tensor], + img_meta: dict, + rescale: bool = False, + rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5). + last dimension 5 arrange as (batch_index, x1, y1, x2, y2). + cls_score (Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tuple[Tensor, Tensor]): Box cls preds and offset preds. + img_meta (dict): image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None + + Returns: + :obj:`InstanceData`: Detection results of each image + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + results = InstanceData() + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + scores = F.softmax(cls_score, dim=1) if cls_score is not None else None + img_shape = img_meta['img_shape'] + if bbox_pred is not None: + bboxes, confidences = self.bbox_coder.decode( + roi[:, 1:], bbox_pred, img_shape) + else: + bboxes = roi[:, 1:].clone() + confidences = None + if img_shape is not None: + bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1) + bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1) + + if rescale and bboxes.size(0) > 0: + assert img_meta.get('scale_factor') is not None + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view( + bboxes.size()[0], -1) + + if rcnn_test_cfg is None: + results.bboxes = bboxes + results.scores = scores + else: + det_bboxes, det_labels = multiclass_nms( + bboxes, + scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img, + score_factors=confidences) + results.bboxes = det_bboxes[:, :4] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + return results + + def refine_bboxes(self, sampling_results: List[SamplingResult], + bbox_results: dict, + batch_img_metas: List[dict]) -> InstanceList: + """Refine bboxes during training. + + Args: + sampling_results (List[:obj:`SamplingResult`]): Sampling results. + bbox_results (dict): Usually is a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `rois` (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + - `bbox_targets` (tuple): Ground truth for proposals in a + single image. Containing the following list of Tensors: + (labels, label_weights, bbox_targets, bbox_weights) + batch_img_metas (List[dict]): List of image information. + + Returns: + list[:obj:`InstanceData`]: Refined bboxes of each image. + """ + pos_is_gts = [res.pos_is_gt for res in sampling_results] + # bbox_targets is a tuple + labels = bbox_results['bbox_targets'][0] + cls_scores = bbox_results['cls_score'] + rois = bbox_results['rois'] + bbox_preds = bbox_results['bbox_pred'] + + if cls_scores.numel() == 0: + return None + + labels = torch.where(labels == self.num_classes, + cls_scores[:, :-1].argmax(1), labels) + + img_ids = rois[:, 0].long().unique(sorted=True) + assert img_ids.numel() <= len(batch_img_metas) + + results_list = [] + for i in range(len(batch_img_metas)): + inds = torch.nonzero( + rois[:, 0] == i, as_tuple=False).squeeze(dim=1) + num_rois = inds.numel() + + bboxes_ = rois[inds, 1:] + label_ = labels[inds] + edge_cls_preds, edge_offset_preds = bbox_preds + edge_cls_preds_ = edge_cls_preds[inds] + edge_offset_preds_ = edge_offset_preds[inds] + bbox_pred_ = (edge_cls_preds_, edge_offset_preds_) + img_meta_ = batch_img_metas[i] + pos_is_gts_ = pos_is_gts[i] + + bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, + img_meta_) + # filter gt bboxes + pos_keep = 1 - pos_is_gts_ + keep_inds = pos_is_gts_.new_ones(num_rois) + keep_inds[:len(pos_is_gts_)] = pos_keep + results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)]) + results_list.append(results) + + return results_list + + def regress_by_class(self, rois: Tensor, label: Tensor, bbox_pred: tuple, + img_meta: dict) -> Tensor: + """Regress the bbox for the predicted class. Used in Cascade R-CNN. + + Args: + rois (Tensor): shape (n, 4) or (n, 5) + label (Tensor): shape (n, ) + bbox_pred (Tuple[Tensor]): shape [(n, num_buckets *2), \ + (n, num_buckets *2)] + img_meta (dict): Image meta info. + + Returns: + Tensor: Regressed bboxes, the same shape as input rois. + """ + assert rois.size(1) == 4 or rois.size(1) == 5 + + if rois.size(1) == 4: + new_rois, _ = self.bbox_coder.decode(rois, bbox_pred, + img_meta['img_shape']) + else: + bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred, + img_meta['img_shape']) + new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) + + return new_rois diff --git a/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..790b08fb207970927c7925cb8b3fb365bc183dc4 --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +from torch import Tensor + +from mmdet.registry import MODELS +from .convfc_bbox_head import ConvFCBBoxHead + + +@MODELS.register_module() +class SCNetBBoxHead(ConvFCBBoxHead): + """BBox head for `SCNet `_. + + This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us + to get intermediate shared feature. + """ + + def _forward_shared(self, x: Tensor) -> Tensor: + """Forward function for shared part. + + Args: + x (Tensor): Input feature. + + Returns: + Tensor: Shared feature. + """ + if self.num_shared_convs > 0: + for conv in self.shared_convs: + x = conv(x) + + if self.num_shared_fcs > 0: + if self.with_avg_pool: + x = self.avg_pool(x) + + x = x.flatten(1) + + for fc in self.shared_fcs: + x = self.relu(fc(x)) + + return x + + def _forward_cls_reg(self, x: Tensor) -> Tuple[Tensor]: + """Forward function for classification and regression parts. + + Args: + x (Tensor): Input feature. + + Returns: + tuple[Tensor]: + + - cls_score (Tensor): classification prediction. + - bbox_pred (Tensor): bbox prediction. + """ + x_cls = x + x_reg = x + + for conv in self.cls_convs: + x_cls = conv(x_cls) + if x_cls.dim() > 2: + if self.with_avg_pool: + x_cls = self.avg_pool(x_cls) + x_cls = x_cls.flatten(1) + for fc in self.cls_fcs: + x_cls = self.relu(fc(x_cls)) + + for conv in self.reg_convs: + x_reg = conv(x_reg) + if x_reg.dim() > 2: + if self.with_avg_pool: + x_reg = self.avg_pool(x_reg) + x_reg = x_reg.flatten(1) + for fc in self.reg_fcs: + x_reg = self.relu(fc(x_reg)) + + cls_score = self.fc_cls(x_cls) if self.with_cls else None + bbox_pred = self.fc_reg(x_reg) if self.with_reg else None + + return cls_score, bbox_pred + + def forward( + self, + x: Tensor, + return_shared_feat: bool = False) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + + Args: + x (Tensor): input features + return_shared_feat (bool): If True, return cls-reg-shared feature. + + Return: + out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``, + if ``return_shared_feat`` is True, append ``x_shared`` to the + returned tuple. + """ + x_shared = self._forward_shared(x) + out = self._forward_cls_reg(x_shared) + + if return_shared_feat: + out += (x_shared, ) + + return out diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..81db671113a63beb7849abdc0e432a738ee46f5e --- /dev/null +++ b/mmdet/models/roi_heads/cascade_roi_head.py @@ -0,0 +1,568 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import ModuleList +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.test_time_augs import merge_aug_masks +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi, get_box_tensor +from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, + OptMultiConfig) +from ..utils.misc import empty_instances, unpack_gt_instances +from .base_roi_head import BaseRoIHead + + +@MODELS.register_module() +class CascadeRoIHead(BaseRoIHead): + """Cascade roi head including one bbox head and one mask head. + + https://arxiv.org/abs/1712.00726 + """ + + def __init__(self, + num_stages: int, + stage_loss_weights: Union[List[float], Tuple[float]], + bbox_roi_extractor: OptMultiConfig = None, + bbox_head: OptMultiConfig = None, + mask_roi_extractor: OptMultiConfig = None, + mask_head: OptMultiConfig = None, + shared_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + assert bbox_roi_extractor is not None + assert bbox_head is not None + assert shared_head is None, \ + 'Shared head is not supported in Cascade RCNN anymore' + + self.num_stages = num_stages + self.stage_loss_weights = stage_loss_weights + super().__init__( + bbox_roi_extractor=bbox_roi_extractor, + bbox_head=bbox_head, + mask_roi_extractor=mask_roi_extractor, + mask_head=mask_head, + shared_head=shared_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg) + + def init_bbox_head(self, bbox_roi_extractor: MultiConfig, + bbox_head: MultiConfig) -> None: + """Initialize box head and box roi extractor. + + Args: + bbox_roi_extractor (:obj:`ConfigDict`, dict or list): + Config of box roi extractor. + bbox_head (:obj:`ConfigDict`, dict or list): Config + of box in box head. + """ + self.bbox_roi_extractor = ModuleList() + self.bbox_head = ModuleList() + if not isinstance(bbox_roi_extractor, list): + bbox_roi_extractor = [ + bbox_roi_extractor for _ in range(self.num_stages) + ] + if not isinstance(bbox_head, list): + bbox_head = [bbox_head for _ in range(self.num_stages)] + assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages + for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): + self.bbox_roi_extractor.append(MODELS.build(roi_extractor)) + self.bbox_head.append(MODELS.build(head)) + + def init_mask_head(self, mask_roi_extractor: MultiConfig, + mask_head: MultiConfig) -> None: + """Initialize mask head and mask roi extractor. + + Args: + mask_head (dict): Config of mask in mask head. + mask_roi_extractor (:obj:`ConfigDict`, dict or list): + Config of mask roi extractor. + """ + self.mask_head = nn.ModuleList() + if not isinstance(mask_head, list): + mask_head = [mask_head for _ in range(self.num_stages)] + assert len(mask_head) == self.num_stages + for head in mask_head: + self.mask_head.append(MODELS.build(head)) + if mask_roi_extractor is not None: + self.share_roi_extractor = False + self.mask_roi_extractor = ModuleList() + if not isinstance(mask_roi_extractor, list): + mask_roi_extractor = [ + mask_roi_extractor for _ in range(self.num_stages) + ] + assert len(mask_roi_extractor) == self.num_stages + for roi_extractor in mask_roi_extractor: + self.mask_roi_extractor.append(MODELS.build(roi_extractor)) + else: + self.share_roi_extractor = True + self.mask_roi_extractor = self.bbox_roi_extractor + + def init_assigner_sampler(self) -> None: + """Initialize assigner and sampler for each stage.""" + self.bbox_assigner = [] + self.bbox_sampler = [] + if self.train_cfg is not None: + for idx, rcnn_train_cfg in enumerate(self.train_cfg): + self.bbox_assigner.append( + TASK_UTILS.build(rcnn_train_cfg.assigner)) + self.current_stage = idx + self.bbox_sampler.append( + TASK_UTILS.build( + rcnn_train_cfg.sampler, + default_args=dict(context=self))) + + def _bbox_forward(self, stage: int, x: Tuple[Tensor], + rois: Tensor) -> dict: + """Box head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + bbox_roi_extractor = self.bbox_roi_extractor[stage] + bbox_head = self.bbox_head[stage] + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], + rois) + # do not support caffe_c4 model anymore + cls_score, bbox_pred = bbox_head(bbox_feats) + + bbox_results = dict( + cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) + return bbox_results + + def bbox_loss(self, stage: int, x: Tuple[Tensor], + sampling_results: List[SamplingResult]) -> dict: + """Run forward function and calculate loss for box head in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + dict: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + - `rois` (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + - `bbox_targets` (tuple): Ground truth for proposals in a + single image. Containing the following list of Tensors: + (labels, label_weights, bbox_targets, bbox_weights) + """ + bbox_head = self.bbox_head[stage] + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward(stage, x, rois) + bbox_results.update(rois=rois) + + bbox_loss_and_target = bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg[stage]) + bbox_results.update(bbox_loss_and_target) + + return bbox_results + + def _mask_forward(self, stage: int, x: Tuple[Tensor], + rois: Tensor) -> dict: + """Mask head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + """ + mask_roi_extractor = self.mask_roi_extractor[stage] + mask_head = self.mask_head[stage] + mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs], + rois) + # do not support caffe_c4 model anymore + mask_preds = mask_head(mask_feats) + + mask_results = dict(mask_preds=mask_preds) + return mask_results + + def mask_loss(self, stage: int, x: Tuple[Tensor], + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList) -> dict: + """Run forward function and calculate loss for mask head in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward(stage, x, pos_rois) + + mask_head = self.mask_head[stage] + + mask_loss_and_target = mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg[stage]) + mask_results.update(mask_loss_and_target) + + return mask_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + # TODO: May add a new function in baseroihead + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + + num_imgs = len(batch_data_samples) + losses = dict() + results_list = rpn_results_list + for stage in range(self.num_stages): + self.current_stage = stage + + stage_loss_weight = self.stage_loss_weights[stage] + + # assign gts and sample proposals + sampling_results = [] + if self.with_bbox or self.with_mask: + bbox_assigner = self.bbox_assigner[stage] + bbox_sampler = self.bbox_sampler[stage] + + for i in range(num_imgs): + results = results_list[i] + # rename rpn_results.bboxes to rpn_results.priors + results.priors = results.pop('bboxes') + + assign_result = bbox_assigner.assign( + results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + + sampling_result = bbox_sampler.sample( + assign_result, + results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + # bbox head forward and loss + bbox_results = self.bbox_loss(stage, x, sampling_results) + + for name, value in bbox_results['loss_bbox'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(stage, x, sampling_results, + batch_gt_instances) + for name, value in mask_results['loss_mask'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + # refine bboxes + if stage < self.num_stages - 1: + bbox_head = self.bbox_head[stage] + with torch.no_grad(): + results_list = bbox_head.refine_bboxes( + sampling_results, bbox_results, batch_img_metas) + # Empty proposal + if results_list is None: + break + return losses + + def predict_bbox(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + rpn_results_list: InstanceList, + rcnn_test_cfg: ConfigType, + rescale: bool = False, + **kwargs) -> InstanceList: + """Perform forward propagation of the bbox head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + proposals = [res.bboxes for res in rpn_results_list] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = bbox2roi(proposals) + + if rois.shape[0] == 0: + return empty_instances( + batch_img_metas, + rois.device, + task_type='bbox', + box_type=self.bbox_head[-1].predict_box_type, + num_classes=self.bbox_head[-1].num_classes, + score_per_cls=rcnn_test_cfg is None) + + rois, cls_scores, bbox_preds = self._refine_roi( + x=x, + rois=rois, + batch_img_metas=batch_img_metas, + num_proposals_per_img=num_proposals_per_img, + **kwargs) + + results_list = self.bbox_head[-1].predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rescale=rescale, + rcnn_test_cfg=rcnn_test_cfg) + return results_list + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: List[InstanceData], + rescale: bool = False) -> List[InstanceData]: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + num_mask_rois_per_img = [len(res) for res in results_list] + aug_masks = [] + for stage in range(self.num_stages): + mask_results = self._mask_forward(stage, x, mask_rois) + mask_preds = mask_results['mask_preds'] + # split batch mask prediction back to each image + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + aug_masks.append([m.sigmoid().detach() for m in mask_preds]) + + merged_masks = [] + for i in range(len(batch_img_metas)): + aug_mask = [mask[i] for mask in aug_masks] + merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i]) + merged_masks.append(merged_mask) + results_list = self.mask_head[-1].predict_by_feat( + mask_preds=merged_masks, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale, + activate_map=True) + return results_list + + def _refine_roi(self, x: Tuple[Tensor], rois: Tensor, + batch_img_metas: List[dict], + num_proposals_per_img: Sequence[int], **kwargs) -> tuple: + """Multi-stage refinement of RoI. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): shape (n, 5), [batch_ind, x1, y1, x2, y2] + batch_img_metas (list[dict]): List of image information. + num_proposals_per_img (sequence[int]): number of proposals + in each image. + + Returns: + tuple: + + - rois (Tensor): Refined RoI. + - cls_scores (list[Tensor]): Average predicted + cls score per image. + - bbox_preds (list[Tensor]): Bbox branch predictions + for the last stage of per image. + """ + # "ms" in variable names means multi-stage + ms_scores = [] + for stage in range(self.num_stages): + bbox_results = self._bbox_forward( + stage=stage, x=x, rois=rois, **kwargs) + + # split batch bbox prediction back to each image + cls_scores = bbox_results['cls_score'] + bbox_preds = bbox_results['bbox_pred'] + + rois = rois.split(num_proposals_per_img, 0) + cls_scores = cls_scores.split(num_proposals_per_img, 0) + ms_scores.append(cls_scores) + + # some detector with_reg is False, bbox_preds will be None + if bbox_preds is not None: + # TODO move this to a sabl_roi_head + # the bbox prediction of some detectors like SABL is not Tensor + if isinstance(bbox_preds, torch.Tensor): + bbox_preds = bbox_preds.split(num_proposals_per_img, 0) + else: + bbox_preds = self.bbox_head[stage].bbox_pred_split( + bbox_preds, num_proposals_per_img) + else: + bbox_preds = (None, ) * len(batch_img_metas) + + if stage < self.num_stages - 1: + bbox_head = self.bbox_head[stage] + if bbox_head.custom_activation: + cls_scores = [ + bbox_head.loss_cls.get_activation(s) + for s in cls_scores + ] + refine_rois_list = [] + for i in range(len(batch_img_metas)): + if rois[i].shape[0] > 0: + bbox_label = cls_scores[i][:, :-1].argmax(dim=1) + # Refactor `bbox_head.regress_by_class` to only accept + # box tensor without img_idx concatenated. + refined_bboxes = bbox_head.regress_by_class( + rois[i][:, 1:], bbox_label, bbox_preds[i], + batch_img_metas[i]) + refined_bboxes = get_box_tensor(refined_bboxes) + refined_rois = torch.cat( + [rois[i][:, [0]], refined_bboxes], dim=1) + refine_rois_list.append(refined_rois) + rois = torch.cat(refine_rois_list) + + # average scores of each image by stages + cls_scores = [ + sum([score[i] for score in ms_scores]) / float(len(ms_scores)) + for i in range(len(batch_img_metas)) + ] + return rois, cls_scores, bbox_preds + + def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + results = () + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = bbox2roi(proposals) + # bbox head + if self.with_bbox: + rois, cls_scores, bbox_preds = self._refine_roi( + x, rois, batch_img_metas, num_proposals_per_img) + results = results + (cls_scores, bbox_preds) + # mask head + if self.with_mask: + aug_masks = [] + rois = torch.cat(rois) + for stage in range(self.num_stages): + mask_results = self._mask_forward(stage, x, rois) + mask_preds = mask_results['mask_preds'] + mask_preds = mask_preds.split(num_proposals_per_img, 0) + aug_masks.append([m.sigmoid().detach() for m in mask_preds]) + + merged_masks = [] + for i in range(len(batch_img_metas)): + aug_mask = [mask[i] for mask in aug_masks] + merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i]) + merged_masks.append(merged_mask) + results = results + (merged_masks, ) + return results diff --git a/mmdet/models/roi_heads/double_roi_head.py b/mmdet/models/roi_heads/double_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f9464ff55bafcca9f3545a3a72dde1eb3939cece --- /dev/null +++ b/mmdet/models/roi_heads/double_roi_head.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +from torch import Tensor + +from mmdet.registry import MODELS +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class DoubleHeadRoIHead(StandardRoIHead): + """RoI head for `Double Head RCNN `_. + + Args: + reg_roi_scale_factor (float): The scale factor to extend the rois + used to extract the regression features. + """ + + def __init__(self, reg_roi_scale_factor: float, **kwargs): + super().__init__(**kwargs) + self.reg_roi_scale_factor = reg_roi_scale_factor + + def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: + """Box head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + bbox_cls_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + bbox_reg_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], + rois, + roi_scale_factor=self.reg_roi_scale_factor) + if self.with_shared_head: + bbox_cls_feats = self.shared_head(bbox_cls_feats) + bbox_reg_feats = self.shared_head(bbox_reg_feats) + cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) + + bbox_results = dict( + cls_score=cls_score, + bbox_pred=bbox_pred, + bbox_feats=bbox_cls_feats) + return bbox_results diff --git a/mmdet/models/roi_heads/dynamic_roi_head.py b/mmdet/models/roi_heads/dynamic_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7f7bd2f68cab0fcdec725501f74b65274eb30e --- /dev/null +++ b/mmdet/models/roi_heads/dynamic_roi_head.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import numpy as np +import torch +from torch import Tensor + +from mmdet.models.losses import SmoothL1Loss +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList +from ..utils.misc import unpack_gt_instances +from .standard_roi_head import StandardRoIHead + +EPS = 1e-15 + + +@MODELS.register_module() +class DynamicRoIHead(StandardRoIHead): + """RoI head for `Dynamic R-CNN `_.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss) + # the IoU history of the past `update_iter_interval` iterations + self.iou_history = [] + # the beta history of the past `update_iter_interval` iterations + self.beta_history = [] + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Forward function for training. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + cur_iou = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + # record the `iou_topk`-th largest IoU in an image + iou_topk = min(self.train_cfg.dynamic_rcnn.iou_topk, + len(assign_result.max_overlaps)) + ious, _ = torch.topk(assign_result.max_overlaps, iou_topk) + cur_iou.append(ious[-1].item()) + sampling_results.append(sampling_result) + # average the current IoUs over images + cur_iou = np.mean(cur_iou) + self.iou_history.append(cur_iou) + + losses = dict() + # bbox head forward and loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results) + losses.update(bbox_results['loss_bbox']) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(x, sampling_results, + bbox_results['bbox_feats'], + batch_gt_instances) + losses.update(mask_results['loss_mask']) + + # update IoU threshold and SmoothL1 beta + update_iter_interval = self.train_cfg.dynamic_rcnn.update_iter_interval + if len(self.iou_history) % update_iter_interval == 0: + new_iou_thr, new_beta = self.update_hyperparameters() + + return losses + + def bbox_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult]) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + + bbox_loss_and_target = self.bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg) + bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox']) + + # record the `beta_topk`-th smallest target + # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets + # and bbox_weights, respectively + bbox_targets = bbox_loss_and_target['bbox_targets'] + pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1) + num_pos = len(pos_inds) + num_imgs = len(sampling_results) + if num_pos > 0: + cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1) + beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs, + num_pos) + cur_target = torch.kthvalue(cur_target, beta_topk)[0].item() + self.beta_history.append(cur_target) + + return bbox_results + + def update_hyperparameters(self): + """Update hyperparameters like IoU thresholds for assigner and beta for + SmoothL1 loss based on the training statistics. + + Returns: + tuple[float]: the updated ``iou_thr`` and ``beta``. + """ + new_iou_thr = max(self.train_cfg.dynamic_rcnn.initial_iou, + np.mean(self.iou_history)) + self.iou_history = [] + self.bbox_assigner.pos_iou_thr = new_iou_thr + self.bbox_assigner.neg_iou_thr = new_iou_thr + self.bbox_assigner.min_pos_iou = new_iou_thr + if (not self.beta_history) or (np.median(self.beta_history) < EPS): + # avoid 0 or too small value for new_beta + new_beta = self.bbox_head.loss_bbox.beta + else: + new_beta = min(self.train_cfg.dynamic_rcnn.initial_beta, + np.median(self.beta_history)) + self.beta_history = [] + self.bbox_head.loss_bbox.beta = new_beta + return new_iou_thr, new_beta diff --git a/mmdet/models/roi_heads/grid_roi_head.py b/mmdet/models/roi_heads/grid_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9eda7f01bcd4e44faca14b61ec4956ee2c372ad6 --- /dev/null +++ b/mmdet/models/roi_heads/grid_roi_head.py @@ -0,0 +1,280 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList +from ..task_modules.samplers import SamplingResult +from ..utils.misc import unpack_gt_instances +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class GridRoIHead(StandardRoIHead): + """Implementation of `Grid RoI Head `_ + + Args: + grid_roi_extractor (:obj:`ConfigDict` or dict): Config of + roi extractor. + grid_head (:obj:`ConfigDict` or dict): Config of grid head + """ + + def __init__(self, grid_roi_extractor: ConfigType, grid_head: ConfigType, + **kwargs) -> None: + assert grid_head is not None + super().__init__(**kwargs) + if grid_roi_extractor is not None: + self.grid_roi_extractor = MODELS.build(grid_roi_extractor) + self.share_roi_extractor = False + else: + self.share_roi_extractor = True + self.grid_roi_extractor = self.bbox_roi_extractor + self.grid_head = MODELS.build(grid_head) + + def _random_jitter(self, + sampling_results: List[SamplingResult], + batch_img_metas: List[dict], + amplitude: float = 0.15) -> List[SamplingResult]: + """Ramdom jitter positive proposals for training. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_img_metas (list[dict]): List of image information. + amplitude (float): Amplitude of random offset. Defaults to 0.15. + + Returns: + list[obj:SamplingResult]: SamplingResults after random jittering. + """ + for sampling_result, img_meta in zip(sampling_results, + batch_img_metas): + bboxes = sampling_result.pos_priors + random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_( + -amplitude, amplitude) + # before jittering + cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2 + wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs() + # after jittering + new_cxcy = cxcy + wh * random_offsets[:, :2] + new_wh = wh * (1 + random_offsets[:, 2:]) + # xywh to xyxy + new_x1y1 = (new_cxcy - new_wh / 2) + new_x2y2 = (new_cxcy + new_wh / 2) + new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1) + # clip bboxes + max_shape = img_meta['img_shape'] + if max_shape is not None: + new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1) + new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1) + + sampling_result.pos_priors = new_bboxes + return sampling_results + + # TODO: Forward is incorrect and need to refactor. + def forward(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList = None) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (Tuple[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + results = () + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + rois = bbox2roi(proposals) + # bbox head + if self.with_bbox: + bbox_results = self._bbox_forward(x, rois) + results = results + (bbox_results['cls_score'], ) + if self.bbox_head.with_reg: + results = results + (bbox_results['bbox_pred'], ) + + # grid head + grid_rois = rois[:100] + grid_feats = self.grid_roi_extractor( + x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois) + if self.with_shared_head: + grid_feats = self.shared_head(grid_feats) + self.grid_head.test_mode = True + grid_preds = self.grid_head(grid_feats) + results = results + (grid_preds, ) + + # mask head + if self.with_mask: + mask_rois = rois[:100] + mask_results = self._mask_forward(x, mask_rois) + results = results + (mask_results['mask_preds'], ) + return results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList, **kwargs) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + losses = dict() + # bbox head loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results, batch_img_metas) + losses.update(bbox_results['loss_bbox']) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(x, sampling_results, + bbox_results['bbox_feats'], + batch_gt_instances) + losses.update(mask_results['loss_mask']) + + return losses + + def bbox_loss(self, + x: Tuple[Tensor], + sampling_results: List[SamplingResult], + batch_img_metas: Optional[List[dict]] = None) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list[:obj:`SamplingResult`]): Sampling results. + batch_img_metas (list[dict], optional): Meta information of each + image, e.g., image size, scaling factor, etc. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + assert batch_img_metas is not None + bbox_results = super().bbox_loss(x, sampling_results) + + # Grid head forward and loss + sampling_results = self._random_jitter(sampling_results, + batch_img_metas) + pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) + + # GN in head does not support zero shape input + if pos_rois.shape[0] == 0: + return bbox_results + + grid_feats = self.grid_roi_extractor( + x[:self.grid_roi_extractor.num_inputs], pos_rois) + if self.with_shared_head: + grid_feats = self.shared_head(grid_feats) + # Accelerate training + max_sample_num_grid = self.train_cfg.get('max_num_grid', 192) + sample_idx = torch.randperm( + grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid + )] + grid_feats = grid_feats[sample_idx] + grid_pred = self.grid_head(grid_feats) + + loss_grid = self.grid_head.loss(grid_pred, sample_idx, + sampling_results, self.train_cfg) + + bbox_results['loss_bbox'].update(loss_grid) + return bbox_results + + def predict_bbox(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + rpn_results_list: InstanceList, + rcnn_test_cfg: ConfigType, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the bbox head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (:obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape \ + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), the last \ + dimension 4 arrange as (x1, y1, x2, y2). + """ + results_list = super().predict_bbox( + x, + batch_img_metas=batch_img_metas, + rpn_results_list=rpn_results_list, + rcnn_test_cfg=rcnn_test_cfg, + rescale=False) + + grid_rois = bbox2roi([res.bboxes for res in results_list]) + if grid_rois.shape[0] != 0: + grid_feats = self.grid_roi_extractor( + x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois) + if self.with_shared_head: + grid_feats = self.shared_head(grid_feats) + self.grid_head.test_mode = True + grid_preds = self.grid_head(grid_feats) + results_list = self.grid_head.predict_by_feat( + grid_preds=grid_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rescale=rescale) + + return results_list diff --git a/mmdet/models/roi_heads/htc_roi_head.py b/mmdet/models/roi_heads/htc_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdd99ddd5ce4d9d42345d1f1d14ecbcae658124 --- /dev/null +++ b/mmdet/models/roi_heads/htc_roi_head.py @@ -0,0 +1,581 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + +from mmdet.models.test_time_augs import merge_aug_masks +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList, OptConfigType +from ..layers import adaptive_avg_pool2d +from ..task_modules.samplers import SamplingResult +from ..utils import empty_instances, unpack_gt_instances +from .cascade_roi_head import CascadeRoIHead + + +@MODELS.register_module() +class HybridTaskCascadeRoIHead(CascadeRoIHead): + """Hybrid task cascade roi head including one bbox head and one mask head. + + https://arxiv.org/abs/1901.07518 + + Args: + num_stages (int): Number of cascade stages. + stage_loss_weights (list[float]): Loss weight for every stage. + semantic_roi_extractor (:obj:`ConfigDict` or dict, optional): + Config of semantic roi extractor. Defaults to None. + Semantic_head (:obj:`ConfigDict` or dict, optional): + Config of semantic head. Defaults to None. + interleaved (bool): Whether to interleaves the box branch and mask + branch. If True, the mask branch can take the refined bounding + box predictions. Defaults to True. + mask_info_flow (bool): Whether to turn on the mask information flow, + which means that feeding the mask features of the preceding stage + to the current stage. Defaults to True. + """ + + def __init__(self, + num_stages: int, + stage_loss_weights: List[float], + semantic_roi_extractor: OptConfigType = None, + semantic_head: OptConfigType = None, + semantic_fusion: Tuple[str] = ('bbox', 'mask'), + interleaved: bool = True, + mask_info_flow: bool = True, + **kwargs) -> None: + super().__init__( + num_stages=num_stages, + stage_loss_weights=stage_loss_weights, + **kwargs) + assert self.with_bbox + assert not self.with_shared_head # shared head is not supported + + if semantic_head is not None: + self.semantic_roi_extractor = MODELS.build(semantic_roi_extractor) + self.semantic_head = MODELS.build(semantic_head) + + self.semantic_fusion = semantic_fusion + self.interleaved = interleaved + self.mask_info_flow = mask_info_flow + + # TODO move to base_roi_head later + @property + def with_semantic(self) -> bool: + """bool: whether the head has semantic head""" + return hasattr(self, + 'semantic_head') and self.semantic_head is not None + + def _bbox_forward( + self, + stage: int, + x: Tuple[Tensor], + rois: Tensor, + semantic_feat: Optional[Tensor] = None) -> Dict[str, Tensor]: + """Box head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + semantic_feat (Tensor, optional): Semantic feature. Defaults to + None. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + bbox_roi_extractor = self.bbox_roi_extractor[stage] + bbox_head = self.bbox_head[stage] + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], + rois) + if self.with_semantic and 'bbox' in self.semantic_fusion: + bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat], + rois) + if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]: + bbox_semantic_feat = adaptive_avg_pool2d( + bbox_semantic_feat, bbox_feats.shape[-2:]) + bbox_feats += bbox_semantic_feat + cls_score, bbox_pred = bbox_head(bbox_feats) + + bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) + return bbox_results + + def bbox_loss(self, + stage: int, + x: Tuple[Tensor], + sampling_results: List[SamplingResult], + semantic_feat: Optional[Tensor] = None) -> dict: + """Run forward function and calculate loss for box head in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + semantic_feat (Tensor, optional): Semantic feature. Defaults to + None. + + Returns: + dict: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + - `rois` (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + - `bbox_targets` (tuple): Ground truth for proposals in a + single image. Containing the following list of Tensors: + (labels, label_weights, bbox_targets, bbox_weights) + """ + bbox_head = self.bbox_head[stage] + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward( + stage, x, rois, semantic_feat=semantic_feat) + bbox_results.update(rois=rois) + + bbox_loss_and_target = bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg[stage]) + bbox_results.update(bbox_loss_and_target) + return bbox_results + + def _mask_forward(self, + stage: int, + x: Tuple[Tensor], + rois: Tensor, + semantic_feat: Optional[Tensor] = None, + training: bool = True) -> Dict[str, Tensor]: + """Mask head forward function used only in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + semantic_feat (Tensor, optional): Semantic feature. Defaults to + None. + training (bool): Mask Forward is different between training and + testing. If True, use the mask forward in training. + Defaults to True. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + """ + mask_roi_extractor = self.mask_roi_extractor[stage] + mask_head = self.mask_head[stage] + mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs], + rois) + + # semantic feature fusion + # element-wise sum for original features and pooled semantic features + if self.with_semantic and 'mask' in self.semantic_fusion: + mask_semantic_feat = self.semantic_roi_extractor([semantic_feat], + rois) + if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: + mask_semantic_feat = F.adaptive_avg_pool2d( + mask_semantic_feat, mask_feats.shape[-2:]) + mask_feats = mask_feats + mask_semantic_feat + + # mask information flow + # forward all previous mask heads to obtain last_feat, and fuse it + # with the normal mask feature + if training: + if self.mask_info_flow: + last_feat = None + for i in range(stage): + last_feat = self.mask_head[i]( + mask_feats, last_feat, return_logits=False) + mask_preds = mask_head( + mask_feats, last_feat, return_feat=False) + else: + mask_preds = mask_head(mask_feats, return_feat=False) + + mask_results = dict(mask_preds=mask_preds) + else: + aug_masks = [] + last_feat = None + for i in range(self.num_stages): + mask_head = self.mask_head[i] + if self.mask_info_flow: + mask_preds, last_feat = mask_head(mask_feats, last_feat) + else: + mask_preds = mask_head(mask_feats) + aug_masks.append(mask_preds) + + mask_results = dict(mask_preds=aug_masks) + + return mask_results + + def mask_loss(self, + stage: int, + x: Tuple[Tensor], + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + semantic_feat: Optional[Tensor] = None) -> dict: + """Run forward function and calculate loss for mask head in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + semantic_feat (Tensor, optional): Semantic feature. Defaults to + None. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward( + stage=stage, + x=x, + rois=pos_rois, + semantic_feat=semantic_feat, + training=True) + + mask_head = self.mask_head[stage] + mask_loss_and_target = mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg[stage]) + mask_results.update(mask_loss_and_target) + + return mask_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + + # semantic segmentation part + # 2 outputs: segmentation prediction and embedded features + losses = dict() + if self.with_semantic: + gt_semantic_segs = [ + data_sample.gt_sem_seg.sem_seg + for data_sample in batch_data_samples + ] + gt_semantic_segs = torch.stack(gt_semantic_segs) + semantic_pred, semantic_feat = self.semantic_head(x) + loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_segs) + losses['loss_semantic_seg'] = loss_seg + else: + semantic_feat = None + + results_list = rpn_results_list + num_imgs = len(batch_img_metas) + for stage in range(self.num_stages): + self.current_stage = stage + + stage_loss_weight = self.stage_loss_weights[stage] + + # assign gts and sample proposals + sampling_results = [] + bbox_assigner = self.bbox_assigner[stage] + bbox_sampler = self.bbox_sampler[stage] + for i in range(num_imgs): + results = results_list[i] + # rename rpn_results.bboxes to rpn_results.priors + if 'bboxes' in results: + results.priors = results.pop('bboxes') + + assign_result = bbox_assigner.assign( + results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = bbox_sampler.sample( + assign_result, + results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + # bbox head forward and loss + bbox_results = self.bbox_loss( + stage=stage, + x=x, + sampling_results=sampling_results, + semantic_feat=semantic_feat) + + for name, value in bbox_results['loss_bbox'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + # mask head forward and loss + if self.with_mask: + # interleaved execution: use regressed bboxes by the box branch + # to train the mask branch + if self.interleaved: + bbox_head = self.bbox_head[stage] + with torch.no_grad(): + results_list = bbox_head.refine_bboxes( + sampling_results, bbox_results, batch_img_metas) + # re-assign and sample 512 RoIs from 512 RoIs + sampling_results = [] + for i in range(num_imgs): + results = results_list[i] + # rename rpn_results.bboxes to rpn_results.priors + results.priors = results.pop('bboxes') + assign_result = bbox_assigner.assign( + results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = bbox_sampler.sample( + assign_result, + results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + mask_results = self.mask_loss( + stage=stage, + x=x, + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + semantic_feat=semantic_feat) + for name, value in mask_results['loss_mask'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + # refine bboxes (same as Cascade R-CNN) + if stage < self.num_stages - 1 and not self.interleaved: + bbox_head = self.bbox_head[stage] + with torch.no_grad(): + results_list = bbox_head.refine_bboxes( + sampling_results=sampling_results, + bbox_results=bbox_results, + batch_img_metas=batch_img_metas) + + return losses + + def predict(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from upstream network. Each + has shape (N, C, H, W). + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results to + the original image. Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + assert self.with_bbox, 'Bbox head must be implemented.' + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + if self.with_semantic: + _, semantic_feat = self.semantic_head(x) + else: + semantic_feat = None + + # TODO: nms_op in mmcv need be enhanced, the bbox result may get + # difference when not rescale in bbox_head + + # If it has the mask branch, the bbox branch does not need + # to be scaled to the original image scale, because the mask + # branch will scale both bbox and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.predict_bbox( + x=x, + semantic_feat=semantic_feat, + batch_img_metas=batch_img_metas, + rpn_results_list=rpn_results_list, + rcnn_test_cfg=self.test_cfg, + rescale=bbox_rescale) + + if self.with_mask: + results_list = self.predict_mask( + x=x, + semantic_heat=semantic_feat, + batch_img_metas=batch_img_metas, + results_list=results_list, + rescale=rescale) + + return results_list + + def predict_mask(self, + x: Tuple[Tensor], + semantic_heat: Tensor, + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + semantic_feat (Tensor): Semantic feature. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + num_imgs = len(batch_img_metas) + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas=batch_img_metas, + device=mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + num_mask_rois_per_img = [len(res) for res in results_list] + mask_results = self._mask_forward( + stage=-1, + x=x, + rois=mask_rois, + semantic_feat=semantic_heat, + training=False) + # split batch mask prediction back to each image + aug_masks = [[ + mask.sigmoid().detach() + for mask in mask_preds.split(num_mask_rois_per_img, 0) + ] for mask_preds in mask_results['mask_preds']] + + merged_masks = [] + for i in range(num_imgs): + aug_mask = [mask[i] for mask in aug_masks] + merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i]) + merged_masks.append(merged_mask) + + results_list = self.mask_head[-1].predict_by_feat( + mask_preds=merged_masks, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale, + activate_map=True) + + return results_list + + def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + results = () + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + num_imgs = len(batch_img_metas) + + if self.with_semantic: + _, semantic_feat = self.semantic_head(x) + else: + semantic_feat = None + + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = bbox2roi(proposals) + # bbox head + if self.with_bbox: + rois, cls_scores, bbox_preds = self._refine_roi( + x=x, + rois=rois, + semantic_feat=semantic_feat, + batch_img_metas=batch_img_metas, + num_proposals_per_img=num_proposals_per_img) + results = results + (cls_scores, bbox_preds) + # mask head + if self.with_mask: + rois = torch.cat(rois) + mask_results = self._mask_forward( + stage=-1, + x=x, + rois=rois, + semantic_feat=semantic_feat, + training=False) + aug_masks = [[ + mask.sigmoid().detach() + for mask in mask_preds.split(num_proposals_per_img, 0) + ] for mask_preds in mask_results['mask_preds']] + + merged_masks = [] + for i in range(num_imgs): + aug_mask = [mask[i] for mask in aug_masks] + merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i]) + merged_masks.append(merged_mask) + results = results + (merged_masks, ) + return results diff --git a/mmdet/models/roi_heads/mask_heads/__init__.py b/mmdet/models/roi_heads/mask_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48a5d4227be41b8985403251e1803f78cf500636 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coarse_mask_head import CoarseMaskHead +from .dynamic_mask_head import DynamicMaskHead +from .fcn_mask_head import FCNMaskHead +from .feature_relay_head import FeatureRelayHead +from .fused_semantic_head import FusedSemanticHead +from .global_context_head import GlobalContextHead +from .grid_head import GridHead +from .htc_mask_head import HTCMaskHead +from .mask_point_head import MaskPointHead +from .maskiou_head import MaskIoUHead +from .scnet_mask_head import SCNetMaskHead +from .scnet_semantic_head import SCNetSemanticHead + +__all__ = [ + 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', + 'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead', + 'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead', + 'DynamicMaskHead' +] diff --git a/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py b/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1caa901228f2439492b82d1890eba468963eb28d --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule, Linear +from mmengine.model import ModuleList +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import MultiConfig +from .fcn_mask_head import FCNMaskHead + + +@MODELS.register_module() +class CoarseMaskHead(FCNMaskHead): + """Coarse mask head used in PointRend. + + Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample + the input feature map instead of upsample it. + + Args: + num_convs (int): Number of conv layers in the head. Defaults to 0. + num_fcs (int): Number of fc layers in the head. Defaults to 2. + fc_out_channels (int): Number of output channels of fc layer. + Defaults to 1024. + downsample_factor (int): The factor that feature map is downsampled by. + Defaults to 2. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + num_convs: int = 0, + num_fcs: int = 2, + fc_out_channels: int = 1024, + downsample_factor: int = 2, + init_cfg: MultiConfig = dict( + type='Xavier', + override=[ + dict(name='fcs'), + dict(type='Constant', val=0.001, name='fc_logits') + ]), + *arg, + **kwarg) -> None: + super().__init__( + *arg, + num_convs=num_convs, + upsample_cfg=dict(type=None), + init_cfg=None, + **kwarg) + self.init_cfg = init_cfg + self.num_fcs = num_fcs + assert self.num_fcs > 0 + self.fc_out_channels = fc_out_channels + self.downsample_factor = downsample_factor + assert self.downsample_factor >= 1 + # remove conv_logit + delattr(self, 'conv_logits') + + if downsample_factor > 1: + downsample_in_channels = ( + self.conv_out_channels + if self.num_convs > 0 else self.in_channels) + self.downsample_conv = ConvModule( + downsample_in_channels, + self.conv_out_channels, + kernel_size=downsample_factor, + stride=downsample_factor, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + else: + self.downsample_conv = None + + self.output_size = (self.roi_feat_size[0] // downsample_factor, + self.roi_feat_size[1] // downsample_factor) + self.output_area = self.output_size[0] * self.output_size[1] + + last_layer_dim = self.conv_out_channels * self.output_area + + self.fcs = ModuleList() + for i in range(num_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + self.fcs.append(Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + output_channels = self.num_classes * self.output_area + self.fc_logits = Linear(last_layer_dim, output_channels) + + def init_weights(self) -> None: + """Initialize weights.""" + super(FCNMaskHead, self).init_weights() + + def forward(self, x: Tensor) -> Tensor: + """Forward features from the upstream network. + + Args: + x (Tensor): Extract mask RoI features. + + Returns: + Tensor: Predicted foreground masks. + """ + for conv in self.convs: + x = conv(x) + + if self.downsample_conv is not None: + x = self.downsample_conv(x) + + x = x.flatten(1) + for fc in self.fcs: + x = self.relu(fc(x)) + mask_preds = self.fc_logits(x).view( + x.size(0), self.num_classes, *self.output_size) + return mask_preds diff --git a/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py b/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f33612b1b141668d0463435975c14a26fbe5a0cd --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdet.models.task_modules import SamplingResult +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, OptConfigType, reduce_mean +from .fcn_mask_head import FCNMaskHead + + +@MODELS.register_module() +class DynamicMaskHead(FCNMaskHead): + r"""Dynamic Mask Head for + `Instances as Queries `_ + + Args: + num_convs (int): Number of convolution layer. + Defaults to 4. + roi_feat_size (int): The output size of RoI extractor, + Defaults to 14. + in_channels (int): Input feature channels. + Defaults to 256. + conv_kernel_size (int): Kernel size of convolution layers. + Defaults to 3. + conv_out_channels (int): Output channels of convolution layers. + Defaults to 256. + num_classes (int): Number of classes. + Defaults to 80 + class_agnostic (int): Whether generate class agnostic prediction. + Defaults to False. + dropout (float): Probability of drop the channel. + Defaults to 0.0 + upsample_cfg (:obj:`ConfigDict` or dict): The config for + upsample layer. + conv_cfg (:obj:`ConfigDict` or dict, optional): The convolution + layer config. + norm_cfg (:obj:`ConfigDict` or dict, optional): The norm layer config. + dynamic_conv_cfg (:obj:`ConfigDict` or dict): The dynamic convolution + layer config. + loss_mask (:obj:`ConfigDict` or dict): The config for mask loss. + """ + + def __init__(self, + num_convs: int = 4, + roi_feat_size: int = 14, + in_channels: int = 256, + conv_kernel_size: int = 3, + conv_out_channels: int = 256, + num_classes: int = 80, + class_agnostic: bool = False, + upsample_cfg: ConfigType = dict( + type='deconv', scale_factor=2), + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + dynamic_conv_cfg: ConfigType = dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + input_feat_shape=14, + with_proj=False, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')), + loss_mask: ConfigType = dict( + type='DiceLoss', loss_weight=8.0), + **kwargs) -> None: + super().__init__( + num_convs=num_convs, + roi_feat_size=roi_feat_size, + in_channels=in_channels, + conv_kernel_size=conv_kernel_size, + conv_out_channels=conv_out_channels, + num_classes=num_classes, + class_agnostic=class_agnostic, + upsample_cfg=upsample_cfg, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + loss_mask=loss_mask, + **kwargs) + assert class_agnostic is False, \ + 'DynamicMaskHead only support class_agnostic=False' + self.fp16_enabled = False + + self.instance_interactive_conv = MODELS.build(dynamic_conv_cfg) + + def init_weights(self) -> None: + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + nn.init.constant_(self.conv_logits.bias, 0.) + + def forward(self, roi_feat: Tensor, proposal_feat: Tensor) -> Tensor: + """Forward function of DynamicMaskHead. + + Args: + roi_feat (Tensor): Roi-pooling features with shape + (batch_size*num_proposals, feature_dimensions, + pooling_h , pooling_w). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size*num_proposals, feature_dimensions) + + Returns: + mask_preds (Tensor): Predicted foreground masks with shape + (batch_size*num_proposals, num_classes, pooling_h*2, pooling_w*2). + """ + + proposal_feat = proposal_feat.reshape(-1, self.in_channels) + proposal_feat_iic = self.instance_interactive_conv( + proposal_feat, roi_feat) + + x = proposal_feat_iic.permute(0, 2, 1).reshape(roi_feat.size()) + + for conv in self.convs: + x = conv(x) + if self.upsample is not None: + x = self.upsample(x) + if self.upsample_method == 'deconv': + x = self.relu(x) + mask_preds = self.conv_logits(x) + return mask_preds + + def loss_and_target(self, mask_preds: Tensor, + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (num_pos, num_classes, h, w). + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + dict: A dictionary of loss and targets components. + """ + mask_targets = self.get_targets( + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=rcnn_train_cfg) + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + + num_pos = pos_labels.new_ones(pos_labels.size()).float().sum() + avg_factor = torch.clamp(reduce_mean(num_pos), min=1.).item() + loss = dict() + if mask_preds.size(0) == 0: + loss_mask = mask_preds.sum() + else: + loss_mask = self.loss_mask( + mask_preds[torch.arange(num_pos).long(), pos_labels, + ...].sigmoid(), + mask_targets, + avg_factor=avg_factor) + loss['loss_mask'] = loss_mask + return dict(loss_mask=loss, mask_targets=mask_targets) diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3a089dfafcb69784f2fc266f0945e6d56b0466d3 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py @@ -0,0 +1,474 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer +from mmcv.ops.carafe import CARAFEPack +from mmengine.config import ConfigDict +from mmengine.model import BaseModule, ModuleList +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.utils import empty_instances +from mmdet.registry import MODELS +from mmdet.structures.mask import mask_target +from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig + +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +@MODELS.register_module() +class FCNMaskHead(BaseModule): + + def __init__(self, + num_convs: int = 4, + roi_feat_size: int = 14, + in_channels: int = 256, + conv_kernel_size: int = 3, + conv_out_channels: int = 256, + num_classes: int = 80, + class_agnostic: int = False, + upsample_cfg: ConfigType = dict( + type='deconv', scale_factor=2), + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + predictor_cfg: ConfigType = dict(type='Conv'), + loss_mask: ConfigType = dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + init_cfg: OptMultiConfig = None) -> None: + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super().__init__(init_cfg=init_cfg) + self.upsample_cfg = upsample_cfg.copy() + if self.upsample_cfg['type'] not in [ + None, 'deconv', 'nearest', 'bilinear', 'carafe' + ]: + raise ValueError( + f'Invalid upsample method {self.upsample_cfg["type"]}, ' + 'accepted methods are "deconv", "nearest", "bilinear", ' + '"carafe"') + self.num_convs = num_convs + # WARN: roi_feat_size is reserved and not used + self.roi_feat_size = _pair(roi_feat_size) + self.in_channels = in_channels + self.conv_kernel_size = conv_kernel_size + self.conv_out_channels = conv_out_channels + self.upsample_method = self.upsample_cfg.get('type') + self.scale_factor = self.upsample_cfg.pop('scale_factor', None) + self.num_classes = num_classes + self.class_agnostic = class_agnostic + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.predictor_cfg = predictor_cfg + self.loss_mask = MODELS.build(loss_mask) + + self.convs = ModuleList() + for i in range(self.num_convs): + in_channels = ( + self.in_channels if i == 0 else self.conv_out_channels) + padding = (self.conv_kernel_size - 1) // 2 + self.convs.append( + ConvModule( + in_channels, + self.conv_out_channels, + self.conv_kernel_size, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + upsample_in_channels = ( + self.conv_out_channels if self.num_convs > 0 else in_channels) + upsample_cfg_ = self.upsample_cfg.copy() + if self.upsample_method is None: + self.upsample = None + elif self.upsample_method == 'deconv': + upsample_cfg_.update( + in_channels=upsample_in_channels, + out_channels=self.conv_out_channels, + kernel_size=self.scale_factor, + stride=self.scale_factor) + self.upsample = build_upsample_layer(upsample_cfg_) + elif self.upsample_method == 'carafe': + upsample_cfg_.update( + channels=upsample_in_channels, scale_factor=self.scale_factor) + self.upsample = build_upsample_layer(upsample_cfg_) + else: + # suppress warnings + align_corners = (None + if self.upsample_method == 'nearest' else False) + upsample_cfg_.update( + scale_factor=self.scale_factor, + mode=self.upsample_method, + align_corners=align_corners) + self.upsample = build_upsample_layer(upsample_cfg_) + + out_channels = 1 if self.class_agnostic else self.num_classes + logits_in_channel = ( + self.conv_out_channels + if self.upsample_method == 'deconv' else upsample_in_channels) + self.conv_logits = build_conv_layer(self.predictor_cfg, + logits_in_channel, out_channels, 1) + self.relu = nn.ReLU(inplace=True) + self.debug_imgs = None + + def init_weights(self) -> None: + """Initialize the weights.""" + super().init_weights() + for m in [self.upsample, self.conv_logits]: + if m is None: + continue + elif isinstance(m, CARAFEPack): + m.init_weights() + elif hasattr(m, 'weight') and hasattr(m, 'bias'): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward features from the upstream network. + + Args: + x (Tensor): Extract mask RoI features. + + Returns: + Tensor: Predicted foreground masks. + """ + for conv in self.convs: + x = conv(x) + if self.upsample is not None: + x = self.upsample(x) + if self.upsample_method == 'deconv': + x = self.relu(x) + mask_preds = self.conv_logits(x) + return mask_preds + + def get_targets(self, sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> Tensor: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + Tensor: Mask target of each positive proposals in the image. + """ + pos_proposals = [res.pos_priors for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + gt_masks = [res.masks for res in batch_gt_instances] + mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, + gt_masks, rcnn_train_cfg) + return mask_targets + + def loss_and_target(self, mask_preds: Tensor, + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (num_pos, num_classes, h, w). + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + dict: A dictionary of loss and targets components. + """ + mask_targets = self.get_targets( + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=rcnn_train_cfg) + + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + + loss = dict() + if mask_preds.size(0) == 0: + loss_mask = mask_preds.sum() + else: + if self.class_agnostic: + loss_mask = self.loss_mask(mask_preds, mask_targets, + torch.zeros_like(pos_labels)) + else: + loss_mask = self.loss_mask(mask_preds, mask_targets, + pos_labels) + loss['loss_mask'] = loss_mask + # TODO: which algorithm requires mask_targets? + return dict(loss_mask=loss, mask_targets=mask_targets) + + def predict_by_feat(self, + mask_preds: Tuple[Tensor], + results_list: List[InstanceData], + batch_img_metas: List[dict], + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. + + Args: + mask_preds (tuple[Tensor]): Tuple of predicted foreground masks, + each has shape (n, num_classes, h, w). + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + assert len(mask_preds) == len(results_list) == len(batch_img_metas) + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = results_list[img_id] + bboxes = results.bboxes + if bboxes.shape[0] == 0: + results_list[img_id] = empty_instances( + [img_meta], + bboxes.device, + task_type='mask', + instance_results=[results], + mask_thr_binary=rcnn_test_cfg.mask_thr_binary)[0] + else: + im_mask = self._predict_by_feat_single( + mask_preds=mask_preds[img_id], + bboxes=bboxes, + labels=results.labels, + img_meta=img_meta, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale, + activate_map=activate_map) + results.masks = im_mask + return results_list + + def _predict_by_feat_single(self, + mask_preds: Tensor, + bboxes: Tensor, + labels: Tensor, + img_meta: dict, + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False) -> Tensor: + """Get segmentation masks from mask_preds and bboxes. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (n, num_classes, h, w). + bboxes (Tensor): Predicted bboxes, has shape (n, 4) + labels (Tensor): Labels of bboxes, has shape (n, ) + img_meta (dict): image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + Tensor: Encoded masks, has shape (n, img_w, img_h) + + Example: + >>> from mmengine.config import Config + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> self = FCNMaskHead(num_classes=C, num_convs=0) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_preds = self.forward(inputs) + >>> # Each input is associated with some bounding box + >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) + >>> labels = torch.randint(0, C, size=(N,)) + >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, }) + >>> ori_shape = (H * 4, W * 4) + >>> scale_factor = (1, 1) + >>> rescale = False + >>> img_meta = {'scale_factor': scale_factor, + ... 'ori_shape': ori_shape} + >>> # Encoded masks are a list for each category. + >>> encoded_masks = self._get_seg_masks_single( + ... mask_preds, bboxes, labels, + ... img_meta, rcnn_test_cfg, rescale) + >>> assert encoded_masks.size()[0] == N + >>> assert encoded_masks.size()[1:] == ori_shape + """ + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + img_h, img_w = img_meta['ori_shape'][:2] + device = bboxes.device + + if not activate_map: + mask_preds = mask_preds.sigmoid() + else: + # In AugTest, has been activated before + mask_preds = bboxes.new_tensor(mask_preds) + + if rescale: # in-placed rescale the bboxes + bboxes /= scale_factor + else: + w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] + img_h = np.round(img_h * h_scale.item()).astype(np.int32) + img_w = np.round(img_w * w_scale.item()).astype(np.int32) + + N = len(mask_preds) + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == 'cpu': + # CPU is most efficient when they are pasted one by one with + # skip_empty=True, so that it performs minimal number of + # operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, + # but may have memory issue + # the types of img_w and img_h are np.int32, + # when the image resolution is large, + # the calculation of num_chunks will overflow. + # so we need to change the types of img_w and img_h to int. + # See https://github.com/open-mmlab/mmdetection/pull/5191 + num_chunks = int( + np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / + GPU_MEM_LIMIT)) + assert (num_chunks <= + N), 'Default GPU_MEM_LIMIT is too small; try increasing it' + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) + + threshold = rcnn_test_cfg.mask_thr_binary + im_mask = torch.zeros( + N, + img_h, + img_w, + device=device, + dtype=torch.bool if threshold >= 0 else torch.uint8) + + if not self.class_agnostic: + mask_preds = mask_preds[range(N), labels][:, None] + + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + mask_preds[inds], + bboxes[inds], + img_h, + img_w, + skip_empty=device.type == 'cpu') + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + else: + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + im_mask[(inds, ) + spatial_inds] = masks_chunk + return im_mask + + +def _do_paste_mask(masks: Tensor, + boxes: Tensor, + img_h: int, + img_w: int, + skip_empty: bool = True) -> tuple: + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, + min=0).to(dtype=torch.int32) + x1_int = torch.clamp( + boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp( + boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + # IsInf op is not supported with ONNX<=1.7.0 + if not torch.onnx.is_in_onnx_export(): + if torch.isinf(img_x).any(): + inds = torch.where(torch.isinf(img_x)) + img_x[inds] = 0 + if torch.isinf(img_y).any(): + inds = torch.where(torch.isinf(img_y)) + img_y[inds] = 0 + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = F.grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () diff --git a/mmdet/models/roi_heads/mask_heads/feature_relay_head.py b/mmdet/models/roi_heads/mask_heads/feature_relay_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0c34561fa5fd749329eda164465ce9787278d357 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/feature_relay_head.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import MultiConfig + + +@MODELS.register_module() +class FeatureRelayHead(BaseModule): + """Feature Relay Head used in `SCNet `_. + + Args: + in_channels (int): number of input channels. Defaults to 256. + conv_out_channels (int): number of output channels before + classification layer. Defaults to 256. + roi_feat_size (int): roi feat size at box head. Default: 7. + scale_factor (int): scale factor to match roi feat size + at mask head. Defaults to 2. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. Defaults to + dict(type='Kaiming', layer='Linear'). + """ + + def __init__( + self, + in_channels: int = 1024, + out_conv_channels: int = 256, + roi_feat_size: int = 7, + scale_factor: int = 2, + init_cfg: MultiConfig = dict(type='Kaiming', layer='Linear') + ) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(roi_feat_size, int) + + self.in_channels = in_channels + self.out_conv_channels = out_conv_channels + self.roi_feat_size = roi_feat_size + self.out_channels = (roi_feat_size**2) * out_conv_channels + self.scale_factor = scale_factor + self.fp16_enabled = False + + self.fc = nn.Linear(self.in_channels, self.out_channels) + self.upsample = nn.Upsample( + scale_factor=scale_factor, mode='bilinear', align_corners=True) + + def forward(self, x: Tensor) -> Optional[Tensor]: + """Forward function. + + Args: + x (Tensor): Input feature. + + Returns: + Optional[Tensor]: Output feature. When the first dim of input is + 0, None is returned. + """ + N, _ = x.shape + if N > 0: + out_C = self.out_conv_channels + out_HW = self.roi_feat_size + x = self.fc(x) + x = x.reshape(N, out_C, out_HW, out_HW) + x = self.upsample(x) + return x + return None diff --git a/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py b/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d20beb2975a563f03e7b6b2afcef287cb41af05a --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Tuple + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.config import ConfigDict +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import MultiConfig, OptConfigType + + +@MODELS.register_module() +class FusedSemanticHead(BaseModule): + r"""Multi-level fused semantic segmentation head. + + .. code-block:: none + + in_1 -> 1x1 conv --- + | + in_2 -> 1x1 conv -- | + || + in_3 -> 1x1 conv - || + ||| /-> 1x1 conv (mask prediction) + in_4 -> 1x1 conv -----> 3x3 convs (*4) + | \-> 1x1 conv (feature) + in_5 -> 1x1 conv --- + """ # noqa: W605 + + def __init__( + self, + num_ins: int, + fusion_level: int, + seg_scale_factor=1 / 8, + num_convs: int = 4, + in_channels: int = 256, + conv_out_channels: int = 256, + num_classes: int = 183, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + ignore_label: int = None, + loss_weight: float = None, + loss_seg: ConfigDict = dict( + type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2), + init_cfg: MultiConfig = dict( + type='Kaiming', override=dict(name='conv_logits')) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.num_ins = num_ins + self.fusion_level = fusion_level + self.seg_scale_factor = seg_scale_factor + self.num_convs = num_convs + self.in_channels = in_channels + self.conv_out_channels = conv_out_channels + self.num_classes = num_classes + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.fp16_enabled = False + + self.lateral_convs = nn.ModuleList() + for i in range(self.num_ins): + self.lateral_convs.append( + ConvModule( + self.in_channels, + self.in_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=False)) + + self.convs = nn.ModuleList() + for i in range(self.num_convs): + in_channels = self.in_channels if i == 0 else conv_out_channels + self.convs.append( + ConvModule( + in_channels, + conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.conv_embedding = ConvModule( + conv_out_channels, + conv_out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1) + if ignore_label: + loss_seg['ignore_index'] = ignore_label + if loss_weight: + loss_seg['loss_weight'] = loss_weight + if ignore_label or loss_weight: + warnings.warn('``ignore_label`` and ``loss_weight`` would be ' + 'deprecated soon. Please set ``ingore_index`` and ' + '``loss_weight`` in ``loss_seg`` instead.') + self.criterion = MODELS.build(loss_seg) + + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]: + """Forward function. + + Args: + feats (tuple[Tensor]): Multi scale feature maps. + + Returns: + tuple[Tensor]: + + - mask_preds (Tensor): Predicted mask logits. + - x (Tensor): Fused feature. + """ + x = self.lateral_convs[self.fusion_level](feats[self.fusion_level]) + fused_size = tuple(x.shape[-2:]) + for i, feat in enumerate(feats): + if i != self.fusion_level: + feat = F.interpolate( + feat, size=fused_size, mode='bilinear', align_corners=True) + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + x = x + self.lateral_convs[i](feat) + + for i in range(self.num_convs): + x = self.convs[i](x) + + mask_preds = self.conv_logits(x) + x = self.conv_embedding(x) + return mask_preds, x + + def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor: + """Loss function. + + Args: + mask_preds (Tensor): Predicted mask logits. + labels (Tensor): Ground truth. + + Returns: + Tensor: Semantic segmentation loss. + """ + labels = F.interpolate( + labels.float(), scale_factor=self.seg_scale_factor, mode='nearest') + labels = labels.squeeze(1).long() + loss_semantic_seg = self.criterion(mask_preds, labels) + return loss_semantic_seg diff --git a/mmdet/models/roi_heads/mask_heads/global_context_head.py b/mmdet/models/roi_heads/mask_heads/global_context_head.py new file mode 100644 index 0000000000000000000000000000000000000000..cb947ea582227d2b74112cbb930e1a3f85b77ff5 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/global_context_head.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.models.layers import ResLayer, SimplifiedBasicBlock +from mmdet.registry import MODELS +from mmdet.utils import MultiConfig, OptConfigType + + +@MODELS.register_module() +class GlobalContextHead(BaseModule): + """Global context head used in `SCNet `_. + + Args: + num_convs (int, optional): number of convolutional layer in GlbCtxHead. + Defaults to 4. + in_channels (int, optional): number of input channels. Defaults to 256. + conv_out_channels (int, optional): number of output channels before + classification layer. Defaults to 256. + num_classes (int, optional): number of classes. Defaults to 80. + loss_weight (float, optional): global context loss weight. + Defaults to 1. + conv_cfg (dict, optional): config to init conv layer. Defaults to None. + norm_cfg (dict, optional): config to init norm layer. Defaults to None. + conv_to_res (bool, optional): if True, 2 convs will be grouped into + 1 `SimplifiedBasicBlock` using a skip connection. + Defaults to False. + init_cfg (:obj:`ConfigDict` or dict or list[dict] or + list[:obj:`ConfigDict`]): Initialization config dict. Defaults to + dict(type='Normal', std=0.01, override=dict(name='fc')). + """ + + def __init__( + self, + num_convs: int = 4, + in_channels: int = 256, + conv_out_channels: int = 256, + num_classes: int = 80, + loss_weight: float = 1.0, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + conv_to_res: bool = False, + init_cfg: MultiConfig = dict( + type='Normal', std=0.01, override=dict(name='fc')) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.num_convs = num_convs + self.in_channels = in_channels + self.conv_out_channels = conv_out_channels + self.num_classes = num_classes + self.loss_weight = loss_weight + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.conv_to_res = conv_to_res + self.fp16_enabled = False + + if self.conv_to_res: + num_res_blocks = num_convs // 2 + self.convs = ResLayer( + SimplifiedBasicBlock, + in_channels, + self.conv_out_channels, + num_res_blocks, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + self.num_convs = num_res_blocks + else: + self.convs = nn.ModuleList() + for i in range(self.num_convs): + in_channels = self.in_channels if i == 0 else conv_out_channels + self.convs.append( + ConvModule( + in_channels, + conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(conv_out_channels, num_classes) + + self.criterion = nn.BCEWithLogitsLoss() + + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]: + """Forward function. + + Args: + feats (Tuple[Tensor]): Multi-scale feature maps. + + Returns: + Tuple[Tensor]: + + - mc_pred (Tensor): Multi-class prediction. + - x (Tensor): Global context feature. + """ + x = feats[-1] + for i in range(self.num_convs): + x = self.convs[i](x) + x = self.pool(x) + + # multi-class prediction + mc_pred = x.reshape(x.size(0), -1) + mc_pred = self.fc(mc_pred) + + return mc_pred, x + + def loss(self, pred: Tensor, labels: List[Tensor]) -> Tensor: + """Loss function. + + Args: + pred (Tensor): Logits. + labels (list[Tensor]): Grouth truths. + + Returns: + Tensor: Loss. + """ + labels = [lbl.unique() for lbl in labels] + targets = pred.new_zeros(pred.size()) + for i, label in enumerate(labels): + targets[i, label] = 1.0 + loss = self.loss_weight * self.criterion(pred, targets) + return loss diff --git a/mmdet/models/roi_heads/mask_heads/grid_head.py b/mmdet/models/roi_heads/mask_heads/grid_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d9514ae7bcfc1b7d5613fa0107e9bd087e13dd46 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/grid_head.py @@ -0,0 +1,490 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.config import ConfigDict +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType + + +@MODELS.register_module() +class GridHead(BaseModule): + """Implementation of `Grid Head `_ + + Args: + grid_points (int): The number of grid points. Defaults to 9. + num_convs (int): The number of convolution layers. Defaults to 8. + roi_feat_size (int): RoI feature size. Default to 14. + in_channels (int): The channel number of inputs features. + Defaults to 256. + conv_kernel_size (int): The kernel size of convolution layers. + Defaults to 3. + point_feat_channels (int): The number of channels of each point + features. Defaults to 64. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Defaults to False. + loss_grid (:obj:`ConfigDict` or dict): Config of grid loss. + conv_cfg (:obj:`ConfigDict` or dict, optional) dictionary to + construct and config conv layer. + norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and + config norm layer. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ + + def __init__( + self, + grid_points: int = 9, + num_convs: int = 8, + roi_feat_size: int = 14, + in_channels: int = 256, + conv_kernel_size: int = 3, + point_feat_channels: int = 64, + deconv_kernel_size: int = 4, + class_agnostic: bool = False, + loss_grid: ConfigType = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=15), + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='GN', num_groups=36), + init_cfg: MultiConfig = [ + dict(type='Kaiming', layer=['Conv2d', 'Linear']), + dict( + type='Normal', + layer='ConvTranspose2d', + std=0.001, + override=dict( + type='Normal', + name='deconv2', + std=0.001, + bias=-np.log(0.99 / 0.01))) + ] + ) -> None: + super().__init__(init_cfg=init_cfg) + self.grid_points = grid_points + self.num_convs = num_convs + self.roi_feat_size = roi_feat_size + self.in_channels = in_channels + self.conv_kernel_size = conv_kernel_size + self.point_feat_channels = point_feat_channels + self.conv_out_channels = self.point_feat_channels * self.grid_points + self.class_agnostic = class_agnostic + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN': + assert self.conv_out_channels % norm_cfg['num_groups'] == 0 + + assert self.grid_points >= 4 + self.grid_size = int(np.sqrt(self.grid_points)) + if self.grid_size * self.grid_size != self.grid_points: + raise ValueError('grid_points must be a square number') + + # the predicted heatmap is half of whole_map_size + if not isinstance(self.roi_feat_size, int): + raise ValueError('Only square RoIs are supporeted in Grid R-CNN') + self.whole_map_size = self.roi_feat_size * 4 + + # compute point-wise sub-regions + self.sub_regions = self.calc_sub_regions() + + self.convs = [] + for i in range(self.num_convs): + in_channels = ( + self.in_channels if i == 0 else self.conv_out_channels) + stride = 2 if i == 0 else 1 + padding = (self.conv_kernel_size - 1) // 2 + self.convs.append( + ConvModule( + in_channels, + self.conv_out_channels, + self.conv_kernel_size, + stride=stride, + padding=padding, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=True)) + self.convs = nn.Sequential(*self.convs) + + self.deconv1 = nn.ConvTranspose2d( + self.conv_out_channels, + self.conv_out_channels, + kernel_size=deconv_kernel_size, + stride=2, + padding=(deconv_kernel_size - 2) // 2, + groups=grid_points) + self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels) + self.deconv2 = nn.ConvTranspose2d( + self.conv_out_channels, + grid_points, + kernel_size=deconv_kernel_size, + stride=2, + padding=(deconv_kernel_size - 2) // 2, + groups=grid_points) + + # find the 4-neighbor of each grid point + self.neighbor_points = [] + grid_size = self.grid_size + for i in range(grid_size): # i-th column + for j in range(grid_size): # j-th row + neighbors = [] + if i > 0: # left: (i - 1, j) + neighbors.append((i - 1) * grid_size + j) + if j > 0: # up: (i, j - 1) + neighbors.append(i * grid_size + j - 1) + if j < grid_size - 1: # down: (i, j + 1) + neighbors.append(i * grid_size + j + 1) + if i < grid_size - 1: # right: (i + 1, j) + neighbors.append((i + 1) * grid_size + j) + self.neighbor_points.append(tuple(neighbors)) + # total edges in the grid + self.num_edges = sum([len(p) for p in self.neighbor_points]) + + self.forder_trans = nn.ModuleList() # first-order feature transition + self.sorder_trans = nn.ModuleList() # second-order feature transition + for neighbors in self.neighbor_points: + fo_trans = nn.ModuleList() + so_trans = nn.ModuleList() + for _ in range(len(neighbors)): + # each transition module consists of a 5x5 depth-wise conv and + # 1x1 conv. + fo_trans.append( + nn.Sequential( + nn.Conv2d( + self.point_feat_channels, + self.point_feat_channels, + 5, + stride=1, + padding=2, + groups=self.point_feat_channels), + nn.Conv2d(self.point_feat_channels, + self.point_feat_channels, 1))) + so_trans.append( + nn.Sequential( + nn.Conv2d( + self.point_feat_channels, + self.point_feat_channels, + 5, + 1, + 2, + groups=self.point_feat_channels), + nn.Conv2d(self.point_feat_channels, + self.point_feat_channels, 1))) + self.forder_trans.append(fo_trans) + self.sorder_trans.append(so_trans) + + self.loss_grid = MODELS.build(loss_grid) + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + """forward function of ``GridHead``. + + Args: + x (Tensor): RoI features, has shape + (num_rois, num_channels, roi_feat_size, roi_feat_size). + + Returns: + Dict[str, Tensor]: Return a dict including fused and unfused + heatmap. + """ + assert x.shape[-1] == x.shape[-2] == self.roi_feat_size + # RoI feature transformation, downsample 2x + x = self.convs(x) + + c = self.point_feat_channels + # first-order fusion + x_fo = [None for _ in range(self.grid_points)] + for i, points in enumerate(self.neighbor_points): + x_fo[i] = x[:, i * c:(i + 1) * c] + for j, point_idx in enumerate(points): + x_fo[i] = x_fo[i] + self.forder_trans[i][j]( + x[:, point_idx * c:(point_idx + 1) * c]) + + # second-order fusion + x_so = [None for _ in range(self.grid_points)] + for i, points in enumerate(self.neighbor_points): + x_so[i] = x[:, i * c:(i + 1) * c] + for j, point_idx in enumerate(points): + x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx]) + + # predicted heatmap with fused features + x2 = torch.cat(x_so, dim=1) + x2 = self.deconv1(x2) + x2 = F.relu(self.norm1(x2), inplace=True) + heatmap = self.deconv2(x2) + + # predicted heatmap with original features (applicable during training) + if self.training: + x1 = x + x1 = self.deconv1(x1) + x1 = F.relu(self.norm1(x1), inplace=True) + heatmap_unfused = self.deconv2(x1) + else: + heatmap_unfused = heatmap + + return dict(fused=heatmap, unfused=heatmap_unfused) + + def calc_sub_regions(self) -> List[Tuple[float]]: + """Compute point specific representation regions. + + See `Grid R-CNN Plus `_ for details. + """ + # to make it consistent with the original implementation, half_size + # is computed as 2 * quarter_size, which is smaller + half_size = self.whole_map_size // 4 * 2 + sub_regions = [] + for i in range(self.grid_points): + x_idx = i // self.grid_size + y_idx = i % self.grid_size + if x_idx == 0: + sub_x1 = 0 + elif x_idx == self.grid_size - 1: + sub_x1 = half_size + else: + ratio = x_idx / (self.grid_size - 1) - 0.25 + sub_x1 = max(int(ratio * self.whole_map_size), 0) + + if y_idx == 0: + sub_y1 = 0 + elif y_idx == self.grid_size - 1: + sub_y1 = half_size + else: + ratio = y_idx / (self.grid_size - 1) - 0.25 + sub_y1 = max(int(ratio * self.whole_map_size), 0) + sub_regions.append( + (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size)) + return sub_regions + + def get_targets(self, sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict) -> Tensor: + """Calculate the ground truth for all samples in a batch according to + the sampling_results.". + + Args: + sampling_results (List[:obj:`SamplingResult`]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (:obj:`ConfigDict`): `train_cfg` of RCNN. + + Returns: + Tensor: Grid heatmap targets. + """ + # mix all samples (across images) together. + pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results], + dim=0).cpu() + pos_gt_bboxes = torch.cat( + [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu() + assert pos_bboxes.shape == pos_gt_bboxes.shape + + # expand pos_bboxes to 2x of original size + x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2 + y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2 + x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2 + y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2 + pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1) + pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1) + + num_rois = pos_bboxes.shape[0] + map_size = self.whole_map_size + # this is not the final target shape + targets = torch.zeros((num_rois, self.grid_points, map_size, map_size), + dtype=torch.float) + + # pre-compute interpolation factors for all grid points. + # the first item is the factor of x-dim, and the second is y-dim. + # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1) + factors = [] + for j in range(self.grid_points): + x_idx = j // self.grid_size + y_idx = j % self.grid_size + factors.append((1 - x_idx / (self.grid_size - 1), + 1 - y_idx / (self.grid_size - 1))) + + radius = rcnn_train_cfg.pos_radius + radius2 = radius**2 + for i in range(num_rois): + # ignore small bboxes + if (pos_bbox_ws[i] <= self.grid_size + or pos_bbox_hs[i] <= self.grid_size): + continue + # for each grid point, mark a small circle as positive + for j in range(self.grid_points): + factor_x, factor_y = factors[j] + gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + ( + 1 - factor_x) * pos_gt_bboxes[i, 2] + gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + ( + 1 - factor_y) * pos_gt_bboxes[i, 3] + + cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] * + map_size) + cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] * + map_size) + + for x in range(cx - radius, cx + radius + 1): + for y in range(cy - radius, cy + radius + 1): + if x >= 0 and x < map_size and y >= 0 and y < map_size: + if (x - cx)**2 + (y - cy)**2 <= radius2: + targets[i, j, y, x] = 1 + # reduce the target heatmap size by a half + # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688). + sub_targets = [] + for i in range(self.grid_points): + sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i] + sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2]) + sub_targets = torch.cat(sub_targets, dim=1) + sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device) + return sub_targets + + def loss(self, grid_pred: Tensor, sample_idx: Tensor, + sampling_results: List[SamplingResult], + rcnn_train_cfg: ConfigDict) -> dict: + """Calculate the loss based on the features extracted by the grid head. + + Args: + grid_pred (dict[str, Tensor]): Outputs of grid_head forward. + sample_idx (Tensor): The sampling index of ``grid_pred``. + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN. + + Returns: + dict: A dictionary of loss and targets components. + """ + grid_targets = self.get_targets(sampling_results, rcnn_train_cfg) + grid_targets = grid_targets[sample_idx] + + loss_fused = self.loss_grid(grid_pred['fused'], grid_targets) + loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets) + loss_grid = loss_fused + loss_unfused + return dict(loss_grid=loss_grid) + + def predict_by_feat(self, + grid_preds: Dict[str, Tensor], + results_list: List[InstanceData], + batch_img_metas: List[dict], + rescale: bool = False) -> InstanceList: + """Adjust the predicted bboxes from bbox head. + + Args: + grid_preds (dict[str, Tensor]): dictionary outputted by forward + function. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + batch_img_metas (list[dict]): List of image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape \ + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), the last \ + dimension 4 arrange as (x1, y1, x2, y2). + """ + num_roi_per_img = tuple(res.bboxes.size(0) for res in results_list) + grid_preds = { + k: v.split(num_roi_per_img, 0) + for k, v in grid_preds.items() + } + + for i, results in enumerate(results_list): + if len(results) != 0: + bboxes = self._predict_by_feat_single( + grid_pred=grid_preds['fused'][i], + bboxes=results.bboxes, + img_meta=batch_img_metas[i], + rescale=rescale) + results.bboxes = bboxes + return results_list + + def _predict_by_feat_single(self, + grid_pred: Tensor, + bboxes: Tensor, + img_meta: dict, + rescale: bool = False) -> Tensor: + """Adjust ``bboxes`` according to ``grid_pred``. + + Args: + grid_pred (Tensor): Grid fused heatmap. + bboxes (Tensor): Predicted bboxes, has shape (n, 4) + img_meta (dict): image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + Tensor: adjusted bboxes. + """ + assert bboxes.size(0) == grid_pred.size(0) + grid_pred = grid_pred.sigmoid() + + R, c, h, w = grid_pred.shape + half_size = self.whole_map_size // 4 * 2 + assert h == w == half_size + assert c == self.grid_points + + # find the point with max scores in the half-sized heatmap + grid_pred = grid_pred.view(R * c, h * w) + pred_scores, pred_position = grid_pred.max(dim=1) + xs = pred_position % w + ys = pred_position // w + + # get the position in the whole heatmap instead of half-sized heatmap + for i in range(self.grid_points): + xs[i::self.grid_points] += self.sub_regions[i][0] + ys[i::self.grid_points] += self.sub_regions[i][1] + + # reshape to (num_rois, grid_points) + pred_scores, xs, ys = tuple( + map(lambda x: x.view(R, c), [pred_scores, xs, ys])) + + # get expanded pos_bboxes + widths = (bboxes[:, 2] - bboxes[:, 0]).unsqueeze(-1) + heights = (bboxes[:, 3] - bboxes[:, 1]).unsqueeze(-1) + x1 = (bboxes[:, 0, None] - widths / 2) + y1 = (bboxes[:, 1, None] - heights / 2) + # map the grid point to the absolute coordinates + abs_xs = (xs.float() + 0.5) / w * widths + x1 + abs_ys = (ys.float() + 0.5) / h * heights + y1 + + # get the grid points indices that fall on the bbox boundaries + x1_inds = [i for i in range(self.grid_size)] + y1_inds = [i * self.grid_size for i in range(self.grid_size)] + x2_inds = [ + self.grid_points - self.grid_size + i + for i in range(self.grid_size) + ] + y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)] + + # voting of all grid points on some boundary + bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum( + dim=1, keepdim=True) / ( + pred_scores[:, x1_inds].sum(dim=1, keepdim=True)) + bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum( + dim=1, keepdim=True) / ( + pred_scores[:, y1_inds].sum(dim=1, keepdim=True)) + bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum( + dim=1, keepdim=True) / ( + pred_scores[:, x2_inds].sum(dim=1, keepdim=True)) + bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum( + dim=1, keepdim=True) / ( + pred_scores[:, y2_inds].sum(dim=1, keepdim=True)) + + bboxes = torch.cat([bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2], dim=1) + bboxes[:, [0, 2]].clamp_(min=0, max=img_meta['img_shape'][1]) + bboxes[:, [1, 3]].clamp_(min=0, max=img_meta['img_shape'][0]) + + if rescale: + assert img_meta.get('scale_factor') is not None + bboxes /= bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + + return bboxes diff --git a/mmdet/models/roi_heads/mask_heads/htc_mask_head.py b/mmdet/models/roi_heads/mask_heads/htc_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..73ac1e6e5f115927e1a2accdd693aae512cac753 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/htc_mask_head.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +from mmcv.cnn import ConvModule +from torch import Tensor + +from mmdet.registry import MODELS +from .fcn_mask_head import FCNMaskHead + + +@MODELS.register_module() +class HTCMaskHead(FCNMaskHead): + """Mask head for HTC. + + Args: + with_conv_res (bool): Whether add conv layer for ``res_feat``. + Defaults to True. + """ + + def __init__(self, with_conv_res: bool = True, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.with_conv_res = with_conv_res + if self.with_conv_res: + self.conv_res = ConvModule( + self.conv_out_channels, + self.conv_out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + + def forward(self, + x: Tensor, + res_feat: Optional[Tensor] = None, + return_logits: bool = True, + return_feat: bool = True) -> Union[Tensor, List[Tensor]]: + """ + Args: + x (Tensor): Feature map. + res_feat (Tensor, optional): Feature for residual connection. + Defaults to None. + return_logits (bool): Whether return mask logits. Defaults to True. + return_feat (bool): Whether return feature map. Defaults to True. + + Returns: + Union[Tensor, List[Tensor]]: The return result is one of three + results: res_feat, logits, or [logits, res_feat]. + """ + assert not (not return_logits and not return_feat) + if res_feat is not None: + assert self.with_conv_res + res_feat = self.conv_res(res_feat) + x = x + res_feat + for conv in self.convs: + x = conv(x) + res_feat = x + outs = [] + if return_logits: + x = self.upsample(x) + if self.upsample_method == 'deconv': + x = self.relu(x) + mask_preds = self.conv_logits(x) + outs.append(mask_preds) + if return_feat: + outs.append(res_feat) + return outs if len(outs) > 1 else outs[0] diff --git a/mmdet/models/roi_heads/mask_heads/mask_point_head.py b/mmdet/models/roi_heads/mask_heads/mask_point_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2084f59f07b48bf2e5b05bb7af61172df8737478 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/mask_point_head.py @@ -0,0 +1,284 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa + +from typing import List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.utils import (get_uncertain_point_coords_with_randomness, + get_uncertainty) +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType + + +@MODELS.register_module() +class MaskPointHead(BaseModule): + """A mask point head use in PointRend. + + ``MaskPointHead`` use shared multi-layer perceptron (equivalent to + nn.Conv1d) to predict the logit of input points. The fine-grained feature + and coarse feature will be concatenate together for predication. + + Args: + num_fcs (int): Number of fc layers in the head. Defaults to 3. + in_channels (int): Number of input channels. Defaults to 256. + fc_channels (int): Number of fc channels. Defaults to 256. + num_classes (int): Number of classes for logits. Defaults to 80. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Defaults to False. + coarse_pred_each_layer (bool): Whether concatenate coarse feature with + the output of each fc layer. Defaults to True. + conv_cfg (:obj:`ConfigDict` or dict): Dictionary to construct + and config conv layer. Defaults to dict(type='Conv1d')). + norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to construct + and config norm layer. Defaults to None. + loss_point (:obj:`ConfigDict` or dict): Dictionary to construct and + config loss layer of point head. Defaults to + dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0). + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + num_fcs: int = 3, + in_channels: int = 256, + fc_channels: int = 256, + class_agnostic: bool = False, + coarse_pred_each_layer: bool = True, + conv_cfg: ConfigType = dict(type='Conv1d'), + norm_cfg: OptConfigType = None, + act_cfg: ConfigType = dict(type='ReLU'), + loss_point: ConfigType = dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + init_cfg: MultiConfig = dict( + type='Normal', std=0.001, override=dict(name='fc_logits')) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.num_fcs = num_fcs + self.in_channels = in_channels + self.fc_channels = fc_channels + self.num_classes = num_classes + self.class_agnostic = class_agnostic + self.coarse_pred_each_layer = coarse_pred_each_layer + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.loss_point = MODELS.build(loss_point) + + fc_in_channels = in_channels + num_classes + self.fcs = nn.ModuleList() + for _ in range(num_fcs): + fc = ConvModule( + fc_in_channels, + fc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.fcs.append(fc) + fc_in_channels = fc_channels + fc_in_channels += num_classes if self.coarse_pred_each_layer else 0 + + out_channels = 1 if self.class_agnostic else self.num_classes + self.fc_logits = nn.Conv1d( + fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, fine_grained_feats: Tensor, + coarse_feats: Tensor) -> Tensor: + """Classify each point base on fine grained and coarse feats. + + Args: + fine_grained_feats (Tensor): Fine grained feature sampled from FPN, + shape (num_rois, in_channels, num_points). + coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead, + shape (num_rois, num_classes, num_points). + + Returns: + Tensor: Point classification results, + shape (num_rois, num_class, num_points). + """ + + x = torch.cat([fine_grained_feats, coarse_feats], dim=1) + for fc in self.fcs: + x = fc(x) + if self.coarse_pred_each_layer: + x = torch.cat((x, coarse_feats), dim=1) + return self.fc_logits(x) + + def get_targets(self, rois: Tensor, rel_roi_points: Tensor, + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + cfg: ConfigType) -> Tensor: + """Get training targets of MaskPointHead for all images. + + Args: + rois (Tensor): Region of Interest, shape (num_rois, 5). + rel_roi_points (Tensor): Points coordinates relative to RoI, shape + (num_rois, num_points, 2). + sampling_results (:obj:`SamplingResult`): Sampling result after + sampling and assignment. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + cfg (obj:`ConfigDict` or dict): Training cfg. + + Returns: + Tensor: Point target, shape (num_rois, num_points). + """ + + num_imgs = len(sampling_results) + rois_list = [] + rel_roi_points_list = [] + for batch_ind in range(num_imgs): + inds = (rois[:, 0] == batch_ind) + rois_list.append(rois[inds]) + rel_roi_points_list.append(rel_roi_points[inds]) + pos_assigned_gt_inds_list = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + cfg_list = [cfg for _ in range(num_imgs)] + + point_targets = map(self._get_targets_single, rois_list, + rel_roi_points_list, pos_assigned_gt_inds_list, + batch_gt_instances, cfg_list) + point_targets = list(point_targets) + + if len(point_targets) > 0: + point_targets = torch.cat(point_targets) + + return point_targets + + def _get_targets_single(self, rois: Tensor, rel_roi_points: Tensor, + pos_assigned_gt_inds: Tensor, + gt_instances: InstanceData, + cfg: ConfigType) -> Tensor: + """Get training target of MaskPointHead for each image.""" + num_pos = rois.size(0) + num_points = cfg.num_points + if num_pos > 0: + gt_masks_th = ( + gt_instances.masks.to_tensor(rois.dtype, + rois.device).index_select( + 0, pos_assigned_gt_inds)) + gt_masks_th = gt_masks_th.unsqueeze(1) + rel_img_points = rel_roi_point_to_rel_img_point( + rois, rel_roi_points, gt_masks_th) + point_targets = point_sample(gt_masks_th, + rel_img_points).squeeze(1) + else: + point_targets = rois.new_zeros((0, num_points)) + return point_targets + + def loss_and_target(self, point_pred: Tensor, rel_roi_points: Tensor, + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + cfg: ConfigType) -> dict: + """Calculate loss for MaskPointHead. + + Args: + point_pred (Tensor): Point predication result, shape + (num_rois, num_classes, num_points). + rel_roi_points (Tensor): Points coordinates relative to RoI, shape + (num_rois, num_points, 2). + sampling_results (:obj:`SamplingResult`): Sampling result after + sampling and assignment. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + cfg (obj:`ConfigDict` or dict): Training cfg. + + Returns: + dict: a dictionary of point loss and point target. + """ + rois = bbox2roi([res.pos_bboxes for res in sampling_results]) + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + + point_target = self.get_targets(rois, rel_roi_points, sampling_results, + batch_gt_instances, cfg) + if self.class_agnostic: + loss_point = self.loss_point(point_pred, point_target, + torch.zeros_like(pos_labels)) + else: + loss_point = self.loss_point(point_pred, point_target, pos_labels) + + return dict(loss_point=loss_point, point_target=point_target) + + def get_roi_rel_points_train(self, mask_preds: Tensor, labels: Tensor, + cfg: ConfigType) -> Tensor: + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + '_get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_preds (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (Tensor): The ground truth class for each instance. + cfg (:obj:`ConfigDict` or dict): Training config of point head. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + point_coords = get_uncertain_point_coords_with_randomness( + mask_preds, labels, cfg.num_points, cfg.oversample_ratio, + cfg.importance_sample_ratio) + return point_coords + + def get_roi_rel_points_test(self, mask_preds: Tensor, label_preds: Tensor, + cfg: ConfigType) -> Tuple[Tensor, Tensor]: + """Get ``num_points`` most uncertain points during test. + + Args: + mask_preds (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + label_preds (Tensor): The predication class for each instance. + cfg (:obj:`ConfigDict` or dict): Testing config of point head. + + Returns: + tuple: + + - point_indices (Tensor): A tensor of shape (num_rois, num_points) + that contains indices from [0, mask_height x mask_width) of the + most uncertain points. + - point_coords (Tensor): A tensor of shape (num_rois, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the [mask_height, mask_width] grid. + """ + num_points = cfg.subdivision_num_points + uncertainty_map = get_uncertainty(mask_preds, label_preds) + num_rois, _, mask_height, mask_width = uncertainty_map.shape + + # During ONNX exporting, the type of each elements of 'shape' is + # `Tensor(float)`, while it is `float` during PyTorch inference. + if isinstance(mask_height, torch.Tensor): + h_step = 1.0 / mask_height.float() + w_step = 1.0 / mask_width.float() + else: + h_step = 1.0 / mask_height + w_step = 1.0 / mask_width + # cast to int to avoid dynamic K for TopK op in ONNX + mask_size = int(mask_height * mask_width) + uncertainty_map = uncertainty_map.view(num_rois, mask_size) + num_points = min(mask_size, num_points) + point_indices = uncertainty_map.topk(num_points, dim=1)[1] + xs = w_step / 2.0 + (point_indices % mask_width).float() * w_step + ys = h_step / 2.0 + (point_indices // mask_width).float() * h_step + point_coords = torch.stack([xs, ys], dim=2) + return point_indices, point_coords diff --git a/mmdet/models/roi_heads/mask_heads/maskiou_head.py b/mmdet/models/roi_heads/mask_heads/maskiou_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8901871e754c491f7bc94eb68a27fa1b50e29148 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/maskiou_head.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import Conv2d, Linear, MaxPool2d +from mmengine.config import ConfigDict +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, InstanceList, OptMultiConfig + + +@MODELS.register_module() +class MaskIoUHead(BaseModule): + """Mask IoU Head. + + This head predicts the IoU of predicted masks and corresponding gt masks. + + Args: + num_convs (int): The number of convolution layers. Defaults to 4. + num_fcs (int): The number of fully connected layers. Defaults to 2. + roi_feat_size (int): RoI feature size. Default to 14. + in_channels (int): The channel number of inputs features. + Defaults to 256. + conv_out_channels (int): The feature channels of convolution layers. + Defaults to 256. + fc_out_channels (int): The feature channels of fully connected layers. + Defaults to 1024. + num_classes (int): Number of categories excluding the background + category. Defaults to 80. + loss_iou (:obj:`ConfigDict` or dict): IoU loss. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_convs: int = 4, + num_fcs: int = 2, + roi_feat_size: int = 14, + in_channels: int = 256, + conv_out_channels: int = 256, + fc_out_channels: int = 1024, + num_classes: int = 80, + loss_iou: ConfigType = dict(type='MSELoss', loss_weight=0.5), + init_cfg: OptMultiConfig = [ + dict(type='Kaiming', override=dict(name='convs')), + dict(type='Caffe2Xavier', override=dict(name='fcs')), + dict(type='Normal', std=0.01, override=dict(name='fc_mask_iou')) + ] + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.num_classes = num_classes + + self.convs = nn.ModuleList() + for i in range(num_convs): + if i == 0: + # concatenation of mask feature and mask prediction + in_channels = self.in_channels + 1 + else: + in_channels = self.conv_out_channels + stride = 2 if i == num_convs - 1 else 1 + self.convs.append( + Conv2d( + in_channels, + self.conv_out_channels, + 3, + stride=stride, + padding=1)) + + roi_feat_size = _pair(roi_feat_size) + pooled_area = (roi_feat_size[0] // 2) * (roi_feat_size[1] // 2) + self.fcs = nn.ModuleList() + for i in range(num_fcs): + in_channels = ( + self.conv_out_channels * + pooled_area if i == 0 else self.fc_out_channels) + self.fcs.append(Linear(in_channels, self.fc_out_channels)) + + self.fc_mask_iou = Linear(self.fc_out_channels, self.num_classes) + self.relu = nn.ReLU() + self.max_pool = MaxPool2d(2, 2) + self.loss_iou = MODELS.build(loss_iou) + + def forward(self, mask_feat: Tensor, mask_preds: Tensor) -> Tensor: + """Forward function. + + Args: + mask_feat (Tensor): Mask features from upstream models. + mask_preds (Tensor): Mask predictions from mask head. + + Returns: + Tensor: Mask IoU predictions. + """ + mask_preds = mask_preds.sigmoid() + mask_pred_pooled = self.max_pool(mask_preds.unsqueeze(1)) + + x = torch.cat((mask_feat, mask_pred_pooled), 1) + + for conv in self.convs: + x = self.relu(conv(x)) + x = x.flatten(1) + for fc in self.fcs: + x = self.relu(fc(x)) + mask_iou = self.fc_mask_iou(x) + return mask_iou + + def loss_and_target(self, mask_iou_pred: Tensor, mask_preds: Tensor, + mask_targets: Tensor, + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> dict: + """Calculate the loss and targets of MaskIoUHead. + + Args: + mask_iou_pred (Tensor): Mask IoU predictions results, has shape + (num_pos, num_classes) + mask_preds (Tensor): Mask predictions from mask head, has shape + (num_pos, mask_size, mask_size). + mask_targets (Tensor): The ground truth masks assigned with + predictions, has shape + (num_pos, mask_size, mask_size). + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It includes ``masks`` inside. + rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN. + + Returns: + dict: A dictionary of loss and targets components. + The targets are only used for cascade rcnn. + """ + mask_iou_targets = self.get_targets( + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + mask_preds=mask_preds, + mask_targets=mask_targets, + rcnn_train_cfg=rcnn_train_cfg) + + pos_inds = mask_iou_targets > 0 + if pos_inds.sum() > 0: + loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds], + mask_iou_targets[pos_inds]) + else: + loss_mask_iou = mask_iou_pred.sum() * 0 + return dict(loss_mask_iou=loss_mask_iou) + + def get_targets(self, sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, mask_preds: Tensor, + mask_targets: Tensor, + rcnn_train_cfg: ConfigDict) -> Tensor: + """Compute target of mask IoU. + + Mask IoU target is the IoU of the predicted mask (inside a bbox) and + the gt mask of corresponding gt mask (the whole instance). + The intersection area is computed inside the bbox, and the gt mask area + is computed with two steps, firstly we compute the gt area inside the + bbox, then divide it by the area ratio of gt area inside the bbox and + the gt area of the whole instance. + + Args: + sampling_results (list[:obj:`SamplingResult`]): sampling results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It includes ``masks`` inside. + mask_preds (Tensor): Predicted masks of each positive proposal, + shape (num_pos, h, w). + mask_targets (Tensor): Gt mask of each positive proposal, + binary map of the shape (num_pos, h, w). + rcnn_train_cfg (obj:`ConfigDict`): Training config for R-CNN part. + + Returns: + Tensor: mask iou target (length == num positive). + """ + pos_proposals = [res.pos_priors for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + gt_masks = [res.masks for res in batch_gt_instances] + + # compute the area ratio of gt areas inside the proposals and + # the whole instance + area_ratios = map(self._get_area_ratio, pos_proposals, + pos_assigned_gt_inds, gt_masks) + area_ratios = torch.cat(list(area_ratios)) + assert mask_targets.size(0) == area_ratios.size(0) + + mask_preds = (mask_preds > rcnn_train_cfg.mask_thr_binary).float() + mask_pred_areas = mask_preds.sum((-1, -2)) + + # mask_preds and mask_targets are binary maps + overlap_areas = (mask_preds * mask_targets).sum((-1, -2)) + + # compute the mask area of the whole instance + gt_full_areas = mask_targets.sum((-1, -2)) / (area_ratios + 1e-7) + + mask_iou_targets = overlap_areas / ( + mask_pred_areas + gt_full_areas - overlap_areas) + return mask_iou_targets + + def _get_area_ratio(self, pos_proposals: Tensor, + pos_assigned_gt_inds: Tensor, + gt_masks: InstanceData) -> Tensor: + """Compute area ratio of the gt mask inside the proposal and the gt + mask of the corresponding instance. + + Args: + pos_proposals (Tensor): Positive proposals, has shape (num_pos, 4). + pos_assigned_gt_inds (Tensor): positive proposals assigned ground + truth index. + gt_masks (BitmapMask or PolygonMask): Gt masks (the whole instance) + of each image, with the same shape of the input image. + + Returns: + Tensor: The area ratio of the gt mask inside the proposal and the + gt mask of the corresponding instance. + """ + num_pos = pos_proposals.size(0) + if num_pos > 0: + area_ratios = [] + proposals_np = pos_proposals.cpu().numpy() + pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() + # compute mask areas of gt instances (batch processing for speedup) + gt_instance_mask_area = gt_masks.areas + for i in range(num_pos): + gt_mask = gt_masks[pos_assigned_gt_inds[i]] + + # crop the gt mask inside the proposal + bbox = proposals_np[i, :].astype(np.int32) + gt_mask_in_proposal = gt_mask.crop(bbox) + + ratio = gt_mask_in_proposal.areas[0] / ( + gt_instance_mask_area[pos_assigned_gt_inds[i]] + 1e-7) + area_ratios.append(ratio) + area_ratios = torch.from_numpy(np.stack(area_ratios)).float().to( + pos_proposals.device) + else: + area_ratios = pos_proposals.new_zeros((0, )) + return area_ratios + + def predict_by_feat(self, mask_iou_preds: Tuple[Tensor], + results_list: InstanceList) -> InstanceList: + """Predict the mask iou and calculate it into ``results.scores``. + + Args: + mask_iou_preds (Tensor): Mask IoU predictions results, has shape + (num_proposals, num_classes) + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + assert len(mask_iou_preds) == len(results_list) + for results, mask_iou_pred in zip(results_list, mask_iou_preds): + labels = results.labels + scores = results.scores + results.scores = scores * mask_iou_pred[range(labels.size(0)), + labels] + return results_list diff --git a/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py b/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd30c337c37f4e280980e459c126df177fe7efa --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.models.layers import ResLayer, SimplifiedBasicBlock +from mmdet.registry import MODELS +from .fcn_mask_head import FCNMaskHead + + +@MODELS.register_module() +class SCNetMaskHead(FCNMaskHead): + """Mask head for `SCNet `_. + + Args: + conv_to_res (bool, optional): if True, change the conv layers to + ``SimplifiedBasicBlock``. + """ + + def __init__(self, conv_to_res: bool = True, **kwargs) -> None: + super().__init__(**kwargs) + self.conv_to_res = conv_to_res + if conv_to_res: + assert self.conv_kernel_size == 3 + self.num_res_blocks = self.num_convs // 2 + self.convs = ResLayer( + SimplifiedBasicBlock, + self.in_channels, + self.conv_out_channels, + self.num_res_blocks, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) diff --git a/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py b/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py new file mode 100644 index 0000000000000000000000000000000000000000..55c5c8e4fae7d4e941a770d985c7253fd70f2226 --- /dev/null +++ b/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.models.layers import ResLayer, SimplifiedBasicBlock +from mmdet.registry import MODELS +from .fused_semantic_head import FusedSemanticHead + + +@MODELS.register_module() +class SCNetSemanticHead(FusedSemanticHead): + """Mask head for `SCNet `_. + + Args: + conv_to_res (bool, optional): if True, change the conv layers to + ``SimplifiedBasicBlock``. + """ + + def __init__(self, conv_to_res: bool = True, **kwargs) -> None: + super().__init__(**kwargs) + self.conv_to_res = conv_to_res + if self.conv_to_res: + num_res_blocks = self.num_convs // 2 + self.convs = ResLayer( + SimplifiedBasicBlock, + self.in_channels, + self.conv_out_channels, + num_res_blocks, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + self.num_convs = num_res_blocks diff --git a/mmdet/models/roi_heads/mask_scoring_roi_head.py b/mmdet/models/roi_heads/mask_scoring_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6545c0ed41ee7ad17b5f1b841f8bc8d65a7b6391 --- /dev/null +++ b/mmdet/models/roi_heads/mask_scoring_roi_head.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList +from ..task_modules.samplers import SamplingResult +from ..utils.misc import empty_instances +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class MaskScoringRoIHead(StandardRoIHead): + """Mask Scoring RoIHead for `Mask Scoring RCNN. + + `_. + + Args: + mask_iou_head (:obj`ConfigDict`, dict): The config of mask_iou_head. + """ + + def __init__(self, mask_iou_head: ConfigType, **kwargs): + assert mask_iou_head is not None + super().__init__(**kwargs) + self.mask_iou_head = MODELS.build(mask_iou_head) + + def forward(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList = None) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + results = () + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + rois = bbox2roi(proposals) + # bbox head + if self.with_bbox: + bbox_results = self._bbox_forward(x, rois) + results = results + (bbox_results['cls_score'], + bbox_results['bbox_pred']) + # mask head + if self.with_mask: + mask_rois = rois[:100] + mask_results = self._mask_forward(x, mask_rois) + results = results + (mask_results['mask_preds'], ) + + # mask iou head + cls_score = bbox_results['cls_score'][:100] + mask_preds = mask_results['mask_preds'] + mask_feats = mask_results['mask_feats'] + _, labels = cls_score[:, :self.bbox_head.num_classes].max(dim=1) + mask_iou_preds = self.mask_iou_head( + mask_feats, mask_preds[range(labels.size(0)), labels]) + results = results + (mask_iou_preds, ) + + return results + + def mask_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult], bbox_feats, + batch_gt_instances: InstanceList) -> dict: + """Perform forward propagation and loss calculation of the mask head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + bbox_feats (Tensor): Extract bbox RoI features. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + - `mask_targets` (Tensor): Mask target of each positive\ + proposals in the image. + - `loss_mask` (dict): A dictionary of mask loss components. + - `loss_mask_iou` (Tensor): mask iou loss. + """ + if not self.share_roi_extractor: + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward(x, pos_rois) + else: + pos_inds = [] + device = bbox_feats.device + for res in sampling_results: + pos_inds.append( + torch.ones( + res.pos_priors.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds.append( + torch.zeros( + res.neg_priors.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds = torch.cat(pos_inds) + + mask_results = self._mask_forward( + x, pos_inds=pos_inds, bbox_feats=bbox_feats) + + mask_loss_and_target = self.mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg) + mask_targets = mask_loss_and_target['mask_targets'] + mask_results.update(loss_mask=mask_loss_and_target['loss_mask']) + if mask_results['loss_mask'] is None: + return mask_results + + # mask iou head forward and loss + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + pos_mask_pred = mask_results['mask_preds'][ + range(mask_results['mask_preds'].size(0)), pos_labels] + mask_iou_pred = self.mask_iou_head(mask_results['mask_feats'], + pos_mask_pred) + pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)), + pos_labels] + + loss_mask_iou = self.mask_iou_head.loss_and_target( + pos_mask_iou_pred, pos_mask_pred, mask_targets, sampling_results, + batch_gt_instances, self.train_cfg) + mask_results['loss_mask'].update(loss_mask_iou) + return mask_results + + def predict_mask(self, + x: Tensor, + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results['mask_preds'] + mask_feats = mask_results['mask_feats'] + # get mask scores with mask iou head + labels = torch.cat([res.labels for res in results_list]) + mask_iou_preds = self.mask_iou_head( + mask_feats, mask_preds[range(labels.size(0)), labels]) + # split batch mask prediction back to each image + num_mask_rois_per_img = [len(res) for res in results_list] + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + mask_iou_preds = mask_iou_preds.split(num_mask_rois_per_img, 0) + + # TODO: Handle the case where rescale is false + results_list = self.mask_head.predict_by_feat( + mask_preds=mask_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + results_list = self.mask_iou_head.predict_by_feat( + mask_iou_preds=mask_iou_preds, results_list=results_list) + return results_list diff --git a/mmdet/models/roi_heads/multi_instance_roi_head.py b/mmdet/models/roi_heads/multi_instance_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..fee55b0a5d341c03165649f59737fd34d85c207e --- /dev/null +++ b/mmdet/models/roi_heads/multi_instance_roi_head.py @@ -0,0 +1,226 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList +from ..task_modules.samplers import SamplingResult +from ..utils import empty_instances, unpack_gt_instances +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class MultiInstanceRoIHead(StandardRoIHead): + """The roi head for Multi-instance prediction.""" + + def __init__(self, num_instance: int = 2, *args, **kwargs) -> None: + self.num_instance = num_instance + super().__init__(*args, **kwargs) + + def init_bbox_head(self, bbox_roi_extractor: ConfigType, + bbox_head: ConfigType) -> None: + """Initialize box head and box roi extractor. + + Args: + bbox_roi_extractor (dict or ConfigDict): Config of box + roi extractor. + bbox_head (dict or ConfigDict): Config of box in box head. + """ + self.bbox_roi_extractor = MODELS.build(bbox_roi_extractor) + self.bbox_head = MODELS.build(bbox_head) + + def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: + """Box head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `cls_score_ref` (Tensor): The cls_score after refine model. + - `bbox_pred_ref` (Tensor): The bbox_pred after refine model. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + # TODO: a more flexible way to decide which feature maps to use + bbox_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + bbox_results = self.bbox_head(bbox_feats) + + if self.bbox_head.with_refine: + bbox_results = dict( + cls_score=bbox_results[0], + bbox_pred=bbox_results[1], + cls_score_ref=bbox_results[2], + bbox_pred_ref=bbox_results[3], + bbox_feats=bbox_feats) + else: + bbox_results = dict( + cls_score=bbox_results[0], + bbox_pred=bbox_results[1], + bbox_feats=bbox_feats) + + return bbox_results + + def bbox_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult]) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + + # If there is a refining process, add refine loss. + if 'cls_score_ref' in bbox_results: + bbox_loss_and_target = self.bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg) + bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox']) + bbox_loss_and_target_ref = self.bbox_head.loss_and_target( + cls_score=bbox_results['cls_score_ref'], + bbox_pred=bbox_results['bbox_pred_ref'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg) + bbox_results['loss_bbox']['loss_rcnn_emd_ref'] = \ + bbox_loss_and_target_ref['loss_bbox']['loss_rcnn_emd'] + else: + bbox_loss_and_target = self.bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg) + bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox']) + + return bbox_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: List[DetDataSample]) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + sampling_results = [] + for i in range(len(batch_data_samples)): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + batch_gt_instances_ignore=batch_gt_instances_ignore[i]) + sampling_results.append(sampling_result) + + losses = dict() + # bbox head loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results) + losses.update(bbox_results['loss_bbox']) + + return losses + + def predict_bbox(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + rpn_results_list: InstanceList, + rcnn_test_cfg: ConfigType, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the bbox head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + proposals = [res.bboxes for res in rpn_results_list] + rois = bbox2roi(proposals) + + if rois.shape[0] == 0: + return empty_instances( + batch_img_metas, rois.device, task_type='bbox') + + bbox_results = self._bbox_forward(x, rois) + + # split batch bbox prediction back to each image + if 'cls_score_ref' in bbox_results: + cls_scores = bbox_results['cls_score_ref'] + bbox_preds = bbox_results['bbox_pred_ref'] + else: + cls_scores = bbox_results['cls_score'] + bbox_preds = bbox_results['bbox_pred'] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = rois.split(num_proposals_per_img, 0) + cls_scores = cls_scores.split(num_proposals_per_img, 0) + + if bbox_preds is not None: + bbox_preds = bbox_preds.split(num_proposals_per_img, 0) + else: + bbox_preds = (None, ) * len(proposals) + + result_list = self.bbox_head.predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale) + return result_list diff --git a/mmdet/models/roi_heads/pisa_roi_head.py b/mmdet/models/roi_heads/pisa_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..45d59879da73b48df790c55d40a4a88f1d099111 --- /dev/null +++ b/mmdet/models/roi_heads/pisa_roi_head.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +from torch import Tensor + +from mmdet.models.task_modules import SamplingResult +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList +from ..losses.pisa_loss import carl_loss, isr_p +from ..utils import unpack_gt_instances +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class PISARoIHead(StandardRoIHead): + r"""The RoI head for `Prime Sample Attention in Object Detection + `_.""" + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: List[DetDataSample]) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + neg_label_weights = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + if isinstance(sampling_result, tuple): + sampling_result, neg_label_weight = sampling_result + sampling_results.append(sampling_result) + neg_label_weights.append(neg_label_weight) + + losses = dict() + # bbox head forward and loss + if self.with_bbox: + bbox_results = self.bbox_loss( + x, sampling_results, neg_label_weights=neg_label_weights) + losses.update(bbox_results['loss_bbox']) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(x, sampling_results, + bbox_results['bbox_feats'], + batch_gt_instances) + losses.update(mask_results['loss_mask']) + + return losses + + def bbox_loss(self, + x: Tuple[Tensor], + sampling_results: List[SamplingResult], + neg_label_weights: List[Tensor] = None) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + bbox_targets = self.bbox_head.get_targets(sampling_results, + self.train_cfg) + + # neg_label_weights obtained by sampler is image-wise, mapping back to + # the corresponding location in label weights + if neg_label_weights[0] is not None: + label_weights = bbox_targets[1] + cur_num_rois = 0 + for i in range(len(sampling_results)): + num_pos = sampling_results[i].pos_inds.size(0) + num_neg = sampling_results[i].neg_inds.size(0) + label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos + + num_neg] = neg_label_weights[i] + cur_num_rois += num_pos + num_neg + + cls_score = bbox_results['cls_score'] + bbox_pred = bbox_results['bbox_pred'] + + # Apply ISR-P + isr_cfg = self.train_cfg.get('isr', None) + if isr_cfg is not None: + bbox_targets = isr_p( + cls_score, + bbox_pred, + bbox_targets, + rois, + sampling_results, + self.bbox_head.loss_cls, + self.bbox_head.bbox_coder, + **isr_cfg, + num_class=self.bbox_head.num_classes) + loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois, + *bbox_targets) + + # Add CARL Loss + carl_cfg = self.train_cfg.get('carl', None) + if carl_cfg is not None: + loss_carl = carl_loss( + cls_score, + bbox_targets[0], + bbox_pred, + bbox_targets[2], + self.bbox_head.loss_bbox, + **carl_cfg, + num_class=self.bbox_head.num_classes) + loss_bbox.update(loss_carl) + + bbox_results.update(loss_bbox=loss_bbox) + return bbox_results diff --git a/mmdet/models/roi_heads/point_rend_roi_head.py b/mmdet/models/roi_heads/point_rend_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0641549631e243c3db25039b01fed64fb1e0d1 --- /dev/null +++ b/mmdet/models/roi_heads/point_rend_roi_head.py @@ -0,0 +1,236 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList +from ..task_modules.samplers import SamplingResult +from ..utils import empty_instances +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class PointRendRoIHead(StandardRoIHead): + """`PointRend `_.""" + + def __init__(self, point_head: ConfigType, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + assert self.with_bbox and self.with_mask + self.init_point_head(point_head) + + def init_point_head(self, point_head: ConfigType) -> None: + """Initialize ``point_head``""" + self.point_head = MODELS.build(point_head) + + def mask_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult], bbox_feats: Tensor, + batch_gt_instances: InstanceList) -> dict: + """Run forward function and calculate loss for mask head and point head + in training.""" + mask_results = super().mask_loss( + x=x, + sampling_results=sampling_results, + bbox_feats=bbox_feats, + batch_gt_instances=batch_gt_instances) + + mask_point_results = self._mask_point_loss( + x=x, + sampling_results=sampling_results, + mask_preds=mask_results['mask_preds'], + batch_gt_instances=batch_gt_instances) + mask_results['loss_mask'].update( + loss_point=mask_point_results['loss_point']) + + return mask_results + + def _mask_point_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult], + mask_preds: Tensor, + batch_gt_instances: InstanceList) -> dict: + """Run forward function and calculate loss for point head in + training.""" + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + rel_roi_points = self.point_head.get_roi_rel_points_train( + mask_preds, pos_labels, cfg=self.train_cfg) + rois = bbox2roi([res.pos_bboxes for res in sampling_results]) + + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, rois, rel_roi_points) + coarse_point_feats = point_sample(mask_preds, rel_roi_points) + mask_point_pred = self.point_head(fine_grained_point_feats, + coarse_point_feats) + + loss_and_target = self.point_head.loss_and_target( + point_pred=mask_point_pred, + rel_roi_points=rel_roi_points, + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + cfg=self.train_cfg) + + return loss_and_target + + def _mask_point_forward_test(self, x: Tuple[Tensor], rois: Tensor, + label_preds: Tensor, + mask_preds: Tensor) -> Tensor: + """Mask refining process with point head in testing. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + rois (Tensor): shape (num_rois, 5). + label_preds (Tensor): The predication class for each rois. + mask_preds (Tensor): The predication coarse masks of + shape (num_rois, num_classes, small_size, small_size). + + Returns: + Tensor: The refined masks of shape (num_rois, num_classes, + large_size, large_size). + """ + refined_mask_pred = mask_preds.clone() + for subdivision_step in range(self.test_cfg.subdivision_steps): + refined_mask_pred = F.interpolate( + refined_mask_pred, + scale_factor=self.test_cfg.scale_factor, + mode='bilinear', + align_corners=False) + # If `subdivision_num_points` is larger or equal to the + # resolution of the next step, then we can skip this step + num_rois, channels, mask_height, mask_width = \ + refined_mask_pred.shape + if (self.test_cfg.subdivision_num_points >= + self.test_cfg.scale_factor**2 * mask_height * mask_width + and + subdivision_step < self.test_cfg.subdivision_steps - 1): + continue + point_indices, rel_roi_points = \ + self.point_head.get_roi_rel_points_test( + refined_mask_pred, label_preds, cfg=self.test_cfg) + + fine_grained_point_feats = self._get_fine_grained_point_feats( + x=x, rois=rois, rel_roi_points=rel_roi_points) + coarse_point_feats = point_sample(mask_preds, rel_roi_points) + mask_point_pred = self.point_head(fine_grained_point_feats, + coarse_point_feats) + + point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) + refined_mask_pred = refined_mask_pred.reshape( + num_rois, channels, mask_height * mask_width) + refined_mask_pred = refined_mask_pred.scatter_( + 2, point_indices, mask_point_pred) + refined_mask_pred = refined_mask_pred.view(num_rois, channels, + mask_height, mask_width) + + return refined_mask_pred + + def _get_fine_grained_point_feats(self, x: Tuple[Tensor], rois: Tensor, + rel_roi_points: Tensor) -> Tensor: + """Sample fine grained feats from each level feature map and + concatenate them together. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + rois (Tensor): shape (num_rois, 5). + rel_roi_points (Tensor): A tensor of shape (num_rois, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the [mask_height, mask_width] grid. + + Returns: + Tensor: The fine grained features for each points, + has shape (num_rois, feats_channels, num_points). + """ + assert rois.shape[0] > 0, 'RoI is a empty tensor.' + num_imgs = x[0].shape[0] + fine_grained_feats = [] + for idx in range(self.mask_roi_extractor.num_inputs): + feats = x[idx] + spatial_scale = 1. / float( + self.mask_roi_extractor.featmap_strides[idx]) + point_feats = [] + for batch_ind in range(num_imgs): + # unravel batch dim + feat = feats[batch_ind].unsqueeze(0) + inds = (rois[:, 0].long() == batch_ind) + if inds.any(): + rel_img_points = rel_roi_point_to_rel_img_point( + rois=rois[inds], + rel_roi_points=rel_roi_points[inds], + img=feat.shape[2:], + spatial_scale=spatial_scale).unsqueeze(0) + point_feat = point_sample(feat, rel_img_points) + point_feat = point_feat.squeeze(0).transpose(0, 1) + point_feats.append(point_feat) + fine_grained_feats.append(torch.cat(point_feats, dim=0)) + return torch.cat(fine_grained_feats, dim=1) + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # don't need to consider aug_test. + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results['mask_preds'] + # split batch mask prediction back to each image + num_mask_rois_per_img = [len(res) for res in results_list] + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + + # refine mask_preds + mask_rois = mask_rois.split(num_mask_rois_per_img, 0) + mask_preds_refined = [] + for i in range(len(batch_img_metas)): + labels = results_list[i].labels + x_i = [xx[[i]] for xx in x] + mask_rois_i = mask_rois[i] + mask_rois_i[:, 0] = 0 + mask_pred_i = self._mask_point_forward_test( + x_i, mask_rois_i, labels, mask_preds[i]) + mask_preds_refined.append(mask_pred_i) + + # TODO: Handle the case where rescale is false + results_list = self.mask_head.predict_by_feat( + mask_preds=mask_preds_refined, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + return results_list diff --git a/mmdet/models/roi_heads/roi_extractors/__init__.py b/mmdet/models/roi_heads/roi_extractors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f60214991b0ed14cdbc3964aee15356c6aaf2aa --- /dev/null +++ b/mmdet/models/roi_heads/roi_extractors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_roi_extractor import BaseRoIExtractor +from .generic_roi_extractor import GenericRoIExtractor +from .single_level_roi_extractor import SingleRoIExtractor + +__all__ = ['BaseRoIExtractor', 'SingleRoIExtractor', 'GenericRoIExtractor'] diff --git a/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..a8de0518818aba8d9aac7b807e3215d0da6c9b99 --- /dev/null +++ b/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv import ops +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.utils import ConfigType, OptMultiConfig + + +class BaseRoIExtractor(BaseModule, metaclass=ABCMeta): + """Base class for RoI extractor. + + Args: + roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and + arguments. + out_channels (int): Output channels of RoI layers. + featmap_strides (list[int]): Strides of input feature maps. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + roi_layer: ConfigType, + out_channels: int, + featmap_strides: List[int], + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) + self.out_channels = out_channels + self.featmap_strides = featmap_strides + + @property + def num_inputs(self) -> int: + """int: Number of input feature maps.""" + return len(self.featmap_strides) + + def build_roi_layers(self, layer_cfg: ConfigType, + featmap_strides: List[int]) -> nn.ModuleList: + """Build RoI operator to extract feature from each level feature map. + + Args: + layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and + config RoI layer operation. Options are modules under + ``mmcv/ops`` such as ``RoIAlign``. + featmap_strides (list[int]): The stride of input feature map w.r.t + to the original image size, which would be used to scale RoI + coordinate (original image coordinate system) to feature + coordinate system. + + Returns: + :obj:`nn.ModuleList`: The RoI extractor modules for each level + feature map. + """ + + cfg = layer_cfg.copy() + layer_type = cfg.pop('type') + if isinstance(layer_type, str): + assert hasattr(ops, layer_type) + layer_cls = getattr(ops, layer_type) + else: + layer_cls = layer_type + roi_layers = nn.ModuleList( + [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides]) + return roi_layers + + def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor: + """Scale RoI coordinates by scale factor. + + Args: + rois (Tensor): RoI (Region of Interest), shape (n, 5) + scale_factor (float): Scale factor that RoI will be multiplied by. + + Returns: + Tensor: Scaled RoI. + """ + + cx = (rois[:, 1] + rois[:, 3]) * 0.5 + cy = (rois[:, 2] + rois[:, 4]) * 0.5 + w = rois[:, 3] - rois[:, 1] + h = rois[:, 4] - rois[:, 2] + new_w = w * scale_factor + new_h = h * scale_factor + x1 = cx - new_w * 0.5 + x2 = cx + new_w * 0.5 + y1 = cy - new_h * 0.5 + y2 = cy + new_h * 0.5 + new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) + return new_rois + + @abstractmethod + def forward(self, + feats: Tuple[Tensor], + rois: Tensor, + roi_scale_factor: Optional[float] = None) -> Tensor: + """Extractor ROI feats. + + Args: + feats (Tuple[Tensor]): Multi-scale features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + roi_scale_factor (Optional[float]): RoI scale factor. + Defaults to None. + + Returns: + Tensor: RoI feature. + """ + pass diff --git a/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..39d4c90135d853404d564391f029558841ac9cac --- /dev/null +++ b/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +from mmcv.cnn.bricks import build_plugin_layer +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType +from .base_roi_extractor import BaseRoIExtractor + + +@MODELS.register_module() +class GenericRoIExtractor(BaseRoIExtractor): + """Extract RoI features from all level feature maps levels. + + This is the implementation of `A novel Region of Interest Extraction Layer + for Instance Segmentation `_. + + Args: + aggregation (str): The method to aggregate multiple feature maps. + Options are 'sum', 'concat'. Defaults to 'sum'. + pre_cfg (:obj:`ConfigDict` or dict): Specify pre-processing modules. + Defaults to None. + post_cfg (:obj:`ConfigDict` or dict): Specify post-processing modules. + Defaults to None. + kwargs (keyword arguments): Arguments that are the same + as :class:`BaseRoIExtractor`. + """ + + def __init__(self, + aggregation: str = 'sum', + pre_cfg: OptConfigType = None, + post_cfg: OptConfigType = None, + **kwargs) -> None: + super().__init__(**kwargs) + + assert aggregation in ['sum', 'concat'] + + self.aggregation = aggregation + self.with_post = post_cfg is not None + self.with_pre = pre_cfg is not None + # build pre/post processing modules + if self.with_post: + self.post_module = build_plugin_layer(post_cfg, '_post_module')[1] + if self.with_pre: + self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1] + + def forward(self, + feats: Tuple[Tensor], + rois: Tensor, + roi_scale_factor: Optional[float] = None) -> Tensor: + """Extractor ROI feats. + + Args: + feats (Tuple[Tensor]): Multi-scale features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + roi_scale_factor (Optional[float]): RoI scale factor. + Defaults to None. + + Returns: + Tensor: RoI feature. + """ + out_size = self.roi_layers[0].output_size + num_levels = len(feats) + roi_feats = feats[0].new_zeros( + rois.size(0), self.out_channels, *out_size) + + # some times rois is an empty tensor + if roi_feats.shape[0] == 0: + return roi_feats + + if num_levels == 1: + return self.roi_layers[0](feats[0], rois) + + if roi_scale_factor is not None: + rois = self.roi_rescale(rois, roi_scale_factor) + + # mark the starting channels for concat mode + start_channels = 0 + for i in range(num_levels): + roi_feats_t = self.roi_layers[i](feats[i], rois) + end_channels = start_channels + roi_feats_t.size(1) + if self.with_pre: + # apply pre-processing to a RoI extracted from each layer + roi_feats_t = self.pre_module(roi_feats_t) + if self.aggregation == 'sum': + # and sum them all + roi_feats += roi_feats_t + else: + # and concat them along channel dimension + roi_feats[:, start_channels:end_channels] = roi_feats_t + # update channels starting position + start_channels = end_channels + # check if concat channels match at the end + if self.aggregation == 'concat': + assert start_channels == self.out_channels + + if self.with_post: + # apply post-processing before return the result + roi_feats = self.post_module(roi_feats) + return roi_feats diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..59229e0b0b0a18dff81abca6f5c20cb50b0d542c --- /dev/null +++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptMultiConfig +from .base_roi_extractor import BaseRoIExtractor + + +@MODELS.register_module() +class SingleRoIExtractor(BaseRoIExtractor): + """Extract RoI features from a single level feature map. + + If there are multiple input feature levels, each RoI is mapped to a level + according to its scale. The mapping rule is proposed in + `FPN `_. + + Args: + roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and + arguments. + out_channels (int): Output channels of RoI layers. + featmap_strides (List[int]): Strides of input feature maps. + finest_scale (int): Scale threshold of mapping to level 0. + Defaults to 56. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + roi_layer: ConfigType, + out_channels: int, + featmap_strides: List[int], + finest_scale: int = 56, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + roi_layer=roi_layer, + out_channels=out_channels, + featmap_strides=featmap_strides, + init_cfg=init_cfg) + self.finest_scale = finest_scale + + def map_roi_levels(self, rois: Tensor, num_levels: int) -> Tensor: + """Map rois to corresponding feature levels by scales. + + - scale < finest_scale * 2: level 0 + - finest_scale * 2 <= scale < finest_scale * 4: level 1 + - finest_scale * 4 <= scale < finest_scale * 8: level 2 + - scale >= finest_scale * 8: level 3 + + Args: + rois (Tensor): Input RoIs, shape (k, 5). + num_levels (int): Total level number. + + Returns: + Tensor: Level index (0-based) of each RoI, shape (k, ) + """ + scale = torch.sqrt( + (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) + target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) + target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() + return target_lvls + + def forward(self, + feats: Tuple[Tensor], + rois: Tensor, + roi_scale_factor: Optional[float] = None): + """Extractor ROI feats. + + Args: + feats (Tuple[Tensor]): Multi-scale features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + roi_scale_factor (Optional[float]): RoI scale factor. + Defaults to None. + + Returns: + Tensor: RoI feature. + """ + # convert fp32 to fp16 when amp is on + rois = rois.type_as(feats[0]) + out_size = self.roi_layers[0].output_size + num_levels = len(feats) + roi_feats = feats[0].new_zeros( + rois.size(0), self.out_channels, *out_size) + + # TODO: remove this when parrots supports + if torch.__version__ == 'parrots': + roi_feats.requires_grad = True + + if num_levels == 1: + if len(rois) == 0: + return roi_feats + return self.roi_layers[0](feats[0], rois) + + target_lvls = self.map_roi_levels(rois, num_levels) + + if roi_scale_factor is not None: + rois = self.roi_rescale(rois, roi_scale_factor) + + for i in range(num_levels): + mask = target_lvls == i + inds = mask.nonzero(as_tuple=False).squeeze(1) + if inds.numel() > 0: + rois_ = rois[inds] + roi_feats_t = self.roi_layers[i](feats[i], rois_) + roi_feats[inds] = roi_feats_t + else: + # Sometimes some pyramid levels will not be used for RoI + # feature extraction and this will cause an incomplete + # computation graph in one GPU, which is different from those + # in other GPUs and will cause a hanging error. + # Therefore, we add it to ensure each feature pyramid is + # included in the computation graph to avoid runtime bugs. + roi_feats += sum( + x.view(-1)[0] + for x in self.parameters()) * 0. + feats[i].sum() * 0. + return roi_feats diff --git a/mmdet/models/roi_heads/scnet_roi_head.py b/mmdet/models/roi_heads/scnet_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d2bc1915bae38011cc75a720e48ed53b51ddb5 --- /dev/null +++ b/mmdet/models/roi_heads/scnet_roi_head.py @@ -0,0 +1,677 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList, OptConfigType +from ..layers import adaptive_avg_pool2d +from ..task_modules.samplers import SamplingResult +from ..utils import empty_instances, unpack_gt_instances +from .cascade_roi_head import CascadeRoIHead + + +@MODELS.register_module() +class SCNetRoIHead(CascadeRoIHead): + """RoIHead for `SCNet `_. + + Args: + num_stages (int): number of cascade stages. + stage_loss_weights (list): loss weight of cascade stages. + semantic_roi_extractor (dict): config to init semantic roi extractor. + semantic_head (dict): config to init semantic head. + feat_relay_head (dict): config to init feature_relay_head. + glbctx_head (dict): config to init global context head. + """ + + def __init__(self, + num_stages: int, + stage_loss_weights: List[float], + semantic_roi_extractor: OptConfigType = None, + semantic_head: OptConfigType = None, + feat_relay_head: OptConfigType = None, + glbctx_head: OptConfigType = None, + **kwargs) -> None: + super().__init__( + num_stages=num_stages, + stage_loss_weights=stage_loss_weights, + **kwargs) + assert self.with_bbox and self.with_mask + assert not self.with_shared_head # shared head is not supported + + if semantic_head is not None: + self.semantic_roi_extractor = MODELS.build(semantic_roi_extractor) + self.semantic_head = MODELS.build(semantic_head) + + if feat_relay_head is not None: + self.feat_relay_head = MODELS.build(feat_relay_head) + + if glbctx_head is not None: + self.glbctx_head = MODELS.build(glbctx_head) + + def init_mask_head(self, mask_roi_extractor: ConfigType, + mask_head: ConfigType) -> None: + """Initialize ``mask_head``""" + if mask_roi_extractor is not None: + self.mask_roi_extractor = MODELS.build(mask_roi_extractor) + self.mask_head = MODELS.build(mask_head) + + # TODO move to base_roi_head later + @property + def with_semantic(self) -> bool: + """bool: whether the head has semantic head""" + return hasattr(self, + 'semantic_head') and self.semantic_head is not None + + @property + def with_feat_relay(self) -> bool: + """bool: whether the head has feature relay head""" + return (hasattr(self, 'feat_relay_head') + and self.feat_relay_head is not None) + + @property + def with_glbctx(self) -> bool: + """bool: whether the head has global context head""" + return hasattr(self, 'glbctx_head') and self.glbctx_head is not None + + def _fuse_glbctx(self, roi_feats: Tensor, glbctx_feat: Tensor, + rois: Tensor) -> Tensor: + """Fuse global context feats with roi feats. + + Args: + roi_feats (Tensor): RoI features. + glbctx_feat (Tensor): Global context feature.. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + Tensor: Fused feature. + """ + assert roi_feats.size(0) == rois.size(0) + # RuntimeError: isDifferentiableType(variable.scalar_type()) + # INTERNAL ASSERT FAILED if detach() is not used when calling + # roi_head.predict(). + img_inds = torch.unique(rois[:, 0].detach().cpu(), sorted=True).long() + fused_feats = torch.zeros_like(roi_feats) + for img_id in img_inds: + inds = (rois[:, 0] == img_id.item()) + fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id] + return fused_feats + + def _slice_pos_feats(self, feats: Tensor, + sampling_results: List[SamplingResult]) -> Tensor: + """Get features from pos rois. + + Args: + feats (Tensor): Input features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + Tensor: Sliced features. + """ + num_rois = [res.priors.size(0) for res in sampling_results] + num_pos_rois = [res.pos_priors.size(0) for res in sampling_results] + inds = torch.zeros(sum(num_rois), dtype=torch.bool) + start = 0 + for i in range(len(num_rois)): + start = 0 if i == 0 else start + num_rois[i - 1] + stop = start + num_pos_rois[i] + inds[start:stop] = 1 + sliced_feats = feats[inds] + return sliced_feats + + def _bbox_forward(self, + stage: int, + x: Tuple[Tensor], + rois: Tensor, + semantic_feat: Optional[Tensor] = None, + glbctx_feat: Optional[Tensor] = None) -> dict: + """Box head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + semantic_feat (Tensor): Semantic feature. Defaults to None. + glbctx_feat (Tensor): Global context feature. Defaults to None. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + bbox_roi_extractor = self.bbox_roi_extractor[stage] + bbox_head = self.bbox_head[stage] + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], + rois) + if self.with_semantic and semantic_feat is not None: + bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat], + rois) + if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]: + bbox_semantic_feat = adaptive_avg_pool2d( + bbox_semantic_feat, bbox_feats.shape[-2:]) + bbox_feats += bbox_semantic_feat + if self.with_glbctx and glbctx_feat is not None: + bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois) + cls_score, bbox_pred, relayed_feat = bbox_head( + bbox_feats, return_shared_feat=True) + + bbox_results = dict( + cls_score=cls_score, + bbox_pred=bbox_pred, + relayed_feat=relayed_feat) + return bbox_results + + def _mask_forward(self, + x: Tuple[Tensor], + rois: Tensor, + semantic_feat: Optional[Tensor] = None, + glbctx_feat: Optional[Tensor] = None, + relayed_feat: Optional[Tensor] = None) -> dict: + """Mask head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + semantic_feat (Tensor): Semantic feature. Defaults to None. + glbctx_feat (Tensor): Global context feature. Defaults to None. + relayed_feat (Tensor): Relayed feature. Defaults to None. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + """ + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], rois) + if self.with_semantic and semantic_feat is not None: + mask_semantic_feat = self.semantic_roi_extractor([semantic_feat], + rois) + if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: + mask_semantic_feat = F.adaptive_avg_pool2d( + mask_semantic_feat, mask_feats.shape[-2:]) + mask_feats += mask_semantic_feat + if self.with_glbctx and glbctx_feat is not None: + mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois) + if self.with_feat_relay and relayed_feat is not None: + mask_feats = mask_feats + relayed_feat + mask_preds = self.mask_head(mask_feats) + mask_results = dict(mask_preds=mask_preds) + + return mask_results + + def bbox_loss(self, + stage: int, + x: Tuple[Tensor], + sampling_results: List[SamplingResult], + semantic_feat: Optional[Tensor] = None, + glbctx_feat: Optional[Tensor] = None) -> dict: + """Run forward function and calculate loss for box head in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + semantic_feat (Tensor): Semantic feature. Defaults to None. + glbctx_feat (Tensor): Global context feature. Defaults to None. + + Returns: + dict: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + - `rois` (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + - `bbox_targets` (tuple): Ground truth for proposals in a + single image. Containing the following list of Tensors: + (labels, label_weights, bbox_targets, bbox_weights) + """ + bbox_head = self.bbox_head[stage] + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward( + stage, + x, + rois, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat) + bbox_results.update(rois=rois) + + bbox_loss_and_target = bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg[stage]) + + bbox_results.update(bbox_loss_and_target) + return bbox_results + + def mask_loss(self, + x: Tuple[Tensor], + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + semantic_feat: Optional[Tensor] = None, + glbctx_feat: Optional[Tensor] = None, + relayed_feat: Optional[Tensor] = None) -> dict: + """Run forward function and calculate loss for mask head in training. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + semantic_feat (Tensor): Semantic feature. Defaults to None. + glbctx_feat (Tensor): Global context feature. Defaults to None. + relayed_feat (Tensor): Relayed feature. Defaults to None. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward( + x, + pos_rois, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat, + relayed_feat=relayed_feat) + + mask_loss_and_target = self.mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg[-1]) + mask_results.update(mask_loss_and_target) + + return mask_results + + def semantic_loss(self, x: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + """Semantic segmentation loss. + + Args: + x (Tuple[Tensor]): Tuple of multi-level img features. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: Usually returns a dictionary with keys: + + - `semantic_feat` (Tensor): Semantic feature. + - `loss_seg` (dict): Semantic segmentation loss. + """ + gt_semantic_segs = [ + data_sample.gt_sem_seg.sem_seg + for data_sample in batch_data_samples + ] + gt_semantic_segs = torch.stack(gt_semantic_segs) + semantic_pred, semantic_feat = self.semantic_head(x) + loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_segs) + + semantic_results = dict(loss_seg=loss_seg, semantic_feat=semantic_feat) + + return semantic_results + + def global_context_loss(self, x: Tuple[Tensor], + batch_gt_instances: InstanceList) -> dict: + """Global context loss. + + Args: + x (Tuple[Tensor]): Tuple of multi-level img features. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `glbctx_feat` (Tensor): Global context feature. + - `loss_glbctx` (dict): Global context loss. + """ + gt_labels = [ + gt_instances.labels for gt_instances in batch_gt_instances + ] + mc_pred, glbctx_feat = self.glbctx_head(x) + loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels) + global_context_results = dict( + loss_glbctx=loss_glbctx, glbctx_feat=glbctx_feat) + + return global_context_results + + def loss(self, x: Tensor, rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + + losses = dict() + + # semantic segmentation branch + if self.with_semantic: + semantic_results = self.semantic_loss( + x=x, batch_data_samples=batch_data_samples) + losses['loss_semantic_seg'] = semantic_results['loss_seg'] + semantic_feat = semantic_results['semantic_feat'] + else: + semantic_feat = None + + # global context branch + if self.with_glbctx: + global_context_results = self.global_context_loss( + x=x, batch_gt_instances=batch_gt_instances) + losses['loss_glbctx'] = global_context_results['loss_glbctx'] + glbctx_feat = global_context_results['glbctx_feat'] + else: + glbctx_feat = None + + results_list = rpn_results_list + num_imgs = len(batch_img_metas) + for stage in range(self.num_stages): + stage_loss_weight = self.stage_loss_weights[stage] + + # assign gts and sample proposals + sampling_results = [] + bbox_assigner = self.bbox_assigner[stage] + bbox_sampler = self.bbox_sampler[stage] + for i in range(num_imgs): + results = results_list[i] + # rename rpn_results.bboxes to rpn_results.priors + results.priors = results.pop('bboxes') + + assign_result = bbox_assigner.assign( + results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = bbox_sampler.sample( + assign_result, + results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + # bbox head forward and loss + bbox_results = self.bbox_loss( + stage=stage, + x=x, + sampling_results=sampling_results, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat) + + for name, value in bbox_results['loss_bbox'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + # refine bboxes + if stage < self.num_stages - 1: + bbox_head = self.bbox_head[stage] + with torch.no_grad(): + results_list = bbox_head.refine_bboxes( + sampling_results=sampling_results, + bbox_results=bbox_results, + batch_img_metas=batch_img_metas) + + if self.with_feat_relay: + relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'], + sampling_results) + relayed_feat = self.feat_relay_head(relayed_feat) + else: + relayed_feat = None + + # mask head forward and loss + mask_results = self.mask_loss( + x=x, + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat, + relayed_feat=relayed_feat) + mask_stage_loss_weight = sum(self.stage_loss_weights) + losses['loss_mask'] = mask_stage_loss_weight * mask_results[ + 'loss_mask']['loss_mask'] + + return losses + + def predict(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from upstream network. Each + has shape (N, C, H, W). + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results to + the original image. Defaults to False. + + Returns: + list[obj:`InstanceData`]: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + assert self.with_bbox, 'Bbox head must be implemented.' + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + if self.with_semantic: + _, semantic_feat = self.semantic_head(x) + else: + semantic_feat = None + + if self.with_glbctx: + _, glbctx_feat = self.glbctx_head(x) + else: + glbctx_feat = None + + # TODO: nms_op in mmcv need be enhanced, the bbox result may get + # difference when not rescale in bbox_head + + # If it has the mask branch, the bbox branch does not need + # to be scaled to the original image scale, because the mask + # branch will scale both bbox and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.predict_bbox( + x=x, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat, + batch_img_metas=batch_img_metas, + rpn_results_list=rpn_results_list, + rcnn_test_cfg=self.test_cfg, + rescale=bbox_rescale) + + if self.with_mask: + results_list = self.predict_mask( + x=x, + semantic_heat=semantic_feat, + glbctx_feat=glbctx_feat, + batch_img_metas=batch_img_metas, + results_list=results_list, + rescale=rescale) + + return results_list + + def predict_mask(self, + x: Tuple[Tensor], + semantic_heat: Tensor, + glbctx_feat: Tensor, + batch_img_metas: List[dict], + results_list: List[InstanceData], + rescale: bool = False) -> List[InstanceData]: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + semantic_feat (Tensor): Semantic feature. + glbctx_feat (Tensor): Global context feature. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas=batch_img_metas, + device=mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + bboxes_results = self._bbox_forward( + stage=-1, + x=x, + rois=mask_rois, + semantic_feat=semantic_heat, + glbctx_feat=glbctx_feat) + relayed_feat = bboxes_results['relayed_feat'] + relayed_feat = self.feat_relay_head(relayed_feat) + + mask_results = self._mask_forward( + x=x, + rois=mask_rois, + semantic_feat=semantic_heat, + glbctx_feat=glbctx_feat, + relayed_feat=relayed_feat) + mask_preds = mask_results['mask_preds'] + + # split batch mask prediction back to each image + num_bbox_per_img = tuple(len(_bbox) for _bbox in bboxes) + mask_preds = mask_preds.split(num_bbox_per_img, 0) + + results_list = self.mask_head.predict_by_feat( + mask_preds=mask_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + + return results_list + + def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + results = () + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + if self.with_semantic: + _, semantic_feat = self.semantic_head(x) + else: + semantic_feat = None + + if self.with_glbctx: + _, glbctx_feat = self.glbctx_head(x) + else: + glbctx_feat = None + + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = bbox2roi(proposals) + # bbox head + if self.with_bbox: + rois, cls_scores, bbox_preds = self._refine_roi( + x=x, + rois=rois, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat, + batch_img_metas=batch_img_metas, + num_proposals_per_img=num_proposals_per_img) + results = results + (cls_scores, bbox_preds) + # mask head + if self.with_mask: + rois = torch.cat(rois) + bboxes_results = self._bbox_forward( + stage=-1, + x=x, + rois=rois, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat) + relayed_feat = bboxes_results['relayed_feat'] + relayed_feat = self.feat_relay_head(relayed_feat) + mask_results = self._mask_forward( + x=x, + rois=rois, + semantic_feat=semantic_feat, + glbctx_feat=glbctx_feat, + relayed_feat=relayed_feat) + mask_preds = mask_results['mask_preds'] + mask_preds = mask_preds.split(num_proposals_per_img, 0) + results = results + (mask_preds, ) + return results diff --git a/mmdet/models/roi_heads/shared_heads/__init__.py b/mmdet/models/roi_heads/shared_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d56636ab34d1dd2592828238099bcdccf179d6d3 --- /dev/null +++ b/mmdet/models/roi_heads/shared_heads/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .res_layer import ResLayer + +__all__ = ['ResLayer'] diff --git a/mmdet/models/roi_heads/shared_heads/res_layer.py b/mmdet/models/roi_heads/shared_heads/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..d9210cb928fec92135a195d44d13a8588382b947 --- /dev/null +++ b/mmdet/models/roi_heads/shared_heads/res_layer.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmengine.model import BaseModule + +from mmdet.models.backbones import ResNet +from mmdet.models.layers import ResLayer as _ResLayer +from mmdet.registry import MODELS + + +@MODELS.register_module() +class ResLayer(BaseModule): + + def __init__(self, + depth, + stage=3, + stride=2, + dilation=1, + style='pytorch', + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + with_cp=False, + dcn=None, + pretrained=None, + init_cfg=None): + super(ResLayer, self).__init__(init_cfg) + + self.norm_eval = norm_eval + self.norm_cfg = norm_cfg + self.stage = stage + self.fp16_enabled = False + block, stage_blocks = ResNet.arch_settings[depth] + stage_block = stage_blocks[stage] + planes = 64 * 2**stage + inplanes = 64 * 2**(stage - 1) * block.expansion + + res_layer = _ResLayer( + block, + inplanes, + planes, + stage_block, + stride=stride, + dilation=dilation, + style=style, + with_cp=with_cp, + norm_cfg=self.norm_cfg, + dcn=dcn) + self.add_module(f'layer{stage + 1}', res_layer) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + res_layer = getattr(self, f'layer{self.stage + 1}') + out = res_layer(x) + return out + + def train(self, mode=True): + super(ResLayer, self).train(mode) + if self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmdet/models/roi_heads/sparse_roi_head.py b/mmdet/models/roi_heads/sparse_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..19c3e1e335ca4e4a9d5befcbffcf4665b459cb5a --- /dev/null +++ b/mmdet/models/roi_heads/sparse_roi_head.py @@ -0,0 +1,601 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.task_modules.samplers import PseudoSampler +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList, OptConfigType +from ..utils.misc import empty_instances, unpack_gt_instances +from .cascade_roi_head import CascadeRoIHead + + +@MODELS.register_module() +class SparseRoIHead(CascadeRoIHead): + r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with + Learnable Proposals `_ + and `Instances as Queries `_ + + Args: + num_stages (int): Number of stage whole iterative process. + Defaults to 6. + stage_loss_weights (Tuple[float]): The loss + weight of each stage. By default all stages have + the same weight 1. + bbox_roi_extractor (:obj:`ConfigDict` or dict): Config of box + roi extractor. + mask_roi_extractor (:obj:`ConfigDict` or dict): Config of mask + roi extractor. + bbox_head (:obj:`ConfigDict` or dict): Config of box head. + mask_head (:obj:`ConfigDict` or dict): Config of mask head. + train_cfg (:obj:`ConfigDict` or dict, Optional): Configuration + information in train stage. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, Optional): Configuration + information in test stage. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. Defaults to None. + """ + + def __init__(self, + num_stages: int = 6, + stage_loss_weights: Tuple[float] = (1, 1, 1, 1, 1, 1), + proposal_feature_channel: int = 256, + bbox_roi_extractor: ConfigType = dict( + type='SingleRoIExtractor', + roi_layer=dict( + type='RoIAlign', output_size=7, sampling_ratio=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_roi_extractor: OptConfigType = None, + bbox_head: ConfigType = dict( + type='DIIHead', + num_classes=80, + num_fcs=2, + num_heads=8, + num_cls_fcs=1, + num_reg_fcs=3, + feedforward_channels=2048, + hidden_channels=256, + dropout=0.0, + roi_feat_size=7, + ffn_act_cfg=dict(type='ReLU', inplace=True)), + mask_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptConfigType = None) -> None: + assert bbox_roi_extractor is not None + assert bbox_head is not None + assert len(stage_loss_weights) == num_stages + self.num_stages = num_stages + self.stage_loss_weights = stage_loss_weights + self.proposal_feature_channel = proposal_feature_channel + super().__init__( + num_stages=num_stages, + stage_loss_weights=stage_loss_weights, + bbox_roi_extractor=bbox_roi_extractor, + mask_roi_extractor=mask_roi_extractor, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg) + # train_cfg would be None when run the test.py + if train_cfg is not None: + for stage in range(num_stages): + assert isinstance(self.bbox_sampler[stage], PseudoSampler), \ + 'Sparse R-CNN and QueryInst only support `PseudoSampler`' + + def bbox_loss(self, stage: int, x: Tuple[Tensor], + results_list: InstanceList, object_feats: Tensor, + batch_img_metas: List[dict], + batch_gt_instances: InstanceList) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + stage (int): The current stage in iterative process. + x (tuple[Tensor]): List of multi-level img features. + results_list (List[:obj:`InstanceData`]) : List of region + proposals. + object_feats (Tensor): The object feature extracted from + the previous stage. + batch_img_metas (list[dict]): Meta information of each image. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + proposal_list = [res.bboxes for res in results_list] + rois = bbox2roi(proposal_list) + bbox_results = self._bbox_forward(stage, x, rois, object_feats, + batch_img_metas) + imgs_whwh = torch.cat( + [res.imgs_whwh[None, ...] for res in results_list]) + cls_pred_list = bbox_results['detached_cls_scores'] + proposal_list = bbox_results['detached_proposals'] + + sampling_results = [] + bbox_head = self.bbox_head[stage] + for i in range(len(batch_img_metas)): + pred_instances = InstanceData() + # TODO: Enhance the logic + pred_instances.bboxes = proposal_list[i] # for assinger + pred_instances.scores = cls_pred_list[i] + pred_instances.priors = proposal_list[i] # for sampler + + assign_result = self.bbox_assigner[stage].assign( + pred_instances=pred_instances, + gt_instances=batch_gt_instances[i], + gt_instances_ignore=None, + img_meta=batch_img_metas[i]) + + sampling_result = self.bbox_sampler[stage].sample( + assign_result, pred_instances, batch_gt_instances[i]) + sampling_results.append(sampling_result) + + bbox_results.update(sampling_results=sampling_results) + + cls_score = bbox_results['cls_score'] + decoded_bboxes = bbox_results['decoded_bboxes'] + cls_score = cls_score.view(-1, cls_score.size(-1)) + decoded_bboxes = decoded_bboxes.view(-1, 4) + bbox_loss_and_target = bbox_head.loss_and_target( + cls_score, + decoded_bboxes, + sampling_results, + self.train_cfg[stage], + imgs_whwh=imgs_whwh, + concat=True) + bbox_results.update(bbox_loss_and_target) + + # propose for the new proposal_list + proposal_list = [] + for idx in range(len(batch_img_metas)): + results = InstanceData() + results.imgs_whwh = results_list[idx].imgs_whwh + results.bboxes = bbox_results['detached_proposals'][idx] + proposal_list.append(results) + bbox_results.update(results_list=proposal_list) + return bbox_results + + def _bbox_forward(self, stage: int, x: Tuple[Tensor], rois: Tensor, + object_feats: Tensor, + batch_img_metas: List[dict]) -> dict: + """Box head forward function used in both training and testing. Returns + all regression, classification results and a intermediate feature. + + Args: + stage (int): The current stage in iterative process. + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + Each dimension means (img_index, x1, y1, x2, y2). + object_feats (Tensor): The object feature extracted from + the previous stage. + batch_img_metas (list[dict]): Meta information of each image. + + Returns: + dict[str, Tensor]: a dictionary of bbox head outputs, + Containing the following results: + + - cls_score (Tensor): The score of each class, has + shape (batch_size, num_proposals, num_classes) + when use focal loss or + (batch_size, num_proposals, num_classes+1) + otherwise. + - decoded_bboxes (Tensor): The regression results + with shape (batch_size, num_proposal, 4). + The last dimension 4 represents + [tl_x, tl_y, br_x, br_y]. + - object_feats (Tensor): The object feature extracted + from current stage + - detached_cls_scores (list[Tensor]): The detached + classification results, length is batch_size, and + each tensor has shape (num_proposal, num_classes). + - detached_proposals (list[tensor]): The detached + regression results, length is batch_size, and each + tensor has shape (num_proposal, 4). The last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + """ + num_imgs = len(batch_img_metas) + bbox_roi_extractor = self.bbox_roi_extractor[stage] + bbox_head = self.bbox_head[stage] + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], + rois) + cls_score, bbox_pred, object_feats, attn_feats = bbox_head( + bbox_feats, object_feats) + + fake_bbox_results = dict( + rois=rois, + bbox_targets=(rois.new_zeros(len(rois), dtype=torch.long), None), + bbox_pred=bbox_pred.view(-1, bbox_pred.size(-1)), + cls_score=cls_score.view(-1, cls_score.size(-1))) + fake_sampling_results = [ + InstanceData(pos_is_gt=rois.new_zeros(object_feats.size(1))) + for _ in range(len(batch_img_metas)) + ] + + results_list = bbox_head.refine_bboxes( + sampling_results=fake_sampling_results, + bbox_results=fake_bbox_results, + batch_img_metas=batch_img_metas) + proposal_list = [res.bboxes for res in results_list] + bbox_results = dict( + cls_score=cls_score, + decoded_bboxes=torch.cat(proposal_list), + object_feats=object_feats, + attn_feats=attn_feats, + # detach then use it in label assign + detached_cls_scores=[ + cls_score[i].detach() for i in range(num_imgs) + ], + detached_proposals=[item.detach() for item in proposal_list]) + + return bbox_results + + def _mask_forward(self, stage: int, x: Tuple[Tensor], rois: Tensor, + attn_feats) -> dict: + """Mask head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + attn_feats (Tensot): Intermediate feature get from the last + diihead, has shape + (batch_size*num_proposals, feature_dimensions) + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + """ + mask_roi_extractor = self.mask_roi_extractor[stage] + mask_head = self.mask_head[stage] + mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs], + rois) + # do not support caffe_c4 model anymore + mask_preds = mask_head(mask_feats, attn_feats) + + mask_results = dict(mask_preds=mask_preds) + return mask_results + + def mask_loss(self, stage: int, x: Tuple[Tensor], bbox_results: dict, + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> dict: + """Run forward function and calculate loss for mask head in training. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + bbox_results (dict): Results obtained from `bbox_loss`. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + attn_feats = bbox_results['attn_feats'] + sampling_results = bbox_results['sampling_results'] + + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + + attn_feats = torch.cat([ + feats[res.pos_inds] + for (feats, res) in zip(attn_feats, sampling_results) + ]) + mask_results = self._mask_forward(stage, x, pos_rois, attn_feats) + + mask_loss_and_target = self.mask_head[stage].loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=rcnn_train_cfg) + mask_results.update(mask_loss_and_target) + + return mask_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (List[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: a dictionary of loss components of all stage. + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ + = outputs + + object_feats = torch.cat( + [res.pop('features')[None, ...] for res in rpn_results_list]) + results_list = rpn_results_list + losses = {} + for stage in range(self.num_stages): + stage_loss_weight = self.stage_loss_weights[stage] + + # bbox head forward and loss + bbox_results = self.bbox_loss( + stage=stage, + x=x, + object_feats=object_feats, + results_list=results_list, + batch_img_metas=batch_img_metas, + batch_gt_instances=batch_gt_instances) + + for name, value in bbox_results['loss_bbox'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + if self.with_mask: + mask_results = self.mask_loss( + stage=stage, + x=x, + bbox_results=bbox_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg[stage]) + + for name, value in mask_results['loss_mask'].items(): + losses[f's{stage}.{name}'] = ( + value * stage_loss_weight if 'loss' in name else value) + + object_feats = bbox_results['object_feats'] + results_list = bbox_results['results_list'] + return losses + + def predict_bbox(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + rpn_results_list: InstanceList, + rcnn_test_cfg: ConfigType, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the bbox head and predict detection + results on the features of the upstream network. + + Args: + x(tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + proposal_list = [res.bboxes for res in rpn_results_list] + object_feats = torch.cat( + [res.pop('features')[None, ...] for res in rpn_results_list]) + if all([proposal.shape[0] == 0 for proposal in proposal_list]): + # There is no proposal in the whole batch + return empty_instances( + batch_img_metas, x[0].device, task_type='bbox') + + for stage in range(self.num_stages): + rois = bbox2roi(proposal_list) + bbox_results = self._bbox_forward(stage, x, rois, object_feats, + batch_img_metas) + object_feats = bbox_results['object_feats'] + cls_score = bbox_results['cls_score'] + proposal_list = bbox_results['detached_proposals'] + + num_classes = self.bbox_head[-1].num_classes + + if self.bbox_head[-1].loss_cls.use_sigmoid: + cls_score = cls_score.sigmoid() + else: + cls_score = cls_score.softmax(-1)[..., :-1] + + topk_inds_list = [] + results_list = [] + for img_id in range(len(batch_img_metas)): + cls_score_per_img = cls_score[img_id] + scores_per_img, topk_inds = cls_score_per_img.flatten(0, 1).topk( + self.test_cfg.max_per_img, sorted=False) + labels_per_img = topk_inds % num_classes + bboxes_per_img = proposal_list[img_id][topk_inds // num_classes] + topk_inds_list.append(topk_inds) + if rescale and bboxes_per_img.size(0) > 0: + assert batch_img_metas[img_id].get('scale_factor') is not None + scale_factor = bboxes_per_img.new_tensor( + batch_img_metas[img_id]['scale_factor']).repeat((1, 2)) + bboxes_per_img = ( + bboxes_per_img.view(bboxes_per_img.size(0), -1, 4) / + scale_factor).view(bboxes_per_img.size()[0], -1) + + results = InstanceData() + results.bboxes = bboxes_per_img + results.scores = scores_per_img + results.labels = labels_per_img + results_list.append(results) + if self.with_mask: + for img_id in range(len(batch_img_metas)): + # add positive information in InstanceData to predict + # mask results in `mask_head`. + proposals = bbox_results['detached_proposals'][img_id] + topk_inds = topk_inds_list[img_id] + attn_feats = bbox_results['attn_feats'][img_id] + + results_list[img_id].proposals = proposals + results_list[img_id].topk_inds = topk_inds + results_list[img_id].attn_feats = attn_feats + return results_list + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. Each item usually contains following keys: + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - proposal (Tensor): Bboxes predicted from bbox_head, + has a shape (num_instances, 4). + - topk_inds (Tensor): Topk indices of each image, has + shape (num_instances, ) + - attn_feats (Tensor): Intermediate feature get from the last + diihead, has shape (num_instances, feature_dimensions) + + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + proposal_list = [res.pop('proposals') for res in results_list] + topk_inds_list = [res.pop('topk_inds') for res in results_list] + attn_feats = torch.cat( + [res.pop('attn_feats')[None, ...] for res in results_list]) + + rois = bbox2roi(proposal_list) + + if rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + last_stage = self.num_stages - 1 + mask_results = self._mask_forward(last_stage, x, rois, attn_feats) + + num_imgs = len(batch_img_metas) + mask_results['mask_preds'] = mask_results['mask_preds'].reshape( + num_imgs, -1, *mask_results['mask_preds'].size()[1:]) + num_classes = self.bbox_head[-1].num_classes + + mask_preds = [] + for img_id in range(num_imgs): + topk_inds = topk_inds_list[img_id] + masks_per_img = mask_results['mask_preds'][img_id].flatten( + 0, 1)[topk_inds] + masks_per_img = masks_per_img[:, None, + ...].repeat(1, num_classes, 1, 1) + mask_preds.append(masks_per_img) + results_list = self.mask_head[-1].predict_by_feat( + mask_preds, + results_list, + batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + + return results_list + + # TODO: Need to refactor later + def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (List[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + all_stage_bbox_results = [] + object_feats = torch.cat( + [res.pop('features')[None, ...] for res in rpn_results_list]) + results_list = rpn_results_list + if self.with_bbox: + for stage in range(self.num_stages): + bbox_results = self.bbox_loss( + stage=stage, + x=x, + results_list=results_list, + object_feats=object_feats, + batch_img_metas=batch_img_metas, + batch_gt_instances=batch_gt_instances) + bbox_results.pop('loss_bbox') + # torch.jit does not support obj:SamplingResult + bbox_results.pop('results_list') + bbox_res = bbox_results.copy() + bbox_res.pop('sampling_results') + all_stage_bbox_results.append((bbox_res, )) + + if self.with_mask: + attn_feats = bbox_results['attn_feats'] + sampling_results = bbox_results['sampling_results'] + + pos_rois = bbox2roi( + [res.pos_priors for res in sampling_results]) + + attn_feats = torch.cat([ + feats[res.pos_inds] + for (feats, res) in zip(attn_feats, sampling_results) + ]) + mask_results = self._mask_forward(stage, x, pos_rois, + attn_feats) + all_stage_bbox_results[-1] += (mask_results, ) + return tuple(all_stage_bbox_results) diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8d168eba0fb2ccf6aa89bde5c637160f10aea83a --- /dev/null +++ b/mmdet/models/roi_heads/standard_roi_head.py @@ -0,0 +1,419 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import DetDataSample, SampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import ConfigType, InstanceList +from ..task_modules.samplers import SamplingResult +from ..utils import empty_instances, unpack_gt_instances +from .base_roi_head import BaseRoIHead + + +@MODELS.register_module() +class StandardRoIHead(BaseRoIHead): + """Simplest base roi head including one bbox head and one mask head.""" + + def init_assigner_sampler(self) -> None: + """Initialize assigner and sampler.""" + self.bbox_assigner = None + self.bbox_sampler = None + if self.train_cfg: + self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner) + self.bbox_sampler = TASK_UTILS.build( + self.train_cfg.sampler, default_args=dict(context=self)) + + def init_bbox_head(self, bbox_roi_extractor: ConfigType, + bbox_head: ConfigType) -> None: + """Initialize box head and box roi extractor. + + Args: + bbox_roi_extractor (dict or ConfigDict): Config of box + roi extractor. + bbox_head (dict or ConfigDict): Config of box in box head. + """ + self.bbox_roi_extractor = MODELS.build(bbox_roi_extractor) + self.bbox_head = MODELS.build(bbox_head) + + def init_mask_head(self, mask_roi_extractor: ConfigType, + mask_head: ConfigType) -> None: + """Initialize mask head and mask roi extractor. + + Args: + mask_roi_extractor (dict or ConfigDict): Config of mask roi + extractor. + mask_head (dict or ConfigDict): Config of mask in mask head. + """ + if mask_roi_extractor is not None: + self.mask_roi_extractor = MODELS.build(mask_roi_extractor) + self.share_roi_extractor = False + else: + self.share_roi_extractor = True + self.mask_roi_extractor = self.bbox_roi_extractor + self.mask_head = MODELS.build(mask_head) + + # TODO: Need to refactor later + def forward(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList = None) -> tuple: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + results = () + proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] + rois = bbox2roi(proposals) + # bbox head + if self.with_bbox: + bbox_results = self._bbox_forward(x, rois) + results = results + (bbox_results['cls_score'], + bbox_results['bbox_pred']) + # mask head + if self.with_mask: + mask_rois = rois[:100] + mask_results = self._mask_forward(x, mask_rois) + results = results + (mask_results['mask_preds'], ) + return results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: List[DetDataSample]) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + losses = dict() + # bbox head loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results) + losses.update(bbox_results['loss_bbox']) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(x, sampling_results, + bbox_results['bbox_feats'], + batch_gt_instances) + losses.update(mask_results['loss_mask']) + + return losses + + def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: + """Box head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + # TODO: a more flexible way to decide which feature maps to use + bbox_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + cls_score, bbox_pred = self.bbox_head(bbox_feats) + + bbox_results = dict( + cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) + return bbox_results + + def bbox_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult]) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + + bbox_loss_and_target = self.bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg) + + bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox']) + return bbox_results + + def mask_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult], bbox_feats: Tensor, + batch_gt_instances: InstanceList) -> dict: + """Perform forward propagation and loss calculation of the mask head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + bbox_feats (Tensor): Extract bbox RoI features. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + - `mask_targets` (Tensor): Mask target of each positive\ + proposals in the image. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + if not self.share_roi_extractor: + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward(x, pos_rois) + else: + pos_inds = [] + device = bbox_feats.device + for res in sampling_results: + pos_inds.append( + torch.ones( + res.pos_priors.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds.append( + torch.zeros( + res.neg_priors.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds = torch.cat(pos_inds) + + mask_results = self._mask_forward( + x, pos_inds=pos_inds, bbox_feats=bbox_feats) + + mask_loss_and_target = self.mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg) + + mask_results.update(loss_mask=mask_loss_and_target['loss_mask']) + return mask_results + + def _mask_forward(self, + x: Tuple[Tensor], + rois: Tensor = None, + pos_inds: Optional[Tensor] = None, + bbox_feats: Optional[Tensor] = None) -> dict: + """Mask head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + pos_inds (Tensor, optional): Indices of positive samples. + Defaults to None. + bbox_feats (Tensor): Extract bbox RoI features. Defaults to None. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + """ + assert ((rois is not None) ^ + (pos_inds is not None and bbox_feats is not None)) + if rois is not None: + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + else: + assert bbox_feats is not None + mask_feats = bbox_feats[pos_inds] + + mask_preds = self.mask_head(mask_feats) + mask_results = dict(mask_preds=mask_preds, mask_feats=mask_feats) + return mask_results + + def predict_bbox(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + rpn_results_list: InstanceList, + rcnn_test_cfg: ConfigType, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the bbox head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + proposals = [res.bboxes for res in rpn_results_list] + rois = bbox2roi(proposals) + + if rois.shape[0] == 0: + return empty_instances( + batch_img_metas, + rois.device, + task_type='bbox', + box_type=self.bbox_head.predict_box_type, + num_classes=self.bbox_head.num_classes, + score_per_cls=rcnn_test_cfg is None) + + bbox_results = self._bbox_forward(x, rois) + + # split batch bbox prediction back to each image + cls_scores = bbox_results['cls_score'] + bbox_preds = bbox_results['bbox_pred'] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = rois.split(num_proposals_per_img, 0) + cls_scores = cls_scores.split(num_proposals_per_img, 0) + + # some detector with_reg is False, bbox_preds will be None + if bbox_preds is not None: + # TODO move this to a sabl_roi_head + # the bbox prediction of some detectors like SABL is not Tensor + if isinstance(bbox_preds, torch.Tensor): + bbox_preds = bbox_preds.split(num_proposals_per_img, 0) + else: + bbox_preds = self.bbox_head.bbox_pred_split( + bbox_preds, num_proposals_per_img) + else: + bbox_preds = (None, ) * len(proposals) + + result_list = self.bbox_head.predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale) + return result_list + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # don't need to consider aug_test. + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results['mask_preds'] + # split batch mask prediction back to each image + num_mask_rois_per_img = [len(res) for res in results_list] + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + + # TODO: Handle the case where rescale is false + results_list = self.mask_head.predict_by_feat( + mask_preds=mask_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + return results_list diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..940490454d9cf1fde4d69c1f890c173b92d522a1 --- /dev/null +++ b/mmdet/models/roi_heads/test_mixins.py @@ -0,0 +1,171 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# TODO: delete this file after refactor +import sys + +import torch + +from mmdet.models.layers import multiclass_nms +from mmdet.models.test_time_augs import merge_aug_bboxes, merge_aug_masks +from mmdet.structures.bbox import bbox2roi, bbox_mapping + +if sys.version_info >= (3, 7): + from mmdet.utils.contextmanagers import completed + + +class BBoxTestMixin: + + if sys.version_info >= (3, 7): + # TODO: Currently not supported + async def async_test_bboxes(self, + x, + img_metas, + proposals, + rcnn_test_cfg, + rescale=False, + **kwargs): + """Asynchronized test for box head without augmentation.""" + rois = bbox2roi(proposals) + roi_feats = self.bbox_roi_extractor( + x[:len(self.bbox_roi_extractor.featmap_strides)], rois) + if self.with_shared_head: + roi_feats = self.shared_head(roi_feats) + sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017) + + async with completed( + __name__, 'bbox_head_forward', + sleep_interval=sleep_interval): + cls_score, bbox_pred = self.bbox_head(roi_feats) + + img_shape = img_metas[0]['img_shape'] + scale_factor = img_metas[0]['scale_factor'] + det_bboxes, det_labels = self.bbox_head.get_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=rescale, + cfg=rcnn_test_cfg) + return det_bboxes, det_labels + + # TODO: Currently not supported + def aug_test_bboxes(self, feats, img_metas, rpn_results_list, + rcnn_test_cfg): + """Test det bboxes with test time augmentation.""" + aug_bboxes = [] + aug_scores = [] + for x, img_meta in zip(feats, img_metas): + # only one image in the batch + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + flip = img_meta[0]['flip'] + flip_direction = img_meta[0]['flip_direction'] + # TODO more flexible + proposals = bbox_mapping(rpn_results_list[0][:, :4], img_shape, + scale_factor, flip, flip_direction) + rois = bbox2roi([proposals]) + bbox_results = self.bbox_forward(x, rois) + bboxes, scores = self.bbox_head.get_bboxes( + rois, + bbox_results['cls_score'], + bbox_results['bbox_pred'], + img_shape, + scale_factor, + rescale=False, + cfg=None) + aug_bboxes.append(bboxes) + aug_scores.append(scores) + # after merging, bboxes will be rescaled to the original image size + merged_bboxes, merged_scores = merge_aug_bboxes( + aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) + if merged_bboxes.shape[0] == 0: + # There is no proposal in the single image + det_bboxes = merged_bboxes.new_zeros(0, 5) + det_labels = merged_bboxes.new_zeros((0, ), dtype=torch.long) + else: + det_bboxes, det_labels = multiclass_nms(merged_bboxes, + merged_scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img) + return det_bboxes, det_labels + + +class MaskTestMixin: + + if sys.version_info >= (3, 7): + # TODO: Currently not supported + async def async_test_mask(self, + x, + img_metas, + det_bboxes, + det_labels, + rescale=False, + mask_test_cfg=None): + """Asynchronized test for mask head without augmentation.""" + # image shape of the first image in the batch (only one) + ori_shape = img_metas[0]['ori_shape'] + scale_factor = img_metas[0]['scale_factor'] + if det_bboxes.shape[0] == 0: + segm_result = [[] for _ in range(self.mask_head.num_classes)] + else: + if rescale and not isinstance(scale_factor, + (float, torch.Tensor)): + scale_factor = det_bboxes.new_tensor(scale_factor) + _bboxes = ( + det_bboxes[:, :4] * + scale_factor if rescale else det_bboxes) + mask_rois = bbox2roi([_bboxes]) + mask_feats = self.mask_roi_extractor( + x[:len(self.mask_roi_extractor.featmap_strides)], + mask_rois) + + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + if mask_test_cfg and \ + mask_test_cfg.get('async_sleep_interval'): + sleep_interval = mask_test_cfg['async_sleep_interval'] + else: + sleep_interval = 0.035 + async with completed( + __name__, + 'mask_head_forward', + sleep_interval=sleep_interval): + mask_pred = self.mask_head(mask_feats) + segm_result = self.mask_head.get_results( + mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape, + scale_factor, rescale) + return segm_result + + # TODO: Currently not supported + def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels): + """Test for mask head with test time augmentation.""" + if det_bboxes.shape[0] == 0: + segm_result = [[] for _ in range(self.mask_head.num_classes)] + else: + aug_masks = [] + for x, img_meta in zip(feats, img_metas): + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + flip = img_meta[0]['flip'] + flip_direction = img_meta[0]['flip_direction'] + _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, + scale_factor, flip, flip_direction) + mask_rois = bbox2roi([_bboxes]) + mask_results = self._mask_forward(x, mask_rois) + # convert to numpy array to save memory + aug_masks.append( + mask_results['mask_pred'].sigmoid().cpu().numpy()) + merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg) + + ori_shape = img_metas[0][0]['ori_shape'] + scale_factor = det_bboxes.new_ones(4) + segm_result = self.mask_head.get_results( + merged_masks, + det_bboxes, + det_labels, + self.test_cfg, + ori_shape, + scale_factor=scale_factor, + rescale=False) + return segm_result diff --git a/mmdet/models/roi_heads/trident_roi_head.py b/mmdet/models/roi_heads/trident_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5215327296282a8e7ca502f3321aced8a4f840b7 --- /dev/null +++ b/mmdet/models/roi_heads/trident_roi_head.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmcv.ops import batched_nms +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import InstanceList +from .standard_roi_head import StandardRoIHead + + +@MODELS.register_module() +class TridentRoIHead(StandardRoIHead): + """Trident roi head. + + Args: + num_branch (int): Number of branches in TridentNet. + test_branch_idx (int): In inference, all 3 branches will be used + if `test_branch_idx==-1`, otherwise only branch with index + `test_branch_idx` will be used. + """ + + def __init__(self, num_branch: int, test_branch_idx: int, + **kwargs) -> None: + self.num_branch = num_branch + self.test_branch_idx = test_branch_idx + super().__init__(**kwargs) + + def merge_trident_bboxes(self, + trident_results: InstanceList) -> InstanceData: + """Merge bbox predictions of each branch. + + Args: + trident_results (List[:obj:`InstanceData`]): A list of InstanceData + predicted from every branch. + + Returns: + :obj:`InstanceData`: merged InstanceData. + """ + bboxes = torch.cat([res.bboxes for res in trident_results]) + scores = torch.cat([res.scores for res in trident_results]) + labels = torch.cat([res.labels for res in trident_results]) + + nms_cfg = self.test_cfg['nms'] + results = InstanceData() + if bboxes.numel() == 0: + results.bboxes = bboxes + results.scores = scores + results.labels = labels + else: + det_bboxes, keep = batched_nms(bboxes, scores, labels, nms_cfg) + results.bboxes = det_bboxes[:, :-1] + results.scores = det_bboxes[:, -1] + results.labels = labels[keep] + + if self.test_cfg['max_per_img'] > 0: + results = results[:self.test_cfg['max_per_img']] + return results + + def predict(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. + + - Compute prediction bbox and label per branch. + - Merge predictions of each branch according to scores of + bboxes, i.e., bboxes with higher score are kept to give + top-k prediction. + + Args: + x (tuple[Tensor]): Features from upstream network. Each + has shape (N, C, H, W). + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results to + the original image. Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + results_list = super().predict( + x=x, + rpn_results_list=rpn_results_list, + batch_data_samples=batch_data_samples, + rescale=rescale) + + num_branch = self.num_branch \ + if self.training or self.test_branch_idx == -1 else 1 + + merged_results_list = [] + for i in range(len(batch_data_samples) // num_branch): + merged_results_list.append( + self.merge_trident_bboxes(results_list[i * num_branch:(i + 1) * + num_branch])) + return merged_results_list diff --git a/mmdet/models/seg_heads/__init__.py b/mmdet/models/seg_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b489a905b1e9b6cef2e8b9575600990563128e4e --- /dev/null +++ b/mmdet/models/seg_heads/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .panoptic_fpn_head import PanopticFPNHead # noqa: F401,F403 +from .panoptic_fusion_heads import * # noqa: F401,F403 diff --git a/mmdet/models/seg_heads/base_semantic_head.py b/mmdet/models/seg_heads/base_semantic_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1db71549d89766c45012517c20cef443f4760419 --- /dev/null +++ b/mmdet/models/seg_heads/base_semantic_head.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Tuple, Union + +import torch.nn.functional as F +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptMultiConfig + + +@MODELS.register_module() +class BaseSemanticHead(BaseModule, metaclass=ABCMeta): + """Base module of Semantic Head. + + Args: + num_classes (int): the number of classes. + seg_rescale_factor (float): the rescale factor for ``gt_sem_seg``, + which equals to ``1 / output_strides``. The output_strides is + for ``seg_preds``. Defaults to 1 / 4. + init_cfg (Optional[Union[:obj:`ConfigDict`, dict]]): the initialization + config. + loss_seg (Union[:obj:`ConfigDict`, dict]): the loss of the semantic + head. + """ + + def __init__(self, + num_classes: int, + seg_rescale_factor: float = 1 / 4., + loss_seg: ConfigType = dict( + type='CrossEntropyLoss', + ignore_index=255, + loss_weight=1.0), + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_seg = MODELS.build(loss_seg) + self.num_classes = num_classes + self.seg_rescale_factor = seg_rescale_factor + + @abstractmethod + def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Dict[str, Tensor]: + """Placeholder of forward function. + + Args: + x (Tensor): Feature maps. + + Returns: + Dict[str, Tensor]: A dictionary, including features + and predicted scores. Required keys: 'seg_preds' + and 'feats'. + """ + pass + + @abstractmethod + def loss(self, x: Union[Tensor, Tuple[Tensor]], + batch_data_samples: SampleList) -> Dict[str, Tensor]: + """ + Args: + x (Union[Tensor, Tuple[Tensor]]): Feature maps. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Args: + x (Tensor): Feature maps. + + Returns: + Dict[str, Tensor]: The loss of semantic head. + """ + pass + + def predict(self, + x: Union[Tensor, Tuple[Tensor]], + batch_img_metas: List[dict], + rescale: bool = False) -> List[Tensor]: + """Test without Augmentation. + + Args: + x (Union[Tensor, Tuple[Tensor]]): Feature maps. + batch_img_metas (List[dict]): List of image information. + rescale (bool): Whether to rescale the results. + Defaults to False. + + Returns: + list[Tensor]: semantic segmentation logits. + """ + seg_preds = self.forward(x)['seg_preds'] + seg_preds = F.interpolate( + seg_preds, + size=batch_img_metas[0]['batch_input_shape'], + mode='bilinear', + align_corners=False) + seg_preds = [seg_preds[i] for i in range(len(batch_img_metas))] + + if rescale: + seg_pred_list = [] + for i in range(len(batch_img_metas)): + h, w = batch_img_metas[i]['img_shape'] + seg_pred = seg_preds[i][:, :h, :w] + + h, w = batch_img_metas[i]['ori_shape'] + seg_pred = F.interpolate( + seg_pred[None], + size=(h, w), + mode='bilinear', + align_corners=False)[0] + seg_pred_list.append(seg_pred) + else: + seg_pred_list = seg_preds + + return seg_pred_list diff --git a/mmdet/models/seg_heads/panoptic_fpn_head.py b/mmdet/models/seg_heads/panoptic_fpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8b901360922f6cdb9f8d15b60dac8d7514ee75 --- /dev/null +++ b/mmdet/models/seg_heads/panoptic_fpn_head.py @@ -0,0 +1,174 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import ModuleList +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from ..layers import ConvUpsample +from ..utils import interpolate_as +from .base_semantic_head import BaseSemanticHead + + +@MODELS.register_module() +class PanopticFPNHead(BaseSemanticHead): + """PanopticFPNHead used in Panoptic FPN. + + In this head, the number of output channels is ``num_stuff_classes + + 1``, including all stuff classes and one thing class. The stuff + classes will be reset from ``0`` to ``num_stuff_classes - 1``, the + thing classes will be merged to ``num_stuff_classes``-th channel. + + Arg: + num_things_classes (int): Number of thing classes. Default: 80. + num_stuff_classes (int): Number of stuff classes. Default: 53. + in_channels (int): Number of channels in the input feature + map. + inner_channels (int): Number of channels in inner features. + start_level (int): The start level of the input features + used in PanopticFPN. + end_level (int): The end level of the used features, the + ``end_level``-th layer will not be used. + conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct + and config conv layer. + norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config + norm layer. Use ``GN`` by default. + init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config + dict. + loss_seg (Union[ConfigDict, dict]): the loss of the semantic head. + """ + + def __init__(self, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + in_channels: int = 256, + inner_channels: int = 128, + start_level: int = 0, + end_level: int = 4, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict( + type='GN', num_groups=32, requires_grad=True), + loss_seg: ConfigType = dict( + type='CrossEntropyLoss', ignore_index=-1, + loss_weight=1.0), + init_cfg: OptMultiConfig = None) -> None: + seg_rescale_factor = 1 / 2**(start_level + 2) + super().__init__( + num_classes=num_stuff_classes + 1, + seg_rescale_factor=seg_rescale_factor, + loss_seg=loss_seg, + init_cfg=init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + # Used feature layers are [start_level, end_level) + self.start_level = start_level + self.end_level = end_level + self.num_stages = end_level - start_level + self.inner_channels = inner_channels + + self.conv_upsample_layers = ModuleList() + for i in range(start_level, end_level): + self.conv_upsample_layers.append( + ConvUpsample( + in_channels, + inner_channels, + num_layers=i if i > 0 else 1, + num_upsample=i if i > 0 else 0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + )) + self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1) + + def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor: + """Merge thing classes to one class. + + In PanopticFPN, the background labels will be reset from `0` to + `self.num_stuff_classes-1`, the foreground labels will be merged to + `self.num_stuff_classes`-th channel. + """ + gt_semantic_seg = gt_semantic_seg.int() + fg_mask = gt_semantic_seg < self.num_things_classes + bg_mask = (gt_semantic_seg >= self.num_things_classes) * ( + gt_semantic_seg < self.num_things_classes + self.num_stuff_classes) + + new_gt_seg = torch.clone(gt_semantic_seg) + new_gt_seg = torch.where(bg_mask, + gt_semantic_seg - self.num_things_classes, + new_gt_seg) + new_gt_seg = torch.where(fg_mask, + fg_mask.int() * self.num_stuff_classes, + new_gt_seg) + return new_gt_seg + + def loss(self, x: Union[Tensor, Tuple[Tensor]], + batch_data_samples: SampleList) -> Dict[str, Tensor]: + """ + Args: + x (Union[Tensor, Tuple[Tensor]]): Feature maps. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + Dict[str, Tensor]: The loss of semantic head. + """ + seg_preds = self(x)['seg_preds'] + gt_semantic_segs = [ + data_sample.gt_sem_seg.sem_seg + for data_sample in batch_data_samples + ] + + gt_semantic_segs = torch.stack(gt_semantic_segs) + if self.seg_rescale_factor != 1.0: + gt_semantic_segs = F.interpolate( + gt_semantic_segs.float(), + scale_factor=self.seg_rescale_factor, + mode='nearest').squeeze(1) + + # Things classes will be merged to one class in PanopticFPN. + gt_semantic_segs = self._set_things_to_void(gt_semantic_segs) + + if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]: + seg_preds = interpolate_as(seg_preds, gt_semantic_segs) + seg_preds = seg_preds.permute((0, 2, 3, 1)) + + loss_seg = self.loss_seg( + seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C] + gt_semantic_segs.reshape(-1).long()) + + return dict(loss_seg=loss_seg) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + nn.init.normal_(self.conv_logits.weight.data, 0, 0.01) + self.conv_logits.bias.data.zero_() + + def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]: + """Forward. + + Args: + x (Tuple[Tensor]): Multi scale Feature maps. + + Returns: + dict[str, Tensor]: semantic segmentation predictions and + feature maps. + """ + # the number of subnets must be not more than + # the length of features. + assert self.num_stages <= len(x) + + feats = [] + for i, layer in enumerate(self.conv_upsample_layers): + f = layer(x[self.start_level + i]) + feats.append(f) + + seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0) + seg_preds = self.conv_logits(seg_feats) + out = dict(seg_preds=seg_preds, seg_feats=seg_feats) + return out diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py b/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41625a61d6d1c38c633062c24b1e3455bd3ae2df --- /dev/null +++ b/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_panoptic_fusion_head import \ + BasePanopticFusionHead # noqa: F401,F403 +from .heuristic_fusion_head import HeuristicFusionHead # noqa: F401,F403 +from .maskformer_fusion_head import MaskFormerFusionHead # noqa: F401,F403 diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/base_panoptic_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/base_panoptic_fusion_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b20e1cd144eaebd042b8017f143c0a643adde1 --- /dev/null +++ b/mmdet/models/seg_heads/panoptic_fusion_heads/base_panoptic_fusion_head.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmengine.model import BaseModule + +from mmdet.registry import MODELS +from mmdet.utils import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class BasePanopticFusionHead(BaseModule, metaclass=ABCMeta): + """Base class for panoptic heads.""" + + def __init__(self, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + test_cfg: OptConfigType = None, + loss_panoptic: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = num_things_classes + num_stuff_classes + self.test_cfg = test_cfg + + if loss_panoptic: + self.loss_panoptic = MODELS.build(loss_panoptic) + else: + self.loss_panoptic = None + + @property + def with_loss(self) -> bool: + """bool: whether the panoptic head contains loss function.""" + return self.loss_panoptic is not None + + @abstractmethod + def loss(self, **kwargs): + """Loss function.""" + + @abstractmethod + def predict(self, **kwargs): + """Predict function.""" diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/heuristic_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/heuristic_fusion_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4a4200edd97f42e9a138e14a1d07328ad9b139 --- /dev/null +++ b/mmdet/models/seg_heads/panoptic_fusion_heads/heuristic_fusion_head.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmdet.evaluation.functional import INSTANCE_OFFSET +from mmdet.registry import MODELS +from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig, PixelList +from .base_panoptic_fusion_head import BasePanopticFusionHead + + +@MODELS.register_module() +class HeuristicFusionHead(BasePanopticFusionHead): + """Fusion Head with Heuristic method.""" + + def __init__(self, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super().__init__( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + test_cfg=test_cfg, + loss_panoptic=None, + init_cfg=init_cfg, + **kwargs) + + def loss(self, **kwargs) -> dict: + """HeuristicFusionHead has no training loss.""" + return dict() + + def _lay_masks(self, + mask_results: InstanceData, + overlap_thr: float = 0.5) -> Tensor: + """Lay instance masks to a result map. + + Args: + mask_results (:obj:`InstanceData`): Instance segmentation results, + each contains ``bboxes``, ``labels``, ``scores`` and ``masks``. + overlap_thr (float): Threshold to determine whether two masks + overlap. default: 0.5. + + Returns: + Tensor: The result map, (H, W). + """ + bboxes = mask_results.bboxes + scores = mask_results.scores + labels = mask_results.labels + masks = mask_results.masks + + num_insts = bboxes.shape[0] + id_map = torch.zeros( + masks.shape[-2:], device=bboxes.device, dtype=torch.long) + if num_insts == 0: + return id_map, labels + + # Sort by score to use heuristic fusion + order = torch.argsort(-scores) + bboxes = bboxes[order] + labels = labels[order] + segm_masks = masks[order] + + instance_id = 1 + left_labels = [] + for idx in range(bboxes.shape[0]): + _cls = labels[idx] + _mask = segm_masks[idx] + instance_id_map = torch.ones_like( + _mask, dtype=torch.long) * instance_id + area = _mask.sum() + if area == 0: + continue + + pasted = id_map > 0 + intersect = (_mask * pasted).sum() + if (intersect / (area + 1e-5)) > overlap_thr: + continue + + _part = _mask * (~pasted) + id_map = torch.where(_part, instance_id_map, id_map) + left_labels.append(_cls) + instance_id += 1 + + if len(left_labels) > 0: + instance_labels = torch.stack(left_labels) + else: + instance_labels = bboxes.new_zeros((0, ), dtype=torch.long) + assert instance_id == (len(instance_labels) + 1) + return id_map, instance_labels + + def _predict_single(self, mask_results: InstanceData, seg_preds: Tensor, + **kwargs) -> PixelData: + """Fuse the results of instance and semantic segmentations. + + Args: + mask_results (:obj:`InstanceData`): Instance segmentation results, + each contains ``bboxes``, ``labels``, ``scores`` and ``masks``. + seg_preds (Tensor): The semantic segmentation results, + (num_stuff + 1, H, W). + + Returns: + Tensor: The panoptic segmentation result, (H, W). + """ + id_map, labels = self._lay_masks(mask_results, + self.test_cfg.mask_overlap) + + seg_results = seg_preds.argmax(dim=0) + seg_results = seg_results + self.num_things_classes + + pan_results = seg_results + instance_id = 1 + for idx in range(len(mask_results)): + _mask = id_map == (idx + 1) + if _mask.sum() == 0: + continue + _cls = labels[idx] + # simply trust detection + segment_id = _cls + instance_id * INSTANCE_OFFSET + pan_results[_mask] = segment_id + instance_id += 1 + + ids, counts = torch.unique( + pan_results % INSTANCE_OFFSET, return_counts=True) + stuff_ids = ids[ids >= self.num_things_classes] + stuff_counts = counts[ids >= self.num_things_classes] + ignore_stuff_ids = stuff_ids[ + stuff_counts < self.test_cfg.stuff_area_limit] + + assert pan_results.ndim == 2 + pan_results[(pan_results.unsqueeze(2) == ignore_stuff_ids.reshape( + 1, 1, -1)).any(dim=2)] = self.num_classes + + pan_results = PixelData(sem_seg=pan_results[None].int()) + return pan_results + + def predict(self, mask_results_list: InstanceList, + seg_preds_list: List[Tensor], **kwargs) -> PixelList: + """Predict results by fusing the results of instance and semantic + segmentations. + + Args: + mask_results_list (list[:obj:`InstanceData`]): Instance + segmentation results, each contains ``bboxes``, ``labels``, + ``scores`` and ``masks``. + seg_preds_list (Tensor): List of semantic segmentation results. + + Returns: + List[PixelData]: Panoptic segmentation result. + """ + results_list = [ + self._predict_single(mask_results_list[i], seg_preds_list[i]) + for i in range(len(mask_results_list)) + ] + + return results_list diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1b76e6b45bb9be2584f8b3eca2e5e1c0809249fa --- /dev/null +++ b/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmdet.evaluation.functional import INSTANCE_OFFSET +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.mask import mask2bbox +from mmdet.utils import OptConfigType, OptMultiConfig +from .base_panoptic_fusion_head import BasePanopticFusionHead + + +@MODELS.register_module() +class MaskFormerFusionHead(BasePanopticFusionHead): + """MaskFormer fusion head which postprocesses results for panoptic + segmentation, instance segmentation and semantic segmentation.""" + + def __init__(self, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + test_cfg: OptConfigType = None, + loss_panoptic: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs): + super().__init__( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + test_cfg=test_cfg, + loss_panoptic=loss_panoptic, + init_cfg=init_cfg, + **kwargs) + + def loss(self, **kwargs): + """MaskFormerFusionHead has no training loss.""" + return dict() + + def panoptic_postprocess(self, mask_cls: Tensor, + mask_pred: Tensor) -> PixelData: + """Panoptic segmengation inference. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + :obj:`PixelData`: Panoptic segment result of shape \ + (h, w), each element in Tensor means: \ + ``segment_id = _cls + instance_id * INSTANCE_OFFSET``. + """ + object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) + iou_thr = self.test_cfg.get('iou_thr', 0.8) + filter_low_score = self.test_cfg.get('filter_low_score', False) + + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.num_classes) & (scores > object_mask_thr) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.full((h, w), + self.num_classes, + dtype=torch.int32, + device=cur_masks.device) + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + pass + else: + cur_mask_ids = cur_prob_masks.argmax(0) + instance_id = 1 + for k in range(cur_classes.shape[0]): + pred_class = int(cur_classes[k].item()) + isthing = pred_class < self.num_things_classes + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + + if filter_low_score: + mask = mask & (cur_masks[k] >= 0.5) + + if mask_area > 0 and original_area > 0: + if mask_area / original_area < iou_thr: + continue + + if not isthing: + # different stuff regions of same class will be + # merged here, and stuff share the instance_id 0. + panoptic_seg[mask] = pred_class + else: + panoptic_seg[mask] = ( + pred_class + instance_id * INSTANCE_OFFSET) + instance_id += 1 + + return PixelData(sem_seg=panoptic_seg[None]) + + def semantic_postprocess(self, mask_cls: Tensor, + mask_pred: Tensor) -> PixelData: + """Semantic segmengation postprocess. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + :obj:`PixelData`: Semantic segment result. + """ + # TODO add semantic segmentation result + raise NotImplementedError + + def instance_postprocess(self, mask_cls: Tensor, + mask_pred: Tensor) -> InstanceData: + """Instance segmengation postprocess. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + :obj:`InstanceData`: Instance segmentation results. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + max_per_image = self.test_cfg.get('max_per_image', 100) + num_queries = mask_cls.shape[0] + # shape (num_queries, num_class) + scores = F.softmax(mask_cls, dim=-1)[:, :-1] + # shape (num_queries * num_class, ) + labels = torch.arange(self.num_classes, device=mask_cls.device).\ + unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + scores_per_image, top_indices = scores.flatten(0, 1).topk( + max_per_image, sorted=False) + labels_per_image = labels[top_indices] + + query_indices = top_indices // self.num_classes + mask_pred = mask_pred[query_indices] + + # extract things + is_thing = labels_per_image < self.num_things_classes + scores_per_image = scores_per_image[is_thing] + labels_per_image = labels_per_image[is_thing] + mask_pred = mask_pred[is_thing] + + mask_pred_binary = (mask_pred > 0).float() + mask_scores_per_image = (mask_pred.sigmoid() * + mask_pred_binary).flatten(1).sum(1) / ( + mask_pred_binary.flatten(1).sum(1) + 1e-6) + det_scores = scores_per_image * mask_scores_per_image + mask_pred_binary = mask_pred_binary.bool() + bboxes = mask2bbox(mask_pred_binary) + + results = InstanceData() + results.bboxes = bboxes + results.labels = labels_per_image + results.scores = det_scores + results.masks = mask_pred_binary + return results + + def predict(self, + mask_cls_results: Tensor, + mask_pred_results: Tensor, + batch_data_samples: SampleList, + rescale: bool = False, + **kwargs) -> List[dict]: + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + mask_cls_results (Tensor): Mask classification logits, + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + mask_pred_results (Tensor): Mask logits, shape + (batch_size, num_queries, h, w). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): If True, return boxes in + original image space. Default False. + + Returns: + list[dict]: Instance segmentation \ + results and panoptic segmentation results for each \ + image. + + .. code-block:: none + + [ + { + 'pan_results': PixelData, + 'ins_results': InstanceData, + # semantic segmentation results are not supported yet + 'sem_results': PixelData + }, + ... + ] + """ + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + panoptic_on = self.test_cfg.get('panoptic_on', True) + semantic_on = self.test_cfg.get('semantic_on', False) + instance_on = self.test_cfg.get('instance_on', False) + assert not semantic_on, 'segmantic segmentation '\ + 'results are not supported yet.' + + results = [] + for mask_cls_result, mask_pred_result, meta in zip( + mask_cls_results, mask_pred_results, batch_img_metas): + # remove padding + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_result[:, :img_height, :img_width] + + if rescale: + # return result in original resolution + ori_height, ori_width = meta['ori_shape'][:2] + mask_pred_result = F.interpolate( + mask_pred_result[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + + result = dict() + if panoptic_on: + pan_results = self.panoptic_postprocess( + mask_cls_result, mask_pred_result) + result['pan_results'] = pan_results + + if instance_on: + ins_results = self.instance_postprocess( + mask_cls_result, mask_pred_result) + result['ins_results'] = ins_results + + if semantic_on: + sem_results = self.semantic_postprocess( + mask_cls_result, mask_pred_result) + result['sem_results'] = sem_results + + results.append(result) + + return results diff --git a/mmdet/models/task_modules/__init__.py b/mmdet/models/task_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfd8f058ed656760e0b1a3fd6118f31a799cb11 --- /dev/null +++ b/mmdet/models/task_modules/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assigners import * # noqa: F401,F403 +from .builder import (ANCHOR_GENERATORS, BBOX_ASSIGNERS, BBOX_CODERS, + BBOX_SAMPLERS, IOU_CALCULATORS, MATCH_COSTS, + PRIOR_GENERATORS, build_anchor_generator, build_assigner, + build_bbox_coder, build_iou_calculator, build_match_cost, + build_prior_generator, build_sampler) +from .coders import * # noqa: F401,F403 +from .prior_generators import * # noqa: F401,F403 +from .samplers import * # noqa: F401,F403 +from .tracking import * # noqa: F401,F403 + +__all__ = [ + 'ANCHOR_GENERATORS', 'PRIOR_GENERATORS', 'BBOX_ASSIGNERS', 'BBOX_SAMPLERS', + 'MATCH_COSTS', 'BBOX_CODERS', 'IOU_CALCULATORS', 'build_anchor_generator', + 'build_prior_generator', 'build_assigner', 'build_sampler', + 'build_iou_calculator', 'build_match_cost', 'build_bbox_coder' +] diff --git a/mmdet/models/task_modules/assigners/__init__.py b/mmdet/models/task_modules/assigners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e564f24c95b1cc6be8a35a1a309ebf10e582032 --- /dev/null +++ b/mmdet/models/task_modules/assigners/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .approx_max_iou_assigner import ApproxMaxIoUAssigner +from .assign_result import AssignResult +from .atss_assigner import ATSSAssigner +from .base_assigner import BaseAssigner +from .center_region_assigner import CenterRegionAssigner +from .dynamic_soft_label_assigner import DynamicSoftLabelAssigner +from .grid_assigner import GridAssigner +from .hungarian_assigner import HungarianAssigner +from .iou2d_calculator import BboxOverlaps2D, BboxOverlaps2D_GLIP +from .match_cost import (BBoxL1Cost, BinaryFocalLossCost, ClassificationCost, + CrossEntropyLossCost, DiceCost, FocalLossCost, + IoUCost) +from .max_iou_assigner import MaxIoUAssigner +from .multi_instance_assigner import MultiInstanceAssigner +from .point_assigner import PointAssigner +from .region_assigner import RegionAssigner +from .sim_ota_assigner import SimOTAAssigner +from .task_aligned_assigner import TaskAlignedAssigner +from .topk_hungarian_assigner import TopkHungarianAssigner +from .uniform_assigner import UniformAssigner + +__all__ = [ + 'BaseAssigner', 'BinaryFocalLossCost', 'MaxIoUAssigner', + 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', + 'CenterRegionAssigner', 'GridAssigner', 'HungarianAssigner', + 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', + 'TaskAlignedAssigner', 'TopkHungarianAssigner', 'BBoxL1Cost', + 'ClassificationCost', 'CrossEntropyLossCost', 'DiceCost', 'FocalLossCost', + 'IoUCost', 'BboxOverlaps2D', 'DynamicSoftLabelAssigner', + 'MultiInstanceAssigner', 'BboxOverlaps2D_GLIP' +] diff --git a/mmdet/models/task_modules/assigners/approx_max_iou_assigner.py b/mmdet/models/task_modules/assigners/approx_max_iou_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..471d54e578d640da242355b54cebe05658309ca2 --- /dev/null +++ b/mmdet/models/task_modules/assigners/approx_max_iou_assigner.py @@ -0,0 +1,162 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from .assign_result import AssignResult +from .max_iou_assigner import MaxIoUAssigner + + +@TASK_UTILS.register_module() +class ApproxMaxIoUAssigner(MaxIoUAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with an integer indicating the ground truth + index. (semi-positive index: gt label (0-based), -1: background) + + - -1: negative sample, no assigned gt + - semi-positive integer: positive sample, index (0-based) of assigned gt + + Args: + pos_iou_thr (float): IoU threshold for positive bboxes. + neg_iou_thr (float or tuple): IoU threshold for negative bboxes. + min_pos_iou (float): Minimum iou for a bbox to be considered as a + positive bbox. Positive samples can have smaller IoU than + pos_iou_thr due to the 4th step (assign max IoU sample to each gt). + gt_max_assign_all (bool): Whether to assign all bboxes with the same + highest overlap with some gt to that gt. + ignore_iof_thr (float): IoF threshold for ignoring bboxes (if + `gt_bboxes_ignore` is specified). Negative values mean not + ignoring any bboxes. + ignore_wrt_candidates (bool): Whether to compute the iof between + `bboxes` and `gt_bboxes_ignore`, or the contrary. + match_low_quality (bool): Whether to allow quality matches. This is + usually allowed for RPN and single stage detectors, but not allowed + in the second stage. + gpu_assign_thr (int): The upper bound of the number of GT for GPU + assign. When the number of gt is above this threshold, will assign + on CPU device. Negative values mean not assign on CPU. + iou_calculator (:obj:`ConfigDict` or dict): Config of overlaps + Calculator. + """ + + def __init__( + self, + pos_iou_thr: float, + neg_iou_thr: Union[float, tuple], + min_pos_iou: float = .0, + gt_max_assign_all: bool = True, + ignore_iof_thr: float = -1, + ignore_wrt_candidates: bool = True, + match_low_quality: bool = True, + gpu_assign_thr: int = -1, + iou_calculator: Union[ConfigDict, dict] = dict(type='BboxOverlaps2D') + ) -> None: + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_pos_iou = min_pos_iou + self.gt_max_assign_all = gt_max_assign_all + self.ignore_iof_thr = ignore_iof_thr + self.ignore_wrt_candidates = ignore_wrt_candidates + self.gpu_assign_thr = gpu_assign_thr + self.match_low_quality = match_low_quality + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to approxs. + + This method assign a gt bbox to each group of approxs (bboxes), + each group of approxs is represent by a base approx (bbox) and + will be assigned with -1, or a semi-positive number. + background_label (-1) means negative sample, + semi-positive number is the index (0-based) of assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every bbox to background_label (-1) + 2. use the max IoU of each group of approxs to assign + 2. assign proposals whose iou with all gts < neg_iou_thr to background + 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, + assign it to that bbox + 4. for each gt bbox, assign its nearest proposals (may be more than + one) to itself + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). ``approxs`` means the + group of approxs aligned with ``priors``, has shape + (n, num_approxs, 4). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + """ + squares = pred_instances.priors + approxs = pred_instances.approxs + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + gt_bboxes_ignore = None if gt_instances_ignore is None else \ + gt_instances_ignore.get('bboxes', None) + approxs_per_octave = approxs.size(1) + + num_squares = squares.size(0) + num_gts = gt_bboxes.size(0) + + if num_squares == 0 or num_gts == 0: + # No predictions and/or truth, return empty assignment + overlaps = approxs.new(num_gts, num_squares) + assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) + return assign_result + + # re-organize anchors by approxs_per_octave x num_squares + approxs = torch.transpose(approxs, 0, 1).contiguous().view(-1, 4) + assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( + num_gts > self.gpu_assign_thr) else False + # compute overlap and assign gt on CPU when number of GT is large + if assign_on_cpu: + device = approxs.device + approxs = approxs.cpu() + gt_bboxes = gt_bboxes.cpu() + if gt_bboxes_ignore is not None: + gt_bboxes_ignore = gt_bboxes_ignore.cpu() + if gt_labels is not None: + gt_labels = gt_labels.cpu() + all_overlaps = self.iou_calculator(approxs, gt_bboxes) + + overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares, + num_gts).max(dim=0) + overlaps = torch.transpose(overlaps, 0, 1) + + if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None + and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0): + if self.ignore_wrt_candidates: + ignore_overlaps = self.iou_calculator( + squares, gt_bboxes_ignore, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + else: + ignore_overlaps = self.iou_calculator( + gt_bboxes_ignore, squares, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) + overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 + + assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) + if assign_on_cpu: + assign_result.gt_inds = assign_result.gt_inds.to(device) + assign_result.max_overlaps = assign_result.max_overlaps.to(device) + if assign_result.labels is not None: + assign_result.labels = assign_result.labels.to(device) + return assign_result diff --git a/mmdet/models/task_modules/assigners/assign_result.py b/mmdet/models/task_modules/assigners/assign_result.py new file mode 100644 index 0000000000000000000000000000000000000000..56ca2c3c18fee94cc4a039b769e42521bd14907d --- /dev/null +++ b/mmdet/models/task_modules/assigners/assign_result.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from mmdet.utils import util_mixins + + +class AssignResult(util_mixins.NiceRepr): + """Stores assignments between predicted and truth boxes. + + Attributes: + num_gts (int): the number of truth boxes considered when computing this + assignment + gt_inds (Tensor): for each predicted box indicates the 1-based + index of the assigned truth box. 0 means unassigned and -1 means + ignore. + max_overlaps (Tensor): the iou between the predicted box and its + assigned truth box. + labels (Tensor): If specified, for each predicted box + indicates the category label of the assigned truth box. + + Example: + >>> # An assign result between 4 predicted boxes and 9 true boxes + >>> # where only two boxes were assigned. + >>> num_gts = 9 + >>> max_overlaps = torch.LongTensor([0, .5, .9, 0]) + >>> gt_inds = torch.LongTensor([-1, 1, 2, 0]) + >>> labels = torch.LongTensor([0, 3, 4, 0]) + >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels) + >>> print(str(self)) # xdoctest: +IGNORE_WANT + + >>> # Force addition of gt labels (when adding gt as proposals) + >>> new_labels = torch.LongTensor([3, 4, 5]) + >>> self.add_gt_(new_labels) + >>> print(str(self)) # xdoctest: +IGNORE_WANT + + """ + + def __init__(self, num_gts: int, gt_inds: Tensor, max_overlaps: Tensor, + labels: Tensor) -> None: + self.num_gts = num_gts + self.gt_inds = gt_inds + self.max_overlaps = max_overlaps + self.labels = labels + # Interface for possible user-defined properties + self._extra_properties = {} + + @property + def num_preds(self): + """int: the number of predictions in this assignment""" + return len(self.gt_inds) + + def set_extra_property(self, key, value): + """Set user-defined new property.""" + assert key not in self.info + self._extra_properties[key] = value + + def get_extra_property(self, key): + """Get user-defined property.""" + return self._extra_properties.get(key, None) + + @property + def info(self): + """dict: a dictionary of info about the object""" + basic_info = { + 'num_gts': self.num_gts, + 'num_preds': self.num_preds, + 'gt_inds': self.gt_inds, + 'max_overlaps': self.max_overlaps, + 'labels': self.labels, + } + basic_info.update(self._extra_properties) + return basic_info + + def __nice__(self): + """str: a "nice" summary string describing this assign result""" + parts = [] + parts.append(f'num_gts={self.num_gts!r}') + if self.gt_inds is None: + parts.append(f'gt_inds={self.gt_inds!r}') + else: + parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}') + if self.max_overlaps is None: + parts.append(f'max_overlaps={self.max_overlaps!r}') + else: + parts.append('max_overlaps.shape=' + f'{tuple(self.max_overlaps.shape)!r}') + if self.labels is None: + parts.append(f'labels={self.labels!r}') + else: + parts.append(f'labels.shape={tuple(self.labels.shape)!r}') + return ', '.join(parts) + + @classmethod + def random(cls, **kwargs): + """Create random AssignResult for tests or debugging. + + Args: + num_preds: number of predicted boxes + num_gts: number of true boxes + p_ignore (float): probability of a predicted box assigned to an + ignored truth + p_assigned (float): probability of a predicted box not being + assigned + p_use_label (float | bool): with labels or not + rng (None | int | numpy.random.RandomState): seed or state + + Returns: + :obj:`AssignResult`: Randomly generated assign results. + + Example: + >>> from mmdet.models.task_modules.assigners.assign_result import * # NOQA + >>> self = AssignResult.random() + >>> print(self.info) + """ + from ..samplers.sampling_result import ensure_rng + rng = ensure_rng(kwargs.get('rng', None)) + + num_gts = kwargs.get('num_gts', None) + num_preds = kwargs.get('num_preds', None) + p_ignore = kwargs.get('p_ignore', 0.3) + p_assigned = kwargs.get('p_assigned', 0.7) + num_classes = kwargs.get('num_classes', 3) + + if num_gts is None: + num_gts = rng.randint(0, 8) + if num_preds is None: + num_preds = rng.randint(0, 16) + + if num_gts == 0: + max_overlaps = torch.zeros(num_preds, dtype=torch.float32) + gt_inds = torch.zeros(num_preds, dtype=torch.int64) + labels = torch.zeros(num_preds, dtype=torch.int64) + + else: + import numpy as np + + # Create an overlap for each predicted box + max_overlaps = torch.from_numpy(rng.rand(num_preds)) + + # Construct gt_inds for each predicted box + is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned) + # maximum number of assignments constraints + n_assigned = min(num_preds, min(num_gts, is_assigned.sum())) + + assigned_idxs = np.where(is_assigned)[0] + rng.shuffle(assigned_idxs) + assigned_idxs = assigned_idxs[0:n_assigned] + assigned_idxs.sort() + + is_assigned[:] = 0 + is_assigned[assigned_idxs] = True + + is_ignore = torch.from_numpy( + rng.rand(num_preds) < p_ignore) & is_assigned + + gt_inds = torch.zeros(num_preds, dtype=torch.int64) + + true_idxs = np.arange(num_gts) + rng.shuffle(true_idxs) + true_idxs = torch.from_numpy(true_idxs) + gt_inds[is_assigned] = true_idxs[:n_assigned].long() + + gt_inds = torch.from_numpy( + rng.randint(1, num_gts + 1, size=num_preds)) + gt_inds[is_ignore] = -1 + gt_inds[~is_assigned] = 0 + max_overlaps[~is_assigned] = 0 + + if num_classes == 0: + labels = torch.zeros(num_preds, dtype=torch.int64) + else: + labels = torch.from_numpy( + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + rng.randint(0, num_classes, size=num_preds)) + labels[~is_assigned] = 0 + + self = cls(num_gts, gt_inds, max_overlaps, labels) + return self + + def add_gt_(self, gt_labels): + """Add ground truth as assigned results. + + Args: + gt_labels (torch.Tensor): Labels of gt boxes + """ + self_inds = torch.arange( + 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) + self.gt_inds = torch.cat([self_inds, self.gt_inds]) + + self.max_overlaps = torch.cat( + [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps]) + + self.labels = torch.cat([gt_labels, self.labels]) diff --git a/mmdet/models/task_modules/assigners/atss_assigner.py b/mmdet/models/task_modules/assigners/atss_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..2796b990c5ae4c56bcf314e1342671d950232ae6 --- /dev/null +++ b/mmdet/models/task_modules/assigners/atss_assigner.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +def bbox_center_distance(bboxes: Tensor, priors: Tensor) -> Tensor: + """Compute the center distance between bboxes and priors. + + Args: + bboxes (Tensor): Shape (n, 4) for , "xyxy" format. + priors (Tensor): Shape (n, 4) for priors, "xyxy" format. + + Returns: + Tensor: Center distances between bboxes and priors. + """ + bbox_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0 + bbox_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0 + bbox_points = torch.stack((bbox_cx, bbox_cy), dim=1) + + priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0 + priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0 + priors_points = torch.stack((priors_cx, priors_cy), dim=1) + + distances = (priors_points[:, None, :] - + bbox_points[None, :, :]).pow(2).sum(-1).sqrt() + + return distances + + +@TASK_UTILS.register_module() +class ATSSAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each prior. + + Each proposals will be assigned with `0` or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + If ``alpha`` is not None, it means that the dynamic cost + ATSSAssigner is adopted, which is currently only used in the DDOD. + + Args: + topk (int): number of priors selected in each level + alpha (float, optional): param of cost rate for each proposal only + in DDOD. Defaults to None. + iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou + calculator. Defaults to ``dict(type='BboxOverlaps2D')`` + ignore_iof_thr (float): IoF threshold for ignoring bboxes (if + `gt_bboxes_ignore` is specified). Negative values mean not + ignoring any bboxes. Defaults to -1. + """ + + def __init__(self, + topk: int, + alpha: Optional[float] = None, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D'), + ignore_iof_thr: float = -1) -> None: + self.topk = topk + self.alpha = alpha + self.iou_calculator = TASK_UTILS.build(iou_calculator) + self.ignore_iof_thr = ignore_iof_thr + + # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py + def assign( + self, + pred_instances: InstanceData, + num_level_priors: List[int], + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None + ) -> AssignResult: + """Assign gt to priors. + + The assignment is done in following steps + + 1. compute iou between all prior (prior of all pyramid levels) and gt + 2. compute center distance between all prior and gt + 3. on each pyramid level, for each gt, select k prior whose center + are closest to the gt center, so we total select k*l prior as + candidates for each gt + 4. get corresponding iou for the these candidates, and compute the + mean and std, set mean + std as the iou threshold + 5. select these candidates whose iou are greater than or equal to + the threshold as positive + 6. limit the positive sample's center in gt + + If ``alpha`` is not None, and ``cls_scores`` and `bbox_preds` + are not None, the overlaps calculation in the first step + will also include dynamic cost, which is currently only used in + the DDOD. + + Args: + pred_instances (:obj:`InstaceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors, points, or bboxes predicted by the model, + shape(n, 4). + num_level_priors (List): Number of bboxes in each level + gt_instances (:obj:`InstaceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + gt_instances_ignore (:obj:`InstaceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + gt_labels = gt_instances.labels + if gt_instances_ignore is not None: + gt_bboxes_ignore = gt_instances_ignore.bboxes + else: + gt_bboxes_ignore = None + + INF = 100000000 + priors = priors[:, :4] + num_gt, num_priors = gt_bboxes.size(0), priors.size(0) + + message = 'Invalid alpha parameter because cls_scores or ' \ + 'bbox_preds are None. If you want to use the ' \ + 'cost-based ATSSAssigner, please set cls_scores, ' \ + 'bbox_preds and self.alpha at the same time. ' + + # compute iou between all bbox and gt + if self.alpha is None: + # ATSSAssigner + overlaps = self.iou_calculator(priors, gt_bboxes) + if ('scores' in pred_instances or 'bboxes' in pred_instances): + warnings.warn(message) + + else: + # Dynamic cost ATSSAssigner in DDOD + assert ('scores' in pred_instances + and 'bboxes' in pred_instances), message + cls_scores = pred_instances.scores + bbox_preds = pred_instances.bboxes + + # compute cls cost for bbox and GT + cls_cost = torch.sigmoid(cls_scores[:, gt_labels]) + + # compute iou between all bbox and gt + overlaps = self.iou_calculator(bbox_preds, gt_bboxes) + + # make sure that we are in element-wise multiplication + assert cls_cost.shape == overlaps.shape + + # overlaps is actually a cost matrix + overlaps = cls_cost**(1 - self.alpha) * overlaps**self.alpha + + # assign 0 by default + assigned_gt_inds = overlaps.new_full((num_priors, ), + 0, + dtype=torch.long) + + if num_gt == 0 or num_priors == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = overlaps.new_zeros((num_priors, )) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + assigned_labels = overlaps.new_full((num_priors, ), + -1, + dtype=torch.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + # compute center distance between all bbox and gt + distances = bbox_center_distance(gt_bboxes, priors) + + if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None + and gt_bboxes_ignore.numel() > 0 and priors.numel() > 0): + ignore_overlaps = self.iou_calculator( + priors, gt_bboxes_ignore, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr + distances[ignore_idxs, :] = INF + assigned_gt_inds[ignore_idxs] = -1 + + # Selecting candidates based on the center distance + candidate_idxs = [] + start_idx = 0 + for level, priors_per_level in enumerate(num_level_priors): + # on each pyramid level, for each gt, + # select k bbox whose center are closest to the gt center + end_idx = start_idx + priors_per_level + distances_per_level = distances[start_idx:end_idx, :] + selectable_k = min(self.topk, priors_per_level) + _, topk_idxs_per_level = distances_per_level.topk( + selectable_k, dim=0, largest=False) + candidate_idxs.append(topk_idxs_per_level + start_idx) + start_idx = end_idx + candidate_idxs = torch.cat(candidate_idxs, dim=0) + + # get corresponding iou for the these candidates, and compute the + # mean and std, set mean + std as the iou threshold + candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)] + overlaps_mean_per_gt = candidate_overlaps.mean(0) + overlaps_std_per_gt = candidate_overlaps.std(0) + overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt + + is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :] + + # limit the positive sample's center in gt + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_priors + priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0 + priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0 + ep_priors_cx = priors_cx.view(1, -1).expand( + num_gt, num_priors).contiguous().view(-1) + ep_priors_cy = priors_cy.view(1, -1).expand( + num_gt, num_priors).contiguous().view(-1) + candidate_idxs = candidate_idxs.view(-1) + + # calculate the left, top, right, bottom distance between positive + # prior center and gt side + l_ = ep_priors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] + t_ = ep_priors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - ep_priors_cx[candidate_idxs].view(-1, num_gt) + b_ = gt_bboxes[:, 3] - ep_priors_cy[candidate_idxs].view(-1, num_gt) + is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 + + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, + # the one with the highest IoU will be selected. + overlaps_inf = torch.full_like(overlaps, + -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] + overlaps_inf = overlaps_inf.view(num_gt, -1).t() + + max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) + assigned_gt_inds[ + max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 + + assigned_labels = assigned_gt_inds.new_full((num_priors, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - + 1] + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/base_assigner.py b/mmdet/models/task_modules/assigners/base_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..b12280ad746c7557008313dd936a62a99e8c78d5 --- /dev/null +++ b/mmdet/models/task_modules/assigners/base_assigner.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Optional + +from mmengine.structures import InstanceData + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns boxes to ground truth boxes.""" + + @abstractmethod + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs): + """Assign boxes to either a ground truth boxes or a negative boxes.""" diff --git a/mmdet/models/task_modules/assigners/center_region_assigner.py b/mmdet/models/task_modules/assigners/center_region_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..11c8055c67cdf46c1ae0f877e88192db33795581 --- /dev/null +++ b/mmdet/models/task_modules/assigners/center_region_assigner.py @@ -0,0 +1,366 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +def scale_boxes(bboxes: Tensor, scale: float) -> Tensor: + """Expand an array of boxes by a given scale. + + Args: + bboxes (Tensor): Shape (m, 4) + scale (float): The scale factor of bboxes + + Returns: + Tensor: Shape (m, 4). Scaled bboxes + """ + assert bboxes.size(1) == 4 + w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5 + h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5 + x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5 + y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_scaled = torch.zeros_like(bboxes) + boxes_scaled[:, 0] = x_c - w_half + boxes_scaled[:, 2] = x_c + w_half + boxes_scaled[:, 1] = y_c - h_half + boxes_scaled[:, 3] = y_c + h_half + return boxes_scaled + + +def is_located_in(points: Tensor, bboxes: Tensor) -> Tensor: + """Are points located in bboxes. + + Args: + points (Tensor): Points, shape: (m, 2). + bboxes (Tensor): Bounding boxes, shape: (n, 4). + + Return: + Tensor: Flags indicating if points are located in bboxes, + shape: (m, n). + """ + assert points.size(1) == 2 + assert bboxes.size(1) == 4 + return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \ + (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \ + (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \ + (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0)) + + +def bboxes_area(bboxes: Tensor) -> Tensor: + """Compute the area of an array of bboxes. + + Args: + bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4) + + Returns: + Tensor: Area of the bboxes. Shape: (m, ) + """ + assert bboxes.size(1) == 4 + w = (bboxes[:, 2] - bboxes[:, 0]) + h = (bboxes[:, 3] - bboxes[:, 1]) + areas = w * h + return areas + + +@TASK_UTILS.register_module() +class CenterRegionAssigner(BaseAssigner): + """Assign pixels at the center region of a bbox as positive. + + Each proposals will be assigned with `-1`, `0`, or a positive integer + indicating the ground truth index. + - -1: negative samples + - semi-positive numbers: positive sample, index (0-based) of assigned gt + + Args: + pos_scale (float): Threshold within which pixels are + labelled as positive. + neg_scale (float): Threshold above which pixels are + labelled as positive. + min_pos_iof (float): Minimum iof of a pixel with a gt to be + labelled as positive. Default: 1e-2 + ignore_gt_scale (float): Threshold within which the pixels + are ignored when the gt is labelled as shadowed. Default: 0.5 + foreground_dominate (bool): If True, the bbox will be assigned as + positive when a gt's kernel region overlaps with another's shadowed + (ignored) region, otherwise it is set as ignored. Default to False. + iou_calculator (:obj:`ConfigDict` or dict): Config of overlaps + Calculator. + """ + + def __init__( + self, + pos_scale: float, + neg_scale: float, + min_pos_iof: float = 1e-2, + ignore_gt_scale: float = 0.5, + foreground_dominate: bool = False, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D') + ) -> None: + self.pos_scale = pos_scale + self.neg_scale = neg_scale + self.min_pos_iof = min_pos_iof + self.ignore_gt_scale = ignore_gt_scale + self.foreground_dominate = foreground_dominate + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def get_gt_priorities(self, gt_bboxes: Tensor) -> Tensor: + """Get gt priorities according to their areas. + + Smaller gt has higher priority. + + Args: + gt_bboxes (Tensor): Ground truth boxes, shape (k, 4). + + Returns: + Tensor: The priority of gts so that gts with larger priority is + more likely to be assigned. Shape (k, ) + """ + gt_areas = bboxes_area(gt_bboxes) + # Rank all gt bbox areas. Smaller objects has larger priority + _, sort_idx = gt_areas.sort(descending=True) + sort_idx = sort_idx.argsort() + return sort_idx + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to bboxes. + + This method assigns gts to every prior (proposal/anchor), each prior + will be assigned with -1, or a semi-positive number. -1 means + negative sample, semi-positive number is the index (0-based) of + assigned gt. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assigned result. Note that shadowed_labels + of shape (N, 2) is also added as an `assign_result` attribute. + `shadowed_labels` is a tensor composed of N pairs of anchor_ind, + class_label], where N is the number of anchors that lie in the + outer region of a gt, anchor_ind is the shadowed anchor index + and class_label is the shadowed class label. + + Example: + >>> from mmengine.structures import InstanceData + >>> self = CenterRegionAssigner(0.2, 0.2) + >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10], + ... [10, 10, 20, 20]]) + >>> gt_instances = InstanceData() + >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 10]]) + >>> gt_instances.labels = torch.Tensor([0]) + >>> assign_result = self.assign(pred_instances, gt_instances) + >>> expected_gt_inds = torch.LongTensor([1, 0]) + >>> assert torch.all(assign_result.gt_inds == expected_gt_inds) + """ + # There are in total 5 steps in the pixel assignment + # 1. Find core (the center region, say inner 0.2) + # and shadow (the relatively ourter part, say inner 0.2-0.5) + # regions of every gt. + # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions + # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in + # the image. + # 3.1. For overlapping objects, the prior bboxes in gt_core is + # assigned with the object with smallest area + # 4. Assign prior bboxes with class label according to its gt id. + # 4.1. Assign -1 to prior bboxes lying in shadowed gts + # 4.2. Assign positive prior boxes with the corresponding label + # 5. Find pixels lying in the shadow of an object and assign them with + # background label, but set the loss weight of its corresponding + # gt to zero. + + # TODO not extract bboxes in assign. + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + gt_labels = gt_instances.labels + + assert priors.size(1) == 4, 'priors must have size of 4' + # 1. Find core positive and shadow region of every gt + gt_core = scale_boxes(gt_bboxes, self.pos_scale) + gt_shadow = scale_boxes(gt_bboxes, self.neg_scale) + + # 2. Find prior bboxes that lie in gt_core and gt_shadow regions + prior_centers = (priors[:, 2:4] + priors[:, 0:2]) / 2 + # The center points lie within the gt boxes + is_prior_in_gt = is_located_in(prior_centers, gt_bboxes) + # Only calculate prior and gt_core IoF. This enables small prior bboxes + # to match large gts + prior_and_gt_core_overlaps = self.iou_calculator( + priors, gt_core, mode='iof') + # The center point of effective priors should be within the gt box + is_prior_in_gt_core = is_prior_in_gt & ( + prior_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k) + + is_prior_in_gt_shadow = ( + self.iou_calculator(priors, gt_shadow, mode='iof') > + self.min_pos_iof) + # Rule out center effective positive pixels + is_prior_in_gt_shadow &= (~is_prior_in_gt_core) + + num_gts, num_priors = gt_bboxes.size(0), priors.size(0) + if num_gts == 0 or num_priors == 0: + # If no gts exist, assign all pixels to negative + assigned_gt_ids = \ + is_prior_in_gt_core.new_zeros((num_priors,), + dtype=torch.long) + pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2)) + else: + # Step 3: assign a one-hot gt id to each pixel, and smaller objects + # have high priority to assign the pixel. + sort_idx = self.get_gt_priorities(gt_bboxes) + assigned_gt_ids, pixels_in_gt_shadow = \ + self.assign_one_hot_gt_indices(is_prior_in_gt_core, + is_prior_in_gt_shadow, + gt_priority=sort_idx) + + if (gt_instances_ignore is not None + and gt_instances_ignore.bboxes.numel() > 0): + # No ground truth or boxes, return empty assignment + gt_bboxes_ignore = gt_instances_ignore.bboxes + gt_bboxes_ignore = scale_boxes( + gt_bboxes_ignore, scale=self.ignore_gt_scale) + is_prior_in_ignored_gts = is_located_in(prior_centers, + gt_bboxes_ignore) + is_prior_in_ignored_gts = is_prior_in_ignored_gts.any(dim=1) + assigned_gt_ids[is_prior_in_ignored_gts] = -1 + + # 4. Assign prior bboxes with class label according to its gt id. + # Default assigned label is the background (-1) + assigned_labels = assigned_gt_ids.new_full((num_priors, ), -1) + pos_inds = torch.nonzero(assigned_gt_ids > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds] - + 1] + # 5. Find pixels lying in the shadow of an object + shadowed_pixel_labels = pixels_in_gt_shadow.clone() + if pixels_in_gt_shadow.numel() > 0: + pixel_idx, gt_idx =\ + pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1] + assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \ + 'Some pixels are dually assigned to ignore and gt!' + shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1] + override = ( + assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1]) + if self.foreground_dominate: + # When a pixel is both positive and shadowed, set it as pos + shadowed_pixel_labels = shadowed_pixel_labels[~override] + else: + # When a pixel is both pos and shadowed, set it as shadowed + assigned_labels[pixel_idx[override]] = -1 + assigned_gt_ids[pixel_idx[override]] = 0 + + assign_result = AssignResult( + num_gts, assigned_gt_ids, None, labels=assigned_labels) + # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2) + assign_result.set_extra_property('shadowed_labels', + shadowed_pixel_labels) + return assign_result + + def assign_one_hot_gt_indices( + self, + is_prior_in_gt_core: Tensor, + is_prior_in_gt_shadow: Tensor, + gt_priority: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + """Assign only one gt index to each prior box. + + Gts with large gt_priority are more likely to be assigned. + + Args: + is_prior_in_gt_core (Tensor): Bool tensor indicating the prior + center is in the core area of a gt (e.g. 0-0.2). + Shape: (num_prior, num_gt). + is_prior_in_gt_shadow (Tensor): Bool tensor indicating the prior + center is in the shadowed area of a gt (e.g. 0.2-0.5). + Shape: (num_prior, num_gt). + gt_priority (Tensor): Priorities of gts. The gt with a higher + priority is more likely to be assigned to the bbox when the + bbox match with multiple gts. Shape: (num_gt, ). + + Returns: + tuple: Returns (assigned_gt_inds, shadowed_gt_inds). + + - assigned_gt_inds: The assigned gt index of each prior bbox \ + (i.e. index from 1 to num_gts). Shape: (num_prior, ). + - shadowed_gt_inds: shadowed gt indices. It is a tensor of \ + shape (num_ignore, 2) with first column being the shadowed prior \ + bbox indices and the second column the shadowed gt \ + indices (1-based). + """ + num_bboxes, num_gts = is_prior_in_gt_core.shape + + if gt_priority is None: + gt_priority = torch.arange( + num_gts, device=is_prior_in_gt_core.device) + assert gt_priority.size(0) == num_gts + # The bigger gt_priority, the more preferable to be assigned + # The assigned inds are by default 0 (background) + assigned_gt_inds = is_prior_in_gt_core.new_zeros((num_bboxes, ), + dtype=torch.long) + # Shadowed bboxes are assigned to be background. But the corresponding + # label is ignored during loss calculation, which is done through + # shadowed_gt_inds + shadowed_gt_inds = torch.nonzero(is_prior_in_gt_shadow, as_tuple=False) + if is_prior_in_gt_core.sum() == 0: # No gt match + shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue + return assigned_gt_inds, shadowed_gt_inds + + # The priority of each prior box and gt pair. If one prior box is + # matched bo multiple gts. Only the pair with the highest priority + # is saved + pair_priority = is_prior_in_gt_core.new_full((num_bboxes, num_gts), + -1, + dtype=torch.long) + + # Each bbox could match with multiple gts. + # The following codes deal with this situation + # Matched bboxes (to any gt). Shape: (num_pos_anchor, ) + inds_of_match = torch.any(is_prior_in_gt_core, dim=1) + # The matched gt index of each positive bbox. Length >= num_pos_anchor + # , since one bbox could match multiple gts + matched_bbox_gt_inds = torch.nonzero( + is_prior_in_gt_core, as_tuple=False)[:, 1] + # Assign priority to each bbox-gt pair. + pair_priority[is_prior_in_gt_core] = gt_priority[matched_bbox_gt_inds] + _, argmax_priority = pair_priority[inds_of_match].max(dim=1) + assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based + # Zero-out the assigned anchor box to filter the shadowed gt indices + is_prior_in_gt_core[inds_of_match, argmax_priority] = 0 + # Concat the shadowed indices due to overlapping with that out side of + # effective scale. shape: (total_num_ignore, 2) + shadowed_gt_inds = torch.cat( + (shadowed_gt_inds, + torch.nonzero(is_prior_in_gt_core, as_tuple=False)), + dim=0) + # Change `is_prior_in_gt_core` back to keep arguments intact. + is_prior_in_gt_core[inds_of_match, argmax_priority] = 1 + # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds` + if shadowed_gt_inds.numel() > 0: + shadowed_gt_inds[:, 1] += 1 + return assigned_gt_inds, shadowed_gt_inds diff --git a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc7af39b22cd6dc00248e330547176787c23963 --- /dev/null +++ b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py @@ -0,0 +1,227 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + +INF = 100000000 +EPS = 1.0e-7 + + +def center_of_mass(masks: Tensor, eps: float = 1e-7) -> Tensor: + """Compute the masks center of mass. + + Args: + masks: Mask tensor, has shape (num_masks, H, W). + eps: a small number to avoid normalizer to be zero. + Defaults to 1e-7. + Returns: + Tensor: The masks center of mass. Has shape (num_masks, 2). + """ + n, h, w = masks.shape + grid_h = torch.arange(h, device=masks.device)[:, None] + grid_w = torch.arange(w, device=masks.device) + normalizer = masks.sum(dim=(1, 2)).float().clamp(min=eps) + center_y = (masks * grid_h).sum(dim=(1, 2)) / normalizer + center_x = (masks * grid_w).sum(dim=(1, 2)) / normalizer + center = torch.cat([center_x[:, None], center_y[:, None]], dim=1) + return center + + +@TASK_UTILS.register_module() +class DynamicSoftLabelAssigner(BaseAssigner): + """Computes matching between predictions and ground truth with dynamic soft + label assignment. + + Args: + soft_center_radius (float): Radius of the soft center prior. + Defaults to 3.0. + topk (int): Select top-k predictions to calculate dynamic k + best matches for each gt. Defaults to 13. + iou_weight (float): The scale factor of iou cost. Defaults to 3.0. + iou_calculator (ConfigType): Config of overlaps Calculator. + Defaults to dict(type='BboxOverlaps2D'). + """ + + def __init__( + self, + soft_center_radius: float = 3.0, + topk: int = 13, + iou_weight: float = 3.0, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D') + ) -> None: + self.soft_center_radius = soft_center_radius + self.topk = topk + self.iou_weight = iou_weight + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to priors. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + Returns: + obj:`AssignResult`: The assigned result. + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + num_gt = gt_bboxes.size(0) + + decoded_bboxes = pred_instances.bboxes + pred_scores = pred_instances.scores + priors = pred_instances.priors + num_bboxes = decoded_bboxes.size(0) + + # assign 0 by default + assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), + 0, + dtype=torch.long) + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + assigned_labels = decoded_bboxes.new_full((num_bboxes, ), + -1, + dtype=torch.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + prior_center = priors[:, :2] + if isinstance(gt_bboxes, BaseBoxes): + is_in_gts = gt_bboxes.find_inside_points(prior_center) + else: + # Tensor boxes will be treated as horizontal boxes by defaults + lt_ = prior_center[:, None] - gt_bboxes[:, :2] + rb_ = gt_bboxes[:, 2:] - prior_center[:, None] + + deltas = torch.cat([lt_, rb_], dim=-1) + is_in_gts = deltas.min(dim=-1).values > 0 + + valid_mask = is_in_gts.sum(dim=1) > 0 + + valid_decoded_bbox = decoded_bboxes[valid_mask] + valid_pred_scores = pred_scores[valid_mask] + num_valid = valid_decoded_bbox.size(0) + + if num_valid == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) + assigned_labels = decoded_bboxes.new_full((num_bboxes, ), + -1, + dtype=torch.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + if hasattr(gt_instances, 'masks'): + gt_center = center_of_mass(gt_instances.masks, eps=EPS) + elif isinstance(gt_bboxes, BaseBoxes): + gt_center = gt_bboxes.centers + else: + # Tensor boxes will be treated as horizontal boxes by defaults + gt_center = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2.0 + valid_prior = priors[valid_mask] + strides = valid_prior[:, 2] + distance = (valid_prior[:, None, :2] - gt_center[None, :, :] + ).pow(2).sum(-1).sqrt() / strides[:, None] + soft_center_prior = torch.pow(10, distance - self.soft_center_radius) + + pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes) + iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight + + gt_onehot_label = ( + F.one_hot(gt_labels.to(torch.int64), + pred_scores.shape[-1]).float().unsqueeze(0).repeat( + num_valid, 1, 1)) + valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) + + soft_label = gt_onehot_label * pairwise_ious[..., None] + scale_factor = soft_label - valid_pred_scores.sigmoid() + soft_cls_cost = F.binary_cross_entropy_with_logits( + valid_pred_scores, soft_label, + reduction='none') * scale_factor.abs().pow(2.0) + soft_cls_cost = soft_cls_cost.sum(dim=-1) + + cost_matrix = soft_cls_cost + iou_cost + soft_center_prior + + matched_pred_ious, matched_gt_inds = self.dynamic_k_matching( + cost_matrix, pairwise_ious, num_gt, valid_mask) + + # convert to AssignResult format + assigned_gt_inds[valid_mask] = matched_gt_inds + 1 + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long() + max_overlaps = assigned_gt_inds.new_full((num_bboxes, ), + -INF, + dtype=torch.float32) + max_overlaps[valid_mask] = matched_pred_ious + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor, + num_gt: int, + valid_mask: Tensor) -> Tuple[Tensor, Tensor]: + """Use IoU and matching cost to calculate the dynamic top-k positive + targets. Same as SimOTA. + + Args: + cost (Tensor): Cost matrix. + pairwise_ious (Tensor): Pairwise iou matrix. + num_gt (int): Number of gt. + valid_mask (Tensor): Mask for valid bboxes. + + Returns: + tuple: matched ious and gt indexes. + """ + matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) + # select candidate topk ious for dynamic-k calculation + candidate_topk = min(self.topk, pairwise_ious.size(0)) + topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) + # calculate dynamic k for each gt + dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) + matching_matrix[:, gt_idx][pos_idx] = 1 + + del topk_ious, dynamic_ks, pos_idx + + prior_match_gt_mask = matching_matrix.sum(1) > 1 + if prior_match_gt_mask.sum() > 0: + cost_min, cost_argmin = torch.min( + cost[prior_match_gt_mask, :], dim=1) + matching_matrix[prior_match_gt_mask, :] *= 0 + matching_matrix[prior_match_gt_mask, cost_argmin] = 1 + # get foreground mask inside box and center prior + fg_mask_inboxes = matching_matrix.sum(1) > 0 + valid_mask[valid_mask.clone()] = fg_mask_inboxes + + matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + matched_pred_ious = (matching_matrix * + pairwise_ious).sum(1)[fg_mask_inboxes] + return matched_pred_ious, matched_gt_inds diff --git a/mmdet/models/task_modules/assigners/grid_assigner.py b/mmdet/models/task_modules/assigners/grid_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..d8935d2df2937f90c71599e5b45ed9a3dff8cd7e --- /dev/null +++ b/mmdet/models/task_modules/assigners/grid_assigner.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class GridAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, `0`, or a positive integer + indicating the ground truth index. + + - -1: don't care + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + pos_iou_thr (float): IoU threshold for positive bboxes. + neg_iou_thr (float or tuple[float, float]): IoU threshold for negative + bboxes. + min_pos_iou (float): Minimum iou for a bbox to be considered as a + positive bbox. Positive samples can have smaller IoU than + pos_iou_thr due to the 4th step (assign max IoU sample to each gt). + Defaults to 0. + gt_max_assign_all (bool): Whether to assign all bboxes with the same + highest overlap with some gt to that gt. + iou_calculator (:obj:`ConfigDict` or dict): Config of overlaps + Calculator. + """ + + def __init__( + self, + pos_iou_thr: float, + neg_iou_thr: Union[float, Tuple[float, float]], + min_pos_iou: float = .0, + gt_max_assign_all: bool = True, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D') + ) -> None: + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_pos_iou = min_pos_iou + self.gt_max_assign_all = gt_max_assign_all + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to bboxes. The process is very much like the max iou + assigner, except that positive samples are constrained within the cell + that the gt boxes fell in. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, 0, or a positive number. -1 means don't care, + 0 means negative sample, positive number is the index (1-based) of + assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every bbox to -1 + 2. assign proposals whose iou with all gts <= neg_iou_thr to 0 + 3. for each bbox within a cell, if the iou with its nearest gt > + pos_iou_thr and the center of that gt falls inside the cell, + assign it to that bbox + 4. for each gt bbox, assign its nearest proposals within the cell the + gt bbox falls in to itself. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + + priors = pred_instances.priors + responsible_flags = pred_instances.responsible_flags + + num_gts, num_priors = gt_bboxes.size(0), priors.size(0) + + # compute iou between all gt and priors + overlaps = self.iou_calculator(gt_bboxes, priors) + + # 1. assign -1 by default + assigned_gt_inds = overlaps.new_full((num_priors, ), + -1, + dtype=torch.long) + + if num_gts == 0 or num_priors == 0: + # No ground truth or priors, return empty assignment + max_overlaps = overlaps.new_zeros((num_priors, )) + if num_gts == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + assigned_labels = overlaps.new_full((num_priors, ), + -1, + dtype=torch.long) + return AssignResult( + num_gts, + assigned_gt_inds, + max_overlaps, + labels=assigned_labels) + + # 2. assign negative: below + # for each anchor, which gt best overlaps with it + # for each anchor, the max iou of all gts + # shape of max_overlaps == argmax_overlaps == num_priors + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + if isinstance(self.neg_iou_thr, float): + assigned_gt_inds[(max_overlaps >= 0) + & (max_overlaps <= self.neg_iou_thr)] = 0 + elif isinstance(self.neg_iou_thr, (tuple, list)): + assert len(self.neg_iou_thr) == 2 + assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0]) + & (max_overlaps <= self.neg_iou_thr[1])] = 0 + + # 3. assign positive: falls into responsible cell and above + # positive IOU threshold, the order matters. + # the prior condition of comparison is to filter out all + # unrelated anchors, i.e. not responsible_flags + overlaps[:, ~responsible_flags.type(torch.bool)] = -1. + + # calculate max_overlaps again, but this time we only consider IOUs + # for anchors responsible for prediction + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # for each gt, which anchor best overlaps with it + # for each gt, the max iou of all proposals + # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts + gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1) + + pos_inds = (max_overlaps > self.pos_iou_thr) & responsible_flags.type( + torch.bool) + assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 + + # 4. assign positive to max overlapped anchors within responsible cell + for i in range(num_gts): + if gt_max_overlaps[i] > self.min_pos_iou: + if self.gt_max_assign_all: + max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \ + responsible_flags.type(torch.bool) + assigned_gt_inds[max_iou_inds] = i + 1 + elif responsible_flags[gt_argmax_overlaps[i]]: + assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1 + + # assign labels of positive anchors + assigned_labels = assigned_gt_inds.new_full((num_priors, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - + 1] + + return AssignResult( + num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/hungarian_assigner.py b/mmdet/models/task_modules/assigners/hungarian_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..a6745a36cdc713c74f801f62dae0d8fe3d03828f --- /dev/null +++ b/mmdet/models/task_modules/assigners/hungarian_assigner.py @@ -0,0 +1,145 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from scipy.optimize import linear_sum_assignment +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class HungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of some components. + For DETR the costs are weighted sum of classification cost, regression L1 + cost and regression iou cost. The targets don't include the no_object, so + generally there are more predictions than targets. After the one-to-one + matching, the un-matched are treated as backgrounds. Thus each query + prediction will be assigned with `0` or a positive integer indicating the + ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + match_costs (:obj:`ConfigDict` or dict or \ + List[Union[:obj:`ConfigDict`, dict]]): Match cost configs. + """ + + def __init__( + self, match_costs: Union[List[Union[dict, ConfigDict]], dict, + ConfigDict] + ) -> None: + + if isinstance(match_costs, dict): + match_costs = [match_costs] + elif isinstance(match_costs, list): + assert len(match_costs) > 0, \ + 'match_costs must not be a empty list.' + + self.match_costs = [ + TASK_UTILS.build(match_cost) for match_cost in match_costs + ] + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> AssignResult: + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. It may includes ``masks``, with shape + (n, h, w) or (n, l). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + ``labels``, with shape (k, ) and ``masks``, with shape + (k, h, w) or (k, l). + img_meta (dict): Image information. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert isinstance(gt_instances.labels, Tensor) + num_gts, num_preds = len(gt_instances), len(pred_instances) + gt_labels = gt_instances.labels + device = gt_labels.device + + # 1. assign -1 by default + assigned_gt_inds = torch.full((num_preds, ), + -1, + dtype=torch.long, + device=device) + assigned_labels = torch.full((num_preds, ), + -1, + dtype=torch.long, + device=device) + + if num_gts == 0 or num_preds == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=None, + labels=assigned_labels) + + # 2. compute weighted cost + cost_list = [] + for match_cost in self.match_costs: + cost = match_cost( + pred_instances=pred_instances, + gt_instances=gt_instances, + img_meta=img_meta) + cost_list.append(cost) + cost = torch.stack(cost_list).sum(dim=0) + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to(device) + matched_col_inds = torch.from_numpy(matched_col_inds).to(device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=None, + labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/iou2d_calculator.py b/mmdet/models/task_modules/assigners/iou2d_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..b6daa94feb46ac2f188df41c7be59ffdc3905e58 --- /dev/null +++ b/mmdet/models/task_modules/assigners/iou2d_calculator.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import bbox_overlaps, get_box_tensor + + +def cast_tensor_type(x, scale=1., dtype=None): + if dtype == 'fp16': + # scale is for preventing overflows + x = (x / scale).half() + return x + + +@TASK_UTILS.register_module() +class BboxOverlaps2D: + """2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" + + def __init__(self, scale=1., dtype=None): + self.scale = scale + self.dtype = dtype + + def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False): + """Calculate IoU between 2D bboxes. + + Args: + bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) + in format, or shape (m, 5) in format. + bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) + in format, shape (m, 5) in format, or be empty. If ``is_aligned `` is ``True``, + then m and n must be equal. + mode (str): "iou" (intersection over union), "iof" (intersection + over foreground), or "giou" (generalized intersection over + union). + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + + Returns: + Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) + """ + bboxes1 = get_box_tensor(bboxes1) + bboxes2 = get_box_tensor(bboxes2) + assert bboxes1.size(-1) in [0, 4, 5] + assert bboxes2.size(-1) in [0, 4, 5] + if bboxes2.size(-1) == 5: + bboxes2 = bboxes2[..., :4] + if bboxes1.size(-1) == 5: + bboxes1 = bboxes1[..., :4] + + if self.dtype == 'fp16': + # change tensor type to save cpu and cuda memory and keep speed + bboxes1 = cast_tensor_type(bboxes1, self.scale, self.dtype) + bboxes2 = cast_tensor_type(bboxes2, self.scale, self.dtype) + overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) + if not overlaps.is_cuda and overlaps.dtype == torch.float16: + # resume cpu float32 + overlaps = overlaps.float() + return overlaps + + return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) + + def __repr__(self): + """str: a string describing the module""" + repr_str = self.__class__.__name__ + f'(' \ + f'scale={self.scale}, dtype={self.dtype})' + return repr_str + + +@TASK_UTILS.register_module() +class BboxOverlaps2D_GLIP(BboxOverlaps2D): + + def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False): + TO_REMOVE = 1 + area1 = (bboxes1[:, 2] - bboxes1[:, 0] + TO_REMOVE) * ( + bboxes1[:, 3] - bboxes1[:, 1] + TO_REMOVE) + area2 = (bboxes2[:, 2] - bboxes2[:, 0] + TO_REMOVE) * ( + bboxes2[:, 3] - bboxes2[:, 1] + TO_REMOVE) + + lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [N,M,2] + rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + iou = inter / (area1[:, None] + area2 - inter) + return iou diff --git a/mmdet/models/task_modules/assigners/match_cost.py b/mmdet/models/task_modules/assigners/match_cost.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc62f01f29138cba31ef2b41254f497351fe0d0 --- /dev/null +++ b/mmdet/models/task_modules/assigners/match_cost.py @@ -0,0 +1,525 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcywh + + +class BaseMatchCost: + """Base match cost class. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, weight: Union[float, int] = 1.) -> None: + self.weight = weight + + @abstractmethod + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + img_meta (dict, optional): Image information. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pass + + +@TASK_UTILS.register_module() +class BBoxL1Cost(BaseMatchCost): + """BBoxL1Cost. + + Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy' + and its coordinates are unnormalized. + + Args: + box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN. + Defaults to 'xyxy'. + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmdet.models.task_modules.assigners. + ... match_costs.match_cost import BBoxL1Cost + >>> import torch + >>> self = BBoxL1Cost() + >>> bbox_pred = torch.rand(1, 4) + >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(bbox_pred, gt_bboxes, factor) + tensor([[1.6172, 1.6422]]) + """ + + def __init__(self, + box_format: str = 'xyxy', + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + assert box_format in ['xyxy', 'xywh'] + self.box_format = box_format + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``bboxes`` inside is + predicted boxes with unnormalized coordinate + (x, y, x, y). + gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt + bboxes with unnormalized coordinate (x, y, x, y). + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pred_bboxes = pred_instances.bboxes + gt_bboxes = gt_instances.bboxes + + # convert box format + if self.box_format == 'xywh': + gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) + pred_bboxes = bbox_xyxy_to_cxcywh(pred_bboxes) + + # normalized + img_h, img_w = img_meta['img_shape'] + factor = gt_bboxes.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + gt_bboxes = gt_bboxes / factor + pred_bboxes = pred_bboxes / factor + + bbox_cost = torch.cdist(pred_bboxes, gt_bboxes, p=1) + return bbox_cost * self.weight + + +@TASK_UTILS.register_module() +class IoUCost(BaseMatchCost): + """IoUCost. + + Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy' + and its coordinates are unnormalized. + + Args: + iou_mode (str): iou mode such as 'iou', 'giou'. Defaults to 'giou'. + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmdet.models.task_modules.assigners. + ... match_costs.match_cost import IoUCost + >>> import torch + >>> self = IoUCost() + >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) + >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> self(bboxes, gt_bboxes) + tensor([[-0.1250, 0.1667], + [ 0.1667, -0.5000]]) + """ + + def __init__(self, iou_mode: str = 'giou', weight: Union[float, int] = 1.): + super().__init__(weight=weight) + self.iou_mode = iou_mode + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``bboxes`` inside is + predicted boxes with unnormalized coordinate + (x, y, x, y). + gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt + bboxes with unnormalized coordinate (x, y, x, y). + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pred_bboxes = pred_instances.bboxes + gt_bboxes = gt_instances.bboxes + + # avoid fp16 overflow + if pred_bboxes.dtype == torch.float16: + fp16 = True + pred_bboxes = pred_bboxes.to(torch.float32) + else: + fp16 = False + + overlaps = bbox_overlaps( + pred_bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) + + if fp16: + overlaps = overlaps.to(torch.float16) + + # The 1 is a constant that doesn't change the matching, so omitted. + iou_cost = -overlaps + return iou_cost * self.weight + + +@TASK_UTILS.register_module() +class ClassificationCost(BaseMatchCost): + """ClsSoftmaxCost. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmdet.models.task_modules.assigners. + ... match_costs.match_cost import ClassificationCost + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight: Union[float, int] = 1) -> None: + super().__init__(weight=weight) + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``scores`` inside is + predicted classification logits, of shape + (num_queries, num_class). + gt_instances (:obj:`InstanceData`): ``labels`` inside should have + shape (num_gt, ). + img_meta (Optional[dict]): _description_. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pred_scores = pred_instances.scores + gt_labels = gt_instances.labels + + pred_scores = pred_scores.softmax(-1) + cls_cost = -pred_scores[:, gt_labels] + + return cls_cost * self.weight + + +@TASK_UTILS.register_module() +class FocalLossCost(BaseMatchCost): + """FocalLossCost. + + Args: + alpha (Union[float, int]): focal_loss alpha. Defaults to 0.25. + gamma (Union[float, int]): focal_loss gamma. Defaults to 2. + eps (float): Defaults to 1e-12. + binary_input (bool): Whether the input is binary. Currently, + binary_input = True is for masks input, binary_input = False + is for label input. Defaults to False. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + alpha: Union[float, int] = 0.25, + gamma: Union[float, int] = 2, + eps: float = 1e-12, + binary_input: bool = False, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.alpha = alpha + self.gamma = gamma + self.eps = eps + self.binary_input = binary_input + + def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_queries, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + + def _mask_focal_loss_cost(self, cls_pred, gt_labels) -> Tensor: + """ + Args: + cls_pred (Tensor): Predicted classification logits. + in shape (num_queries, d1, ..., dn), dtype=torch.float32. + gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn), + dtype=torch.long. Labels should be binary. + + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_queries, num_gt). + """ + cls_pred = cls_pred.flatten(1) + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost / n * self.weight + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``scores`` or ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``labels`` or ``mask``. + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + if self.binary_input: + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + return self._mask_focal_loss_cost(pred_masks, gt_masks) + else: + pred_scores = pred_instances.scores + gt_labels = gt_instances.labels + return self._focal_loss_cost(pred_scores, gt_labels) + + +@TASK_UTILS.register_module() +class BinaryFocalLossCost(FocalLossCost): + + def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_queries, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.flatten(1) + gt_labels = gt_labels.flatten(1).float() + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost * self.weight + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``scores`` or ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``labels`` or ``mask``. + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + # gt_instances.text_token_mask is a repeated tensor of the same length + # of instances. Only gt_instances.text_token_mask[0] is useful + text_token_mask = torch.nonzero( + gt_instances.text_token_mask[0]).squeeze(-1) + pred_scores = pred_instances.scores[:, text_token_mask] + gt_labels = gt_instances.positive_maps[:, text_token_mask] + return self._focal_loss_cost(pred_scores, gt_labels) + + +@TASK_UTILS.register_module() +class DiceCost(BaseMatchCost): + """Cost of mask assignments based on dice losses. + + Args: + pred_act (bool): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float): Defaults to 1e-3. + naive_dice (bool): If True, use the naive dice loss + in which the power of the number in the denominator is + the first power. If False, use the second power that + is adopted by K-Net and SOLO. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + pred_act: bool = False, + eps: float = 1e-3, + naive_dice: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.pred_act = pred_act + self.eps = eps + self.naive_dice = naive_dice + + def _binary_mask_dice_loss(self, mask_preds: Tensor, + gt_masks: Tensor) -> Tensor: + """ + Args: + mask_preds (Tensor): Mask prediction in shape (num_queries, *). + gt_masks (Tensor): Ground truth in shape (num_gt, *) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (num_queries, num_gt). + """ + mask_preds = mask_preds.flatten(1) + gt_masks = gt_masks.flatten(1).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + if self.naive_dice: + denominator = mask_preds.sum(-1)[:, None] + \ + gt_masks.sum(-1)[None, :] + else: + denominator = mask_preds.pow(2).sum(1)[:, None] + \ + gt_masks.pow(2).sum(1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``mask``. + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + + if self.pred_act: + pred_masks = pred_masks.sigmoid() + dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks) + return dice_cost * self.weight + + +@TASK_UTILS.register_module() +class CrossEntropyLossCost(BaseMatchCost): + """CrossEntropyLossCost. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + use_sigmoid: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred: Tensor, + gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or + (num_queries, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + + Returns: + Tensor: Cross entropy cost matrix in shape (num_queries, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits( + cls_pred, torch.ones_like(cls_pred), reduction='none') + neg = F.binary_cross_entropy_with_logits( + cls_pred, torch.zeros_like(cls_pred), reduction='none') + cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ + torch.einsum('nc,mc->nm', neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``scores`` or ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``labels`` or ``masks``. + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(pred_masks, gt_masks) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/mmdet/models/task_modules/assigners/max_iou_assigner.py b/mmdet/models/task_modules/assigners/max_iou_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..71da54429ae0526bf52277bc3b1d24630acceaed --- /dev/null +++ b/mmdet/models/task_modules/assigners/max_iou_assigner.py @@ -0,0 +1,325 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Union + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +def _perm_box(bboxes, + iou_calculator, + iou_thr=0.97, + perm_range=0.01, + counter=0, + max_iter=5): + """Compute the permuted bboxes. + + Args: + bboxes (Tensor): Shape (n, 4) for , "xyxy" format. + iou_calculator (obj): Overlaps Calculator. + iou_thr (float): The permuted bboxes should have IoU > iou_thr. + perm_range (float): The scale of permutation. + counter (int): Counter of permutation iteration. + max_iter (int): The max iterations of permutation. + Returns: + Tensor: The permuted bboxes. + """ + ori_bboxes = copy.deepcopy(bboxes) + is_valid = True + N = bboxes.size(0) + perm_factor = bboxes.new_empty(N, 4).uniform_(1 - perm_range, + 1 + perm_range) + bboxes *= perm_factor + new_wh = bboxes[:, 2:] - bboxes[:, :2] + if (new_wh <= 0).any(): + is_valid = False + iou = iou_calculator(ori_bboxes.unique(dim=0), bboxes) + if (iou < iou_thr).any(): + is_valid = False + if not is_valid and counter < max_iter: + return _perm_box( + ori_bboxes, + iou_calculator, + perm_range=max(perm_range - counter * 0.001, 1e-3), + counter=counter + 1) + return bboxes + + +def perm_repeat_bboxes(bboxes, iou_calculator=None, perm_repeat_cfg=None): + """Permute the repeated bboxes. + + Args: + bboxes (Tensor): Shape (n, 4) for , "xyxy" format. + iou_calculator (obj): Overlaps Calculator. + perm_repeat_cfg (Dict): Config of permutation. + Returns: + Tensor: Bboxes after permuted repeated bboxes. + """ + assert isinstance(bboxes, torch.Tensor) + if iou_calculator is None: + import torchvision + iou_calculator = torchvision.ops.box_iou + bboxes = copy.deepcopy(bboxes) + unique_bboxes = bboxes.unique(dim=0) + iou_thr = perm_repeat_cfg.get('iou_thr', 0.97) + perm_range = perm_repeat_cfg.get('perm_range', 0.01) + for box in unique_bboxes: + inds = (bboxes == box).sum(-1).float() == 4 + if inds.float().sum().item() == 1: + continue + bboxes[inds] = _perm_box( + bboxes[inds], + iou_calculator, + iou_thr=iou_thr, + perm_range=perm_range, + counter=0) + return bboxes + + +@TASK_UTILS.register_module() +class MaxIoUAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, or a semi-positive integer + indicating the ground truth index. + + - -1: negative sample, no assigned gt + - semi-positive integer: positive sample, index (0-based) of assigned gt + + Args: + pos_iou_thr (float): IoU threshold for positive bboxes. + neg_iou_thr (float or tuple): IoU threshold for negative bboxes. + min_pos_iou (float): Minimum iou for a bbox to be considered as a + positive bbox. Positive samples can have smaller IoU than + pos_iou_thr due to the 4th step (assign max IoU sample to each gt). + `min_pos_iou` is set to avoid assigning bboxes that have extremely + small iou with GT as positive samples. It brings about 0.3 mAP + improvements in 1x schedule but does not affect the performance of + 3x schedule. More comparisons can be found in + `PR #7464 `_. + gt_max_assign_all (bool): Whether to assign all bboxes with the same + highest overlap with some gt to that gt. + ignore_iof_thr (float): IoF threshold for ignoring bboxes (if + `gt_bboxes_ignore` is specified). Negative values mean not + ignoring any bboxes. + ignore_wrt_candidates (bool): Whether to compute the iof between + `bboxes` and `gt_bboxes_ignore`, or the contrary. + match_low_quality (bool): Whether to allow low quality matches. This is + usually allowed for RPN and single stage detectors, but not allowed + in the second stage. Details are demonstrated in Step 4. + gpu_assign_thr (int): The upper bound of the number of GT for GPU + assign. When the number of gt is above this threshold, will assign + on CPU device. Negative values mean not assign on CPU. + iou_calculator (dict): Config of overlaps Calculator. + perm_repeat_gt_cfg (dict): Config of permute repeated gt bboxes. + """ + + def __init__(self, + pos_iou_thr: float, + neg_iou_thr: Union[float, tuple], + min_pos_iou: float = .0, + gt_max_assign_all: bool = True, + ignore_iof_thr: float = -1, + ignore_wrt_candidates: bool = True, + match_low_quality: bool = True, + gpu_assign_thr: float = -1, + iou_calculator: dict = dict(type='BboxOverlaps2D'), + perm_repeat_gt_cfg=None): + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_pos_iou = min_pos_iou + self.gt_max_assign_all = gt_max_assign_all + self.ignore_iof_thr = ignore_iof_thr + self.ignore_wrt_candidates = ignore_wrt_candidates + self.gpu_assign_thr = gpu_assign_thr + self.match_low_quality = match_low_quality + self.iou_calculator = TASK_UTILS.build(iou_calculator) + self.perm_repeat_gt_cfg = perm_repeat_gt_cfg + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to bboxes. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, or a semi-positive number. -1 means negative + sample, semi-positive number is the index (0-based) of assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every bbox to the background + 2. assign proposals whose iou with all gts < neg_iou_thr to 0 + 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, + assign it to that bbox + 4. for each gt bbox, assign its nearest proposals (may be more than + one) to itself + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + + Example: + >>> from mmengine.structures import InstanceData + >>> self = MaxIoUAssigner(0.5, 0.5) + >>> pred_instances = InstanceData() + >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10], + ... [10, 10, 20, 20]]) + >>> gt_instances = InstanceData() + >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]]) + >>> gt_instances.labels = torch.Tensor([0]) + >>> assign_result = self.assign(pred_instances, gt_instances) + >>> expected_gt_inds = torch.LongTensor([1, 0]) + >>> assert torch.all(assign_result.gt_inds == expected_gt_inds) + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + gt_labels = gt_instances.labels + if gt_instances_ignore is not None: + gt_bboxes_ignore = gt_instances_ignore.bboxes + else: + gt_bboxes_ignore = None + + assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( + gt_bboxes.shape[0] > self.gpu_assign_thr) else False + # compute overlap and assign gt on CPU when number of GT is large + if assign_on_cpu: + device = priors.device + priors = priors.cpu() + gt_bboxes = gt_bboxes.cpu() + gt_labels = gt_labels.cpu() + if gt_bboxes_ignore is not None: + gt_bboxes_ignore = gt_bboxes_ignore.cpu() + + if self.perm_repeat_gt_cfg is not None and priors.numel() > 0: + gt_bboxes_unique = perm_repeat_bboxes(gt_bboxes, + self.iou_calculator, + self.perm_repeat_gt_cfg) + else: + gt_bboxes_unique = gt_bboxes + overlaps = self.iou_calculator(gt_bboxes_unique, priors) + + if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None + and gt_bboxes_ignore.numel() > 0 and priors.numel() > 0): + if self.ignore_wrt_candidates: + ignore_overlaps = self.iou_calculator( + priors, gt_bboxes_ignore, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + else: + ignore_overlaps = self.iou_calculator( + gt_bboxes_ignore, priors, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) + overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 + + assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) + if assign_on_cpu: + assign_result.gt_inds = assign_result.gt_inds.to(device) + assign_result.max_overlaps = assign_result.max_overlaps.to(device) + if assign_result.labels is not None: + assign_result.labels = assign_result.labels.to(device) + return assign_result + + def assign_wrt_overlaps(self, overlaps: Tensor, + gt_labels: Tensor) -> AssignResult: + """Assign w.r.t. the overlaps of priors with gts. + + Args: + overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes, + shape(k, n). + gt_labels (Tensor): Labels of k gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + num_gts, num_bboxes = overlaps.size(0), overlaps.size(1) + + # 1. assign -1 by default + assigned_gt_inds = overlaps.new_full((num_bboxes, ), + -1, + dtype=torch.long) + + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = overlaps.new_zeros((num_bboxes, )) + assigned_labels = overlaps.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=max_overlaps, + labels=assigned_labels) + + # for each anchor, which gt best overlaps with it + # for each anchor, the max iou of all gts + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + # for each gt, which anchor best overlaps with it + # for each gt, the max iou of all proposals + gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1) + + # 2. assign negative: below + # the negative inds are set to be 0 + if isinstance(self.neg_iou_thr, float): + assigned_gt_inds[(max_overlaps >= 0) + & (max_overlaps < self.neg_iou_thr)] = 0 + elif isinstance(self.neg_iou_thr, tuple): + assert len(self.neg_iou_thr) == 2 + assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0]) + & (max_overlaps < self.neg_iou_thr[1])] = 0 + + # 3. assign positive: above positive IoU threshold + pos_inds = max_overlaps >= self.pos_iou_thr + assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 + + if self.match_low_quality: + # Low-quality matching will overwrite the assigned_gt_inds assigned + # in Step 3. Thus, the assigned gt might not be the best one for + # prediction. + # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2, + # bbox 1 will be assigned as the best target for bbox A in step 3. + # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's + # assigned_gt_inds will be overwritten to be bbox 2. + # This might be the reason that it is not used in ROI Heads. + for i in range(num_gts): + if gt_max_overlaps[i] >= self.min_pos_iou: + if self.gt_max_assign_all: + max_iou_inds = overlaps[i, :] == gt_max_overlaps[i] + assigned_gt_inds[max_iou_inds] = i + 1 + else: + assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1 + + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - + 1] + + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=max_overlaps, + labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/multi_instance_assigner.py b/mmdet/models/task_modules/assigners/multi_instance_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba32afe856b3c2ad03ed89562d080f15b6ccf30 --- /dev/null +++ b/mmdet/models/task_modules/assigners/multi_instance_assigner.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from .assign_result import AssignResult +from .max_iou_assigner import MaxIoUAssigner + + +@TASK_UTILS.register_module() +class MultiInstanceAssigner(MaxIoUAssigner): + """Assign a corresponding gt bbox or background to each proposal bbox. If + we need to use a proposal box to generate multiple predict boxes, + `MultiInstanceAssigner` can assign multiple gt to each proposal box. + + Args: + num_instance (int): How many bboxes are predicted by each proposal box. + """ + + def __init__(self, num_instance: int = 2, **kwargs): + super().__init__(**kwargs) + self.num_instance = num_instance + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to bboxes. + + This method assign gt bboxes to every bbox (proposal/anchor), each bbox + is assigned a set of gts, and the number of gts in this set is defined + by `self.num_instance`. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + # Set the FG label to 1 and add ignored annotations + gt_labels = gt_instances.labels + 1 + if gt_instances_ignore is not None: + gt_bboxes_ignore = gt_instances_ignore.bboxes + if hasattr(gt_instances_ignore, 'labels'): + gt_labels_ignore = gt_instances_ignore.labels + else: + gt_labels_ignore = torch.ones_like(gt_bboxes_ignore)[:, 0] * -1 + else: + gt_bboxes_ignore = None + gt_labels_ignore = None + + assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( + gt_bboxes.shape[0] > self.gpu_assign_thr) else False + # compute overlap and assign gt on CPU when number of GT is large + if assign_on_cpu: + device = priors.device + priors = priors.cpu() + gt_bboxes = gt_bboxes.cpu() + gt_labels = gt_labels.cpu() + if gt_bboxes_ignore is not None: + gt_bboxes_ignore = gt_bboxes_ignore.cpu() + gt_labels_ignore = gt_labels_ignore.cpu() + + if gt_bboxes_ignore is not None: + all_bboxes = torch.cat([gt_bboxes, gt_bboxes_ignore], dim=0) + all_labels = torch.cat([gt_labels, gt_labels_ignore], dim=0) + else: + all_bboxes = gt_bboxes + all_labels = gt_labels + all_priors = torch.cat([priors, all_bboxes], dim=0) + + overlaps_normal = self.iou_calculator( + all_priors, all_bboxes, mode='iou') + overlaps_ignore = self.iou_calculator( + all_priors, all_bboxes, mode='iof') + gt_ignore_mask = all_labels.eq(-1).repeat(all_priors.shape[0], 1) + overlaps_normal = overlaps_normal * ~gt_ignore_mask + overlaps_ignore = overlaps_ignore * gt_ignore_mask + + overlaps_normal, overlaps_normal_indices = overlaps_normal.sort( + descending=True, dim=1) + overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort( + descending=True, dim=1) + + # select the roi with the higher score + max_overlaps_normal = overlaps_normal[:, :self.num_instance].flatten() + gt_assignment_normal = overlaps_normal_indices[:, :self. + num_instance].flatten() + max_overlaps_ignore = overlaps_ignore[:, :self.num_instance].flatten() + gt_assignment_ignore = overlaps_ignore_indices[:, :self. + num_instance].flatten() + + # ignore or not + ignore_assign_mask = (max_overlaps_normal < self.pos_iou_thr) * ( + max_overlaps_ignore > max_overlaps_normal) + overlaps = (max_overlaps_normal * ~ignore_assign_mask) + ( + max_overlaps_ignore * ignore_assign_mask) + gt_assignment = (gt_assignment_normal * ~ignore_assign_mask) + ( + gt_assignment_ignore * ignore_assign_mask) + + assigned_labels = all_labels[gt_assignment] + fg_mask = (overlaps >= self.pos_iou_thr) * (assigned_labels != -1) + bg_mask = (overlaps < self.neg_iou_thr) * (overlaps >= 0) + assigned_labels[fg_mask] = 1 + assigned_labels[bg_mask] = 0 + + overlaps = overlaps.reshape(-1, self.num_instance) + gt_assignment = gt_assignment.reshape(-1, self.num_instance) + assigned_labels = assigned_labels.reshape(-1, self.num_instance) + + assign_result = AssignResult( + num_gts=all_bboxes.size(0), + gt_inds=gt_assignment, + max_overlaps=overlaps, + labels=assigned_labels) + + if assign_on_cpu: + assign_result.gt_inds = assign_result.gt_inds.to(device) + assign_result.max_overlaps = assign_result.max_overlaps.to(device) + if assign_result.labels is not None: + assign_result.labels = assign_result.labels.to(device) + return assign_result diff --git a/mmdet/models/task_modules/assigners/point_assigner.py b/mmdet/models/task_modules/assigners/point_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..4da60a490b0022ac76c46db8a34f814bc9da8e2e --- /dev/null +++ b/mmdet/models/task_modules/assigners/point_assigner.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class PointAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each point. + + Each proposals will be assigned with `0`, or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + """ + + def __init__(self, scale: int = 4, pos_num: int = 3) -> None: + self.scale = scale + self.pos_num = pos_num + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to points. + + This method assign a gt bbox to every points set, each points set + will be assigned with the background_label (-1), or a label number. + -1 is background, and semi-positive number is the index (0-based) of + assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every points to the background_label (-1) + 2. A point is assigned to some gt bbox if + (i) the point is within the k closest points to the gt bbox + (ii) the distance between this point and the gt is smaller than + other gt bboxes + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + + + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + Returns: + :obj:`AssignResult`: The assign result. + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + # points to be assigned, shape(n, 3) while last + # dimension stands for (x, y, stride). + points = pred_instances.priors + + num_points = points.shape[0] + num_gts = gt_bboxes.shape[0] + + if num_gts == 0 or num_points == 0: + # If no truth assign everything to the background + assigned_gt_inds = points.new_full((num_points, ), + 0, + dtype=torch.long) + assigned_labels = points.new_full((num_points, ), + -1, + dtype=torch.long) + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=None, + labels=assigned_labels) + + points_xy = points[:, :2] + points_stride = points[:, 2] + points_lvl = torch.log2( + points_stride).int() # [3...,4...,5...,6...,7...] + lvl_min, lvl_max = points_lvl.min(), points_lvl.max() + + # assign gt box + gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 + gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) + scale = self.scale + gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + + torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() + gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) + + # stores the assigned gt index of each point + assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) + # stores the assigned gt dist (to this point) of each point + assigned_gt_dist = points.new_full((num_points, ), float('inf')) + points_range = torch.arange(points.shape[0]) + + for idx in range(num_gts): + gt_lvl = gt_bboxes_lvl[idx] + # get the index of points in this level + lvl_idx = gt_lvl == points_lvl + points_index = points_range[lvl_idx] + # get the points in this level + lvl_points = points_xy[lvl_idx, :] + # get the center point of gt + gt_point = gt_bboxes_xy[[idx], :] + # get width and height of gt + gt_wh = gt_bboxes_wh[[idx], :] + # compute the distance between gt center and + # all points in this level + points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) + # find the nearest k points to gt center in this level + min_dist, min_dist_index = torch.topk( + points_gt_dist, self.pos_num, largest=False) + # the index of nearest k points to gt center in this level + min_dist_points_index = points_index[min_dist_index] + # The less_than_recorded_index stores the index + # of min_dist that is less then the assigned_gt_dist. Where + # assigned_gt_dist stores the dist from previous assigned gt + # (if exist) to each point. + less_than_recorded_index = min_dist < assigned_gt_dist[ + min_dist_points_index] + # The min_dist_points_index stores the index of points satisfy: + # (1) it is k nearest to current gt center in this level. + # (2) it is closer to current gt center than other gt center. + min_dist_points_index = min_dist_points_index[ + less_than_recorded_index] + # assign the result + assigned_gt_inds[min_dist_points_index] = idx + 1 + assigned_gt_dist[min_dist_points_index] = min_dist[ + less_than_recorded_index] + + assigned_labels = assigned_gt_inds.new_full((num_points, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - + 1] + + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=None, + labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/region_assigner.py b/mmdet/models/task_modules/assigners/region_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..df549143086c1195efaf12a2f3e81259da0e6c97 --- /dev/null +++ b/mmdet/models/task_modules/assigners/region_assigner.py @@ -0,0 +1,239 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from ..prior_generators import anchor_inside_flags +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +def calc_region( + bbox: Tensor, + ratio: float, + stride: int, + featmap_size: Optional[Tuple[int, int]] = None) -> Tuple[Tensor]: + """Calculate region of the box defined by the ratio, the ratio is from the + center of the box to every edge.""" + # project bbox on the feature + f_bbox = bbox / stride + x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2]) + y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3]) + x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2]) + y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3]) + if featmap_size is not None: + x1 = x1.clamp(min=0, max=featmap_size[1]) + y1 = y1.clamp(min=0, max=featmap_size[0]) + x2 = x2.clamp(min=0, max=featmap_size[1]) + y2 = y2.clamp(min=0, max=featmap_size[0]) + return (x1, y1, x2, y2) + + +def anchor_ctr_inside_region_flags(anchors: Tensor, stride: int, + region: Tuple[Tensor]) -> Tensor: + """Get the flag indicate whether anchor centers are inside regions.""" + x1, y1, x2, y2 = region + f_anchors = anchors / stride + x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5 + y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5 + flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2) + return flags + + +@TASK_UTILS.register_module() +class RegionAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, `0`, or a positive integer + indicating the ground truth index. + + - -1: don't care + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + center_ratio (float): ratio of the region in the center of the bbox to + define positive sample. + ignore_ratio (float): ratio of the region to define ignore samples. + """ + + def __init__(self, + center_ratio: float = 0.2, + ignore_ratio: float = 0.5) -> None: + self.center_ratio = center_ratio + self.ignore_ratio = ignore_ratio + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: dict, + featmap_sizes: List[Tuple[int, int]], + num_level_anchors: List[int], + anchor_scale: int, + anchor_strides: List[int], + gt_instances_ignore: Optional[InstanceData] = None, + allowed_border: int = 0) -> AssignResult: + """Assign gt to anchors. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, 0, or a positive number. -1 means don't care, + 0 means negative sample, positive number is the index (1-based) of + assigned gt. + + The assignment is done in following steps, and the order matters. + + 1. Assign every anchor to 0 (negative) + 2. (For each gt_bboxes) Compute ignore flags based on ignore_region + then assign -1 to anchors w.r.t. ignore flags + 3. (For each gt_bboxes) Compute pos flags based on center_region then + assign gt_bboxes to anchors w.r.t. pos flags + 4. (For each gt_bboxes) Compute ignore flags based on adjacent anchor + level then assign -1 to anchors w.r.t. ignore flags + 5. Assign anchor outside of image to -1 + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + img_meta (dict): Meta info of image. + featmap_sizes (list[tuple[int, int]]): Feature map size each level. + num_level_anchors (list[int]): The number of anchors in each level. + anchor_scale (int): Scale of the anchor. + anchor_strides (list[int]): Stride of the anchor. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + allowed_border (int, optional): The border to allow the valid + anchor. Defaults to 0. + + Returns: + :obj:`AssignResult`: The assign result. + """ + if gt_instances_ignore is not None: + raise NotImplementedError + + num_gts = len(gt_instances) + num_bboxes = len(pred_instances) + + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + flat_anchors = pred_instances.priors + flat_valid_flags = pred_instances.valid_flags + mlvl_anchors = torch.split(flat_anchors, num_level_anchors) + + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = gt_bboxes.new_zeros((num_bboxes, )) + assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ), + dtype=torch.long) + assigned_labels = gt_bboxes.new_full((num_bboxes, ), + -1, + dtype=torch.long) + return AssignResult( + num_gts=num_gts, + gt_inds=assigned_gt_inds, + max_overlaps=max_overlaps, + labels=assigned_labels) + + num_lvls = len(mlvl_anchors) + r1 = (1 - self.center_ratio) / 2 + r2 = (1 - self.ignore_ratio) / 2 + + scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * + (gt_bboxes[:, 3] - gt_bboxes[:, 1])) + min_anchor_size = scale.new_full( + (1, ), float(anchor_scale * anchor_strides[0])) + target_lvls = torch.floor( + torch.log2(scale) - torch.log2(min_anchor_size) + 0.5) + target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long() + + # 1. assign 0 (negative) by default + mlvl_assigned_gt_inds = [] + mlvl_ignore_flags = [] + for lvl in range(num_lvls): + assigned_gt_inds = gt_bboxes.new_full((num_level_anchors[lvl], ), + 0, + dtype=torch.long) + ignore_flags = torch.zeros_like(assigned_gt_inds) + mlvl_assigned_gt_inds.append(assigned_gt_inds) + mlvl_ignore_flags.append(ignore_flags) + + for gt_id in range(num_gts): + lvl = target_lvls[gt_id].item() + featmap_size = featmap_sizes[lvl] + stride = anchor_strides[lvl] + anchors = mlvl_anchors[lvl] + gt_bbox = gt_bboxes[gt_id, :4] + + # Compute regions + ignore_region = calc_region(gt_bbox, r2, stride, featmap_size) + ctr_region = calc_region(gt_bbox, r1, stride, featmap_size) + + # 2. Assign -1 to ignore flags + ignore_flags = anchor_ctr_inside_region_flags( + anchors, stride, ignore_region) + mlvl_assigned_gt_inds[lvl][ignore_flags] = -1 + + # 3. Assign gt_bboxes to pos flags + pos_flags = anchor_ctr_inside_region_flags(anchors, stride, + ctr_region) + mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1 + + # 4. Assign -1 to ignore adjacent lvl + if lvl > 0: + d_lvl = lvl - 1 + d_anchors = mlvl_anchors[d_lvl] + d_featmap_size = featmap_sizes[d_lvl] + d_stride = anchor_strides[d_lvl] + d_ignore_region = calc_region(gt_bbox, r2, d_stride, + d_featmap_size) + ignore_flags = anchor_ctr_inside_region_flags( + d_anchors, d_stride, d_ignore_region) + mlvl_ignore_flags[d_lvl][ignore_flags] = 1 + if lvl < num_lvls - 1: + u_lvl = lvl + 1 + u_anchors = mlvl_anchors[u_lvl] + u_featmap_size = featmap_sizes[u_lvl] + u_stride = anchor_strides[u_lvl] + u_ignore_region = calc_region(gt_bbox, r2, u_stride, + u_featmap_size) + ignore_flags = anchor_ctr_inside_region_flags( + u_anchors, u_stride, u_ignore_region) + mlvl_ignore_flags[u_lvl][ignore_flags] = 1 + + # 4. (cont.) Assign -1 to ignore adjacent lvl + for lvl in range(num_lvls): + ignore_flags = mlvl_ignore_flags[lvl] + mlvl_assigned_gt_inds[lvl][ignore_flags == 1] = -1 + + # 5. Assign -1 to anchor outside of image + flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds) + assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] == + flat_valid_flags.shape[0]) + inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags, + img_meta['img_shape'], + allowed_border) + outside_flags = ~inside_flags + flat_assigned_gt_inds[outside_flags] = -1 + + assigned_labels = torch.zeros_like(flat_assigned_gt_inds) + pos_flags = flat_assigned_gt_inds > 0 + assigned_labels[pos_flags] = gt_labels[flat_assigned_gt_inds[pos_flags] + - 1] + + return AssignResult( + num_gts=num_gts, + gt_inds=flat_assigned_gt_inds, + max_overlaps=None, + labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/sim_ota_assigner.py b/mmdet/models/task_modules/assigners/sim_ota_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..d54a8b91d132d9bf661267de666bfed7e915a65a --- /dev/null +++ b/mmdet/models/task_modules/assigners/sim_ota_assigner.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + +INF = 100000.0 +EPS = 1.0e-7 + + +@TASK_UTILS.register_module() +class SimOTAAssigner(BaseAssigner): + """Computes matching between predictions and ground truth. + + Args: + center_radius (float): Ground truth center size + to judge whether a prior is in center. Defaults to 2.5. + candidate_topk (int): The candidate top-k which used to + get top-k ious to calculate dynamic-k. Defaults to 10. + iou_weight (float): The scale factor for regression + iou cost. Defaults to 3.0. + cls_weight (float): The scale factor for classification + cost. Defaults to 1.0. + iou_calculator (ConfigType): Config of overlaps Calculator. + Defaults to dict(type='BboxOverlaps2D'). + """ + + def __init__(self, + center_radius: float = 2.5, + candidate_topk: int = 10, + iou_weight: float = 3.0, + cls_weight: float = 1.0, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D')): + self.center_radius = center_radius + self.candidate_topk = candidate_topk + self.iou_weight = iou_weight + self.cls_weight = cls_weight + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs) -> AssignResult: + """Assign gt to priors using SimOTA. + + Args: + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + Returns: + obj:`AssignResult`: The assigned result. + """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + num_gt = gt_bboxes.size(0) + + decoded_bboxes = pred_instances.bboxes + pred_scores = pred_instances.scores + priors = pred_instances.priors + num_bboxes = decoded_bboxes.size(0) + + # assign 0 by default + assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), + 0, + dtype=torch.long) + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) + assigned_labels = decoded_bboxes.new_full((num_bboxes, ), + -1, + dtype=torch.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( + priors, gt_bboxes) + valid_decoded_bbox = decoded_bboxes[valid_mask] + valid_pred_scores = pred_scores[valid_mask] + num_valid = valid_decoded_bbox.size(0) + if num_valid == 0: + # No valid bboxes, return empty assignment + max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) + assigned_labels = decoded_bboxes.new_full((num_bboxes, ), + -1, + dtype=torch.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes) + iou_cost = -torch.log(pairwise_ious + EPS) + + gt_onehot_label = ( + F.one_hot(gt_labels.to(torch.int64), + pred_scores.shape[-1]).float().unsqueeze(0).repeat( + num_valid, 1, 1)) + + valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) + # disable AMP autocast and calculate BCE with FP32 to avoid overflow + with torch.cuda.amp.autocast(enabled=False): + cls_cost = ( + F.binary_cross_entropy( + valid_pred_scores.to(dtype=torch.float32), + gt_onehot_label, + reduction='none', + ).sum(-1).to(dtype=valid_pred_scores.dtype)) + + cost_matrix = ( + cls_cost * self.cls_weight + iou_cost * self.iou_weight + + (~is_in_boxes_and_center) * INF) + + matched_pred_ious, matched_gt_inds = \ + self.dynamic_k_matching( + cost_matrix, pairwise_ious, num_gt, valid_mask) + + # convert to AssignResult format + assigned_gt_inds[valid_mask] = matched_gt_inds + 1 + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long() + max_overlaps = assigned_gt_inds.new_full((num_bboxes, ), + -INF, + dtype=torch.float32) + max_overlaps[valid_mask] = matched_pred_ious + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + def get_in_gt_and_in_center_info( + self, priors: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]: + """Get the information of which prior is in gt bboxes and gt center + priors.""" + num_gt = gt_bboxes.size(0) + + repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt) + repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt) + repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt) + repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt) + + # is prior centers in gt bboxes, shape: [n_prior, n_gt] + l_ = repeated_x - gt_bboxes[:, 0] + t_ = repeated_y - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - repeated_x + b_ = gt_bboxes[:, 3] - repeated_y + + deltas = torch.stack([l_, t_, r_, b_], dim=1) + is_in_gts = deltas.min(dim=1).values > 0 + is_in_gts_all = is_in_gts.sum(dim=1) > 0 + + # is prior centers in gt centers + gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + ct_box_l = gt_cxs - self.center_radius * repeated_stride_x + ct_box_t = gt_cys - self.center_radius * repeated_stride_y + ct_box_r = gt_cxs + self.center_radius * repeated_stride_x + ct_box_b = gt_cys + self.center_radius * repeated_stride_y + + cl_ = repeated_x - ct_box_l + ct_ = repeated_y - ct_box_t + cr_ = ct_box_r - repeated_x + cb_ = ct_box_b - repeated_y + + ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1) + is_in_cts = ct_deltas.min(dim=1).values > 0 + is_in_cts_all = is_in_cts.sum(dim=1) > 0 + + # in boxes or in centers, shape: [num_priors] + is_in_gts_or_centers = is_in_gts_all | is_in_cts_all + + # both in boxes and centers, shape: [num_fg, num_gt] + is_in_boxes_and_centers = ( + is_in_gts[is_in_gts_or_centers, :] + & is_in_cts[is_in_gts_or_centers, :]) + return is_in_gts_or_centers, is_in_boxes_and_centers + + def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor, + num_gt: int, + valid_mask: Tensor) -> Tuple[Tensor, Tensor]: + """Use IoU and matching cost to calculate the dynamic top-k positive + targets.""" + matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) + # select candidate topk ious for dynamic-k calculation + candidate_topk = min(self.candidate_topk, pairwise_ious.size(0)) + topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) + # calculate dynamic k for each gt + dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) + matching_matrix[:, gt_idx][pos_idx] = 1 + + del topk_ious, dynamic_ks, pos_idx + + prior_match_gt_mask = matching_matrix.sum(1) > 1 + if prior_match_gt_mask.sum() > 0: + cost_min, cost_argmin = torch.min( + cost[prior_match_gt_mask, :], dim=1) + matching_matrix[prior_match_gt_mask, :] *= 0 + matching_matrix[prior_match_gt_mask, cost_argmin] = 1 + # get foreground mask inside box and center prior + fg_mask_inboxes = matching_matrix.sum(1) > 0 + valid_mask[valid_mask.clone()] = fg_mask_inboxes + + matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + matched_pred_ious = (matching_matrix * + pairwise_ious).sum(1)[fg_mask_inboxes] + return matched_pred_ious, matched_gt_inds diff --git a/mmdet/models/task_modules/assigners/task_aligned_assigner.py b/mmdet/models/task_modules/assigners/task_aligned_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..220ea8485933ab3243f6c1e205dbf1b973df08d7 --- /dev/null +++ b/mmdet/models/task_modules/assigners/task_aligned_assigner.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + +INF = 100000000 + + +@TASK_UTILS.register_module() +class TaskAlignedAssigner(BaseAssigner): + """Task aligned assigner used in the paper: + `TOOD: Task-aligned One-stage Object Detection. + `_. + + Assign a corresponding gt bbox or background to each predicted bbox. + Each bbox will be assigned with `0` or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + topk (int): number of bbox selected in each level + iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou + calculator. Defaults to ``dict(type='BboxOverlaps2D')`` + """ + + def __init__(self, + topk: int, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D')): + assert topk >= 1 + self.topk = topk + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + alpha: int = 1, + beta: int = 6) -> AssignResult: + """Assign gt to bboxes. + + The assignment is done in following steps + + 1. compute alignment metric between all bbox (bbox of all pyramid + levels) and gt + 2. select top-k bbox as candidates for each gt + 3. limit the positive sample's center in gt (because the anchor-free + detector only can predict positive distance) + + + Args: + pred_instances (:obj:`InstaceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors, points, or bboxes predicted by the model, + shape(n, 4). + gt_instances (:obj:`InstaceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + gt_instances_ignore (:obj:`InstaceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + alpha (int): Hyper-parameters related to alignment_metrics. + Defaults to 1. + beta (int): Hyper-parameters related to alignment_metrics. + Defaults to 6. + + Returns: + :obj:`TaskAlignedAssignResult`: The assign result. + """ + priors = pred_instances.priors + decode_bboxes = pred_instances.bboxes + pred_scores = pred_instances.scores + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + + priors = priors[:, :4] + num_gt, num_bboxes = gt_bboxes.size(0), priors.size(0) + # compute alignment metric between all bbox and gt + overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach() + bbox_scores = pred_scores[:, gt_labels].detach() + # assign 0 by default + assigned_gt_inds = priors.new_full((num_bboxes, ), 0, dtype=torch.long) + assign_metrics = priors.new_zeros((num_bboxes, )) + + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = priors.new_zeros((num_bboxes, )) + if num_gt == 0: + # No gt boxes, assign everything to background + assigned_gt_inds[:] = 0 + assigned_labels = priors.new_full((num_bboxes, ), + -1, + dtype=torch.long) + assign_result = AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + assign_result.assign_metrics = assign_metrics + return assign_result + + # select top-k bboxes as candidates for each gt + alignment_metrics = bbox_scores**alpha * overlaps**beta + topk = min(self.topk, alignment_metrics.size(0)) + _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True) + candidate_metrics = alignment_metrics[candidate_idxs, + torch.arange(num_gt)] + is_pos = candidate_metrics > 0 + + # limit the positive sample's center in gt + priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0 + priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0 + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_bboxes + ep_priors_cx = priors_cx.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + ep_priors_cy = priors_cy.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + candidate_idxs = candidate_idxs.view(-1) + + # calculate the left, top, right, bottom distance between positive + # bbox center and gt side + l_ = ep_priors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] + t_ = ep_priors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - ep_priors_cx[candidate_idxs].view(-1, num_gt) + b_ = gt_bboxes[:, 3] - ep_priors_cy[candidate_idxs].view(-1, num_gt) + is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, + # the one with the highest iou will be selected. + overlaps_inf = torch.full_like(overlaps, + -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] + overlaps_inf = overlaps_inf.view(num_gt, -1).t() + + max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) + assigned_gt_inds[ + max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 + assign_metrics[max_overlaps != -INF] = alignment_metrics[ + max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]] + + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - + 1] + assign_result = AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + assign_result.assign_metrics = assign_metrics + return assign_result diff --git a/mmdet/models/task_modules/assigners/topk_hungarian_assigner.py b/mmdet/models/task_modules/assigners/topk_hungarian_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..e48f092ac1ae99eadfdf7502b591b57c782e6354 --- /dev/null +++ b/mmdet/models/task_modules/assigners/topk_hungarian_assigner.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.structures import BaseDataElement +from scipy.optimize import linear_sum_assignment + +from mmdet.registry import TASK_UTILS +from .assign_result import AssignResult +from .task_aligned_assigner import TaskAlignedAssigner + + +@TASK_UTILS.register_module() +class TopkHungarianAssigner(TaskAlignedAssigner): + """Computes 1-to-k matching between ground truth and predictions. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of some components. + For DETR the costs are weighted sum of classification cost, regression L1 + cost and regression iou cost. The targets don't include the no_object, so + generally there are more predictions than targets. After the 1-to-k + gt-pred matching, the un-matched are treated as backgrounds. Thus each + query prediction will be assigned with `0` or a positive integer + indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (dict): Classification cost configuration. + reg_cost (dict): Regression L1 cost configuration. + iou_cost (dict): Regression iou cost configuration. + """ + + def __init__(self, + *args, + cls_cost=dict(type='FocalLossCost', weight=2.0), + reg_cost=dict(type='BBoxL1Cost', weight=5.0), + iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0), + **kwargs): + super(TopkHungarianAssigner, self).__init__(*args, **kwargs) + + self.cls_cost = TASK_UTILS.build(cls_cost) + self.reg_cost = TASK_UTILS.build(reg_cost) + self.iou_cost = TASK_UTILS.build(iou_cost) + + def assign(self, + pred_scores, + decode_bboxes, + gt_bboxes, + gt_labels, + img_meta, + alpha=1, + beta=6, + **kwargs): + """Computes 1-to-k gt-pred matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. Assign every prediction to -1. + 2. Compute the weighted costs, each cost has shape (num_pred, num_gt). + 3. Update topk to be min(topk, int(num_pred / num_gt)), then repeat + costs topk times to shape: (num_pred, num_gt * topk), so that each + gt will match topk predictions. + 3. Do Hungarian matching on CPU based on the costs. + 4. Assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + 5. Calculate alignment metrics and overlaps of each matched pred-gt + pair. + + Args: + pred_scores (Tensor): Predicted normalized classification + scores for one image, has shape (num_dense_queries, + cls_out_channels). + decode_bboxes (Tensor): Predicted unnormalized bbox coordinates + for one image, has shape (num_dense_queries, 4) with the + last dimension arranged as (x1, y1, x2, y2). + gt_bboxes (Tensor): Unnormalized ground truth + bboxes for one image, has shape (num_gt, 4) with the + last dimension arranged as (x1, y1, x2, y2). + NOTE: num_gt is dynamic for each image. + gt_labels (Tensor): Ground truth classification + index for the image, has shape (num_gt,). + NOTE: num_gt is dynamic for each image. + img_meta (dict): Meta information for one image. + alpha (int): Hyper-parameters related to alignment_metrics. + Defaults to 1. + beta (int): Hyper-parameters related to alignment_metrics. + Defaults to 6. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + pred_scores = pred_scores.detach() + decode_bboxes = decode_bboxes.detach() + temp_overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach() + bbox_scores = pred_scores[:, gt_labels].detach() + alignment_metrics = bbox_scores**alpha * temp_overlaps**beta + + pred_instances = BaseDataElement() + gt_instances = BaseDataElement() + + pred_instances.bboxes = decode_bboxes + gt_instances.bboxes = gt_bboxes + + pred_instances.scores = pred_scores + gt_instances.labels = gt_labels + + reg_cost = self.reg_cost(pred_instances, gt_instances, img_meta) + iou_cost = self.iou_cost(pred_instances, gt_instances, img_meta) + cls_cost = self.cls_cost(pred_instances, gt_instances, img_meta) + all_cost = cls_cost + reg_cost + iou_cost + + num_gt, num_bboxes = gt_bboxes.size(0), pred_scores.size(0) + if num_gt > 0: + # assign 0 by default + assigned_gt_inds = pred_scores.new_full((num_bboxes, ), + 0, + dtype=torch.long) + select_cost = all_cost + + topk = min(self.topk, int(len(select_cost) / num_gt)) + + # Repeat the ground truth `topk` times to perform 1-to-k gt-pred + # matching. For example, if `num_pred` = 900, `num_gt` = 3, then + # there are only 3 gt-pred pairs in sum for 1-1 matching. + # However, for 1-k gt-pred matching, if `topk` = 4, then each + # gt is assigned 4 unique predictions, so there would be 12 + # gt-pred pairs in sum. + repeat_select_cost = select_cost[..., + None].repeat(1, 1, topk).view( + select_cost.size(0), -1) + # anchor index and gt index + matched_row_inds, matched_col_inds = linear_sum_assignment( + repeat_select_cost.detach().cpu().numpy()) + matched_row_inds = torch.from_numpy(matched_row_inds).to( + pred_scores.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to( + pred_scores.device) + + match_gt_ids = matched_col_inds // topk + candidate_idxs = matched_row_inds + + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + + if candidate_idxs.numel() > 0: + assigned_labels[candidate_idxs] = gt_labels[match_gt_ids] + else: + assigned_labels = None + + assigned_gt_inds[candidate_idxs] = match_gt_ids + 1 + + overlaps = self.iou_calculator( + decode_bboxes[candidate_idxs], + gt_bboxes[match_gt_ids], + is_aligned=True).detach() + + temp_pos_alignment_metrics = alignment_metrics[candidate_idxs] + pos_alignment_metrics = torch.gather(temp_pos_alignment_metrics, 1, + match_gt_ids[:, + None]).view(-1) + assign_result = AssignResult( + num_gt, assigned_gt_inds, overlaps, labels=assigned_labels) + + assign_result.assign_metrics = pos_alignment_metrics + return assign_result + else: + + assigned_gt_inds = pred_scores.new_full((num_bboxes, ), + -1, + dtype=torch.long) + + assigned_labels = pred_scores.new_full((num_bboxes, ), + -1, + dtype=torch.long) + + assigned_gt_inds[:] = 0 + return AssignResult( + 0, assigned_gt_inds, None, labels=assigned_labels) diff --git a/mmdet/models/task_modules/assigners/uniform_assigner.py b/mmdet/models/task_modules/assigners/uniform_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..9a83bfd0b46a3690dce9cf0adf2c1e676f304d06 --- /dev/null +++ b/mmdet/models/task_modules/assigners/uniform_assigner.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh +from mmdet.utils import ConfigType +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class UniformAssigner(BaseAssigner): + """Uniform Matching between the priors and gt boxes, which can achieve + balance in positive priors, and gt_bboxes_ignore was not considered for + now. + + Args: + pos_ignore_thr (float): the threshold to ignore positive priors + neg_ignore_thr (float): the threshold to ignore negative priors + match_times(int): Number of positive priors for each gt box. + Defaults to 4. + iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou + calculator. Defaults to ``dict(type='BboxOverlaps2D')`` + """ + + def __init__(self, + pos_ignore_thr: float, + neg_ignore_thr: float, + match_times: int = 4, + iou_calculator: ConfigType = dict(type='BboxOverlaps2D')): + self.match_times = match_times + self.pos_ignore_thr = pos_ignore_thr + self.neg_ignore_thr = neg_ignore_thr + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def assign( + self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None + ) -> AssignResult: + """Assign gt to priors. + + The assignment is done in following steps + + 1. assign -1 by default + 2. compute the L1 cost between boxes. Note that we use priors and + predict boxes both + 3. compute the ignore indexes use gt_bboxes and predict boxes + 4. compute the ignore indexes of positive sample use priors and + predict boxes + + + Args: + pred_instances (:obj:`InstaceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be priors, points, or bboxes predicted by the model, + shape(n, 4). + gt_instances (:obj:`InstaceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + gt_instances_ignore (:obj:`InstaceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + """ + + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + priors = pred_instances.priors + bbox_pred = pred_instances.decoder_priors + + num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) + + # 1. assign -1 by default + assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), + 0, + dtype=torch.long) + assigned_labels = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + assign_result = AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + assign_result.set_extra_property( + 'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool)) + assign_result.set_extra_property('pos_predicted_boxes', + bbox_pred.new_empty((0, 4))) + assign_result.set_extra_property('target_boxes', + bbox_pred.new_empty((0, 4))) + return assign_result + + # 2. Compute the L1 cost between boxes + # Note that we use priors and predict boxes both + cost_bbox = torch.cdist( + bbox_xyxy_to_cxcywh(bbox_pred), + bbox_xyxy_to_cxcywh(gt_bboxes), + p=1) + cost_bbox_priors = torch.cdist( + bbox_xyxy_to_cxcywh(priors), bbox_xyxy_to_cxcywh(gt_bboxes), p=1) + + # We found that topk function has different results in cpu and + # cuda mode. In order to ensure consistency with the source code, + # we also use cpu mode. + # TODO: Check whether the performance of cpu and cuda are the same. + C = cost_bbox.cpu() + C1 = cost_bbox_priors.cpu() + + # self.match_times x n + index = torch.topk( + C, # c=b,n,x c[i]=n,x + k=self.match_times, + dim=0, + largest=False)[1] + + # self.match_times x n + index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1] + # (self.match_times*2) x n + indexes = torch.cat((index, index1), + dim=1).reshape(-1).to(bbox_pred.device) + + pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes) + anchor_overlaps = self.iou_calculator(priors, gt_bboxes) + pred_max_overlaps, _ = pred_overlaps.max(dim=1) + anchor_max_overlaps, _ = anchor_overlaps.max(dim=0) + + # 3. Compute the ignore indexes use gt_bboxes and predict boxes + ignore_idx = pred_max_overlaps > self.neg_ignore_thr + assigned_gt_inds[ignore_idx] = -1 + + # 4. Compute the ignore indexes of positive sample use priors + # and predict boxes + pos_gt_index = torch.arange( + 0, C1.size(1), + device=bbox_pred.device).repeat(self.match_times * 2) + pos_ious = anchor_overlaps[indexes, pos_gt_index] + pos_ignore_idx = pos_ious < self.pos_ignore_thr + + pos_gt_index_with_ignore = pos_gt_index + 1 + pos_gt_index_with_ignore[pos_ignore_idx] = -1 + assigned_gt_inds[indexes] = pos_gt_index_with_ignore + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + assign_result = AssignResult( + num_gts, + assigned_gt_inds, + anchor_max_overlaps, + labels=assigned_labels) + assign_result.set_extra_property('pos_idx', ~pos_ignore_idx) + assign_result.set_extra_property('pos_predicted_boxes', + bbox_pred[indexes]) + assign_result.set_extra_property('target_boxes', + gt_bboxes[pos_gt_index]) + return assign_result diff --git a/mmdet/models/task_modules/builder.py b/mmdet/models/task_modules/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6736049fef688e0d663d6195c79ec9688dc4c5d7 --- /dev/null +++ b/mmdet/models/task_modules/builder.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmdet.registry import TASK_UTILS + +PRIOR_GENERATORS = TASK_UTILS +ANCHOR_GENERATORS = TASK_UTILS +BBOX_ASSIGNERS = TASK_UTILS +BBOX_SAMPLERS = TASK_UTILS +BBOX_CODERS = TASK_UTILS +MATCH_COSTS = TASK_UTILS +IOU_CALCULATORS = TASK_UTILS + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + warnings.warn('``build_sampler`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_iou_calculator(cfg, default_args=None): + """Builder of IoU calculator.""" + warnings.warn( + '``build_iou_calculator`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_match_cost(cfg, default_args=None): + """Builder of IoU calculator.""" + warnings.warn('``build_match_cost`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_assigner(cfg, **default_args): + """Builder of box assigner.""" + warnings.warn('``build_assigner`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + warnings.warn('``build_sampler`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_prior_generator(cfg, default_args=None): + warnings.warn( + '``build_prior_generator`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn( + '``build_anchor_generator`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) diff --git a/mmdet/models/task_modules/coders/__init__.py b/mmdet/models/task_modules/coders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97c3982140021958dabdd03f8040519f946250ff --- /dev/null +++ b/mmdet/models/task_modules/coders/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_bbox_coder import BaseBBoxCoder +from .bucketing_bbox_coder import BucketingBBoxCoder +from .delta_xywh_bbox_coder import (DeltaXYWHBBoxCoder, + DeltaXYWHBBoxCoderForGLIP) +from .distance_point_bbox_coder import DistancePointBBoxCoder +from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder +from .pseudo_bbox_coder import PseudoBBoxCoder +from .tblr_bbox_coder import TBLRBBoxCoder +from .yolo_bbox_coder import YOLOBBoxCoder + +__all__ = [ + 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', + 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder', + 'BucketingBBoxCoder', 'DistancePointBBoxCoder', 'DeltaXYWHBBoxCoderForGLIP' +] diff --git a/mmdet/models/task_modules/coders/base_bbox_coder.py b/mmdet/models/task_modules/coders/base_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..806d2651869e02173578c9eb331758743a068dd9 --- /dev/null +++ b/mmdet/models/task_modules/coders/base_bbox_coder.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BaseBBoxCoder(metaclass=ABCMeta): + """Base bounding box coder. + + Args: + use_box_type (bool): Whether to warp decoded boxes with the + box type data structure. Defaults to False. + """ + + # The size of the last of dimension of the encoded tensor. + encode_size = 4 + + def __init__(self, use_box_type: bool = False, **kwargs): + self.use_box_type = use_box_type + + @abstractmethod + def encode(self, bboxes, gt_bboxes): + """Encode deltas between bboxes and ground truth boxes.""" + + @abstractmethod + def decode(self, bboxes, bboxes_pred): + """Decode the predicted bboxes according to prediction and base + boxes.""" diff --git a/mmdet/models/task_modules/coders/bucketing_bbox_coder.py b/mmdet/models/task_modules/coders/bucketing_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..4044e1cd91d619521606f3c03032a40a9fc27130 --- /dev/null +++ b/mmdet/models/task_modules/coders/bucketing_bbox_coder.py @@ -0,0 +1,366 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import (BaseBoxes, HorizontalBoxes, bbox_rescale, + get_box_tensor) +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class BucketingBBoxCoder(BaseBBoxCoder): + """Bucketing BBox Coder for Side-Aware Boundary Localization (SABL). + + Boundary Localization with Bucketing and Bucketing Guided Rescoring + are implemented here. + + Please refer to https://arxiv.org/abs/1912.04260 for more details. + + Args: + num_buckets (int): Number of buckets. + scale_factor (int): Scale factor of proposals to generate buckets. + offset_topk (int): Topk buckets are used to generate + bucket fine regression targets. Defaults to 2. + offset_upperbound (float): Offset upperbound to generate + bucket fine regression targets. + To avoid too large offset displacements. Defaults to 1.0. + cls_ignore_neighbor (bool): Ignore second nearest bucket or Not. + Defaults to True. + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + """ + + def __init__(self, + num_buckets: int, + scale_factor: int, + offset_topk: int = 2, + offset_upperbound: float = 1.0, + cls_ignore_neighbor: bool = True, + clip_border: bool = True, + **kwargs) -> None: + super().__init__(**kwargs) + self.num_buckets = num_buckets + self.scale_factor = scale_factor + self.offset_topk = offset_topk + self.offset_upperbound = offset_upperbound + self.cls_ignore_neighbor = cls_ignore_neighbor + self.clip_border = clip_border + + def encode(self, bboxes: Union[Tensor, BaseBoxes], + gt_bboxes: Union[Tensor, BaseBoxes]) -> Tuple[Tensor]: + """Get bucketing estimation and fine regression targets during + training. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes, + e.g., object proposals. + gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the + transformation, e.g., ground truth boxes. + + Returns: + encoded_bboxes(tuple[Tensor]): bucketing estimation + and fine regression targets and weights + """ + bboxes = get_box_tensor(bboxes) + gt_bboxes = get_box_tensor(gt_bboxes) + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets, + self.scale_factor, self.offset_topk, + self.offset_upperbound, + self.cls_ignore_neighbor) + return encoded_bboxes + + def decode( + self, + bboxes: Union[Tensor, BaseBoxes], + pred_bboxes: Tensor, + max_shape: Optional[Tuple[int]] = None + ) -> Tuple[Union[Tensor, BaseBoxes], Tensor]: + """Apply transformation `pred_bboxes` to `boxes`. + Args: + boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. + pred_bboxes (torch.Tensor): Predictions for bucketing estimation + and fine regression + max_shape (tuple[int], optional): Maximum shape of boxes. + Defaults to None. + + Returns: + Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + """ + bboxes = get_box_tensor(bboxes) + assert len(pred_bboxes) == 2 + cls_preds, offset_preds = pred_bboxes + assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size( + 0) == bboxes.size(0) + bboxes, loc_confidence = bucket2bbox(bboxes, cls_preds, offset_preds, + self.num_buckets, + self.scale_factor, max_shape, + self.clip_border) + if self.use_box_type: + bboxes = HorizontalBoxes(bboxes, clone=False) + return bboxes, loc_confidence + + +def generat_buckets(proposals: Tensor, + num_buckets: int, + scale_factor: float = 1.0) -> Tuple[Tensor]: + """Generate buckets w.r.t bucket number and scale factor of proposals. + + Args: + proposals (Tensor): Shape (n, 4) + num_buckets (int): Number of buckets. + scale_factor (float): Scale factor to rescale proposals. + + Returns: + tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets, + t_buckets, d_buckets) + + - bucket_w: Width of buckets on x-axis. Shape (n, ). + - bucket_h: Height of buckets on y-axis. Shape (n, ). + - l_buckets: Left buckets. Shape (n, ceil(side_num/2)). + - r_buckets: Right buckets. Shape (n, ceil(side_num/2)). + - t_buckets: Top buckets. Shape (n, ceil(side_num/2)). + - d_buckets: Down buckets. Shape (n, ceil(side_num/2)). + """ + proposals = bbox_rescale(proposals, scale_factor) + + # number of buckets in each side + side_num = int(np.ceil(num_buckets / 2.0)) + pw = proposals[..., 2] - proposals[..., 0] + ph = proposals[..., 3] - proposals[..., 1] + px1 = proposals[..., 0] + py1 = proposals[..., 1] + px2 = proposals[..., 2] + py2 = proposals[..., 3] + + bucket_w = pw / num_buckets + bucket_h = ph / num_buckets + + # left buckets + l_buckets = px1[:, None] + (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None] + # right buckets + r_buckets = px2[:, None] - (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None] + # top buckets + t_buckets = py1[:, None] + (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None] + # down buckets + d_buckets = py2[:, None] - (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None] + return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets + + +def bbox2bucket(proposals: Tensor, + gt: Tensor, + num_buckets: int, + scale_factor: float, + offset_topk: int = 2, + offset_upperbound: float = 1.0, + cls_ignore_neighbor: bool = True) -> Tuple[Tensor]: + """Generate buckets estimation and fine regression targets. + + Args: + proposals (Tensor): Shape (n, 4) + gt (Tensor): Shape (n, 4) + num_buckets (int): Number of buckets. + scale_factor (float): Scale factor to rescale proposals. + offset_topk (int): Topk buckets are used to generate + bucket fine regression targets. Defaults to 2. + offset_upperbound (float): Offset allowance to generate + bucket fine regression targets. + To avoid too large offset displacements. Defaults to 1.0. + cls_ignore_neighbor (bool): Ignore second nearest bucket or Not. + Defaults to True. + + Returns: + tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights). + + - offsets: Fine regression targets. \ + Shape (n, num_buckets*2). + - offsets_weights: Fine regression weights. \ + Shape (n, num_buckets*2). + - bucket_labels: Bucketing estimation labels. \ + Shape (n, num_buckets*2). + - cls_weights: Bucketing estimation weights. \ + Shape (n, num_buckets*2). + """ + assert proposals.size() == gt.size() + + # generate buckets + proposals = proposals.float() + gt = gt.float() + (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, + d_buckets) = generat_buckets(proposals, num_buckets, scale_factor) + + gx1 = gt[..., 0] + gy1 = gt[..., 1] + gx2 = gt[..., 2] + gy2 = gt[..., 3] + + # generate offset targets and weights + # offsets from buckets to gts + l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None] + r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None] + t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None] + d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None] + + # select top-k nearest buckets + l_topk, l_label = l_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + r_topk, r_label = r_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + t_topk, t_label = t_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + d_topk, d_label = d_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + + offset_l_weights = l_offsets.new_zeros(l_offsets.size()) + offset_r_weights = r_offsets.new_zeros(r_offsets.size()) + offset_t_weights = t_offsets.new_zeros(t_offsets.size()) + offset_d_weights = d_offsets.new_zeros(d_offsets.size()) + inds = torch.arange(0, proposals.size(0)).to(proposals).long() + + # generate offset weights of top-k nearest buckets + for k in range(offset_topk): + if k >= 1: + offset_l_weights[inds, l_label[:, + k]] = (l_topk[:, k] < + offset_upperbound).float() + offset_r_weights[inds, r_label[:, + k]] = (r_topk[:, k] < + offset_upperbound).float() + offset_t_weights[inds, t_label[:, + k]] = (t_topk[:, k] < + offset_upperbound).float() + offset_d_weights[inds, d_label[:, + k]] = (d_topk[:, k] < + offset_upperbound).float() + else: + offset_l_weights[inds, l_label[:, k]] = 1.0 + offset_r_weights[inds, r_label[:, k]] = 1.0 + offset_t_weights[inds, t_label[:, k]] = 1.0 + offset_d_weights[inds, d_label[:, k]] = 1.0 + + offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1) + offsets_weights = torch.cat([ + offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights + ], + dim=-1) + + # generate bucket labels and weight + side_num = int(np.ceil(num_buckets / 2.0)) + labels = torch.stack( + [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1) + + batch_size = labels.size(0) + bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size, + -1).float() + bucket_cls_l_weights = (l_offsets.abs() < 1).float() + bucket_cls_r_weights = (r_offsets.abs() < 1).float() + bucket_cls_t_weights = (t_offsets.abs() < 1).float() + bucket_cls_d_weights = (d_offsets.abs() < 1).float() + bucket_cls_weights = torch.cat([ + bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights, + bucket_cls_d_weights + ], + dim=-1) + # ignore second nearest buckets for cls if necessary + if cls_ignore_neighbor: + bucket_cls_weights = (~((bucket_cls_weights == 1) & + (bucket_labels == 0))).float() + else: + bucket_cls_weights[:] = 1.0 + return offsets, offsets_weights, bucket_labels, bucket_cls_weights + + +def bucket2bbox(proposals: Tensor, + cls_preds: Tensor, + offset_preds: Tensor, + num_buckets: int, + scale_factor: float = 1.0, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + clip_border: bool = True) -> Tuple[Tensor]: + """Apply bucketing estimation (cls preds) and fine regression (offset + preds) to generate det bboxes. + + Args: + proposals (Tensor): Boxes to be transformed. Shape (n, 4) + cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2). + offset_preds (Tensor): fine regression. Shape (n, num_buckets*2). + num_buckets (int): Number of buckets. + scale_factor (float): Scale factor to rescale proposals. + max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + + Returns: + tuple[Tensor]: (bboxes, loc_confidence). + + - bboxes: predicted bboxes. Shape (n, 4) + - loc_confidence: localization confidence of predicted bboxes. + Shape (n,). + """ + + side_num = int(np.ceil(num_buckets / 2.0)) + cls_preds = cls_preds.view(-1, side_num) + offset_preds = offset_preds.view(-1, side_num) + + scores = F.softmax(cls_preds, dim=1) + score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True) + + rescaled_proposals = bbox_rescale(proposals, scale_factor) + + pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0] + ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1] + px1 = rescaled_proposals[..., 0] + py1 = rescaled_proposals[..., 1] + px2 = rescaled_proposals[..., 2] + py2 = rescaled_proposals[..., 3] + + bucket_w = pw / num_buckets + bucket_h = ph / num_buckets + + score_inds_l = score_label[0::4, 0] + score_inds_r = score_label[1::4, 0] + score_inds_t = score_label[2::4, 0] + score_inds_d = score_label[3::4, 0] + l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w + r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w + t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h + d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h + + offsets = offset_preds.view(-1, 4, side_num) + inds = torch.arange(proposals.size(0)).to(proposals).long() + l_offsets = offsets[:, 0, :][inds, score_inds_l] + r_offsets = offsets[:, 1, :][inds, score_inds_r] + t_offsets = offsets[:, 2, :][inds, score_inds_t] + d_offsets = offsets[:, 3, :][inds, score_inds_d] + + x1 = l_buckets - l_offsets * bucket_w + x2 = r_buckets - r_offsets * bucket_w + y1 = t_buckets - t_offsets * bucket_h + y2 = d_buckets - d_offsets * bucket_h + + if clip_border and max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1] - 1) + y1 = y1.clamp(min=0, max=max_shape[0] - 1) + x2 = x2.clamp(min=0, max=max_shape[1] - 1) + y2 = y2.clamp(min=0, max=max_shape[0] - 1) + bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]], + dim=-1) + + # bucketing guided rescoring + loc_confidence = score_topk[:, 0] + top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1 + loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float() + loc_confidence = loc_confidence.view(-1, 4).mean(dim=1) + + return bboxes, loc_confidence diff --git a/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py b/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b60b5ee791e05ce4f5f8d8e1876f7f61e964ed --- /dev/null +++ b/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py @@ -0,0 +1,579 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Optional, Sequence, Union + +import numpy as np +import torch +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class DeltaXYWHBBoxCoder(BaseBBoxCoder): + """Delta XYWH BBox coder. + + Following the practice in `R-CNN `_, + this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and + decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2). + + Args: + target_means (Sequence[float]): Denormalizing means of target for + delta coordinates + target_stds (Sequence[float]): Denormalizing standard deviation of + target for delta coordinates + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + add_ctr_clamp (bool): Whether to add center clamp, when added, the + predicted box is clamped is its center is too far away from + the original anchor's center. Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. + """ + + def __init__(self, + target_means: Sequence[float] = (0., 0., 0., 0.), + target_stds: Sequence[float] = (1., 1., 1., 1.), + clip_border: bool = True, + add_ctr_clamp: bool = False, + ctr_clamp: int = 32, + **kwargs) -> None: + super().__init__(**kwargs) + self.means = target_means + self.stds = target_stds + self.clip_border = clip_border + self.add_ctr_clamp = add_ctr_clamp + self.ctr_clamp = ctr_clamp + + def encode(self, bboxes: Union[Tensor, BaseBoxes], + gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor: + """Get box regression transformation deltas that can be used to + transform the ``bboxes`` into the ``gt_bboxes``. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, + e.g., object proposals. + gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the + transformation, e.g., ground-truth boxes. + + Returns: + torch.Tensor: Box transformation deltas + """ + bboxes = get_box_tensor(bboxes) + gt_bboxes = get_box_tensor(gt_bboxes) + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds) + return encoded_bboxes + + def decode( + self, + bboxes: Union[Tensor, BaseBoxes], + pred_bboxes: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: Optional[float] = 16 / 1000 + ) -> Union[Tensor, BaseBoxes]: + """Apply transformation `pred_bboxes` to `boxes`. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape + (B, N, 4) or (N, 4) + pred_bboxes (Tensor): Encoded offsets with respect to each roi. + Has shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H + when rois is a grid of anchors.Offset encoding follows [1]_. + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + wh_ratio_clip (float, optional): The allowed ratio between + width and height. + + Returns: + Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + """ + bboxes = get_box_tensor(bboxes) + assert pred_bboxes.size(0) == bboxes.size(0) + if pred_bboxes.ndim == 3: + assert pred_bboxes.size(1) == bboxes.size(1) + + if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export(): + # single image decode + decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, + self.stds, max_shape, wh_ratio_clip, + self.clip_border, self.add_ctr_clamp, + self.ctr_clamp) + else: + if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export(): + warnings.warn( + 'DeprecationWarning: onnx_delta2bbox is deprecated ' + 'in the case of batch decoding and non-ONNX, ' + 'please use “delta2bbox” instead. In order to improve ' + 'the decoding speed, the batch function will no ' + 'longer be supported. ') + decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means, + self.stds, max_shape, + wh_ratio_clip, self.clip_border, + self.add_ctr_clamp, + self.ctr_clamp) + + if self.use_box_type: + assert decoded_bboxes.size(-1) == 4, \ + ('Cannot warp decoded boxes with box type when decoded boxes' + 'have shape of (N, num_classes * 4)') + decoded_bboxes = HorizontalBoxes(decoded_bboxes) + return decoded_bboxes + + +@TASK_UTILS.register_module() +class DeltaXYWHBBoxCoderForGLIP(DeltaXYWHBBoxCoder): + """This is designed specifically for the GLIP algorithm. + + In order to completely match the official performance, we need to perform + special calculations in the encoding and decoding processes, such as + additional +1 and -1 calculations. However, this is not a user-friendly + design. + """ + + def encode(self, bboxes: Union[Tensor, BaseBoxes], + gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor: + """Get box regression transformation deltas that can be used to + transform the ``bboxes`` into the ``gt_bboxes``. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, + e.g., object proposals. + gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the + transformation, e.g., ground-truth boxes. + + Returns: + torch.Tensor: Box transformation deltas + """ + bboxes = get_box_tensor(bboxes) + gt_bboxes = get_box_tensor(gt_bboxes) + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds) + return encoded_bboxes + + def decode( + self, + bboxes: Union[Tensor, BaseBoxes], + pred_bboxes: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: Optional[float] = 16 / 1000 + ) -> Union[Tensor, BaseBoxes]: + """Apply transformation `pred_bboxes` to `boxes`. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape + (B, N, 4) or (N, 4) + pred_bboxes (Tensor): Encoded offsets with respect to each roi. + Has shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H + when rois is a grid of anchors.Offset encoding follows [1]_. + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + wh_ratio_clip (float, optional): The allowed ratio between + width and height. + + Returns: + Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + """ + bboxes = get_box_tensor(bboxes) + assert pred_bboxes.size(0) == bboxes.size(0) + if pred_bboxes.ndim == 3: + assert pred_bboxes.size(1) == bboxes.size(1) + + if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export(): + # single image decode + decoded_bboxes = delta2bbox_glip(bboxes, pred_bboxes, self.means, + self.stds, max_shape, + wh_ratio_clip, self.clip_border, + self.add_ctr_clamp, + self.ctr_clamp) + else: + raise NotImplementedError() + + if self.use_box_type: + assert decoded_bboxes.size(-1) == 4, \ + ('Cannot warp decoded boxes with box type when decoded boxes' + 'have shape of (N, num_classes * 4)') + decoded_bboxes = HorizontalBoxes(decoded_bboxes) + return decoded_bboxes + + +def bbox2delta( + proposals: Tensor, + gt: Tensor, + means: Sequence[float] = (0., 0., 0., 0.), + stds: Sequence[float] = (1., 1., 1., 1.) +) -> Tensor: + """Compute deltas of proposals w.r.t. gt. + + We usually compute the deltas of x, y, w, h of proposals w.r.t ground + truth bboxes to get regression target. + This is the inverse function of :func:`delta2bbox`. + + Args: + proposals (Tensor): Boxes to be transformed, shape (N, ..., 4) + gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4) + means (Sequence[float]): Denormalizing means for delta coordinates + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates + + Returns: + Tensor: deltas with shape (N, 4), where columns represent dx, dy, + dw, dh. + """ + assert proposals.size() == gt.size() + + proposals = proposals.float() + gt = gt.float() + px = (proposals[..., 0] + proposals[..., 2]) * 0.5 + py = (proposals[..., 1] + proposals[..., 3]) * 0.5 + pw = proposals[..., 2] - proposals[..., 0] + ph = proposals[..., 3] - proposals[..., 1] + + gx = (gt[..., 0] + gt[..., 2]) * 0.5 + gy = (gt[..., 1] + gt[..., 3]) * 0.5 + gw = gt[..., 2] - gt[..., 0] + gh = gt[..., 3] - gt[..., 1] + + dx = (gx - px) / pw + dy = (gy - py) / ph + dw = torch.log(gw / pw) + dh = torch.log(gh / ph) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + means = deltas.new_tensor(means).unsqueeze(0) + stds = deltas.new_tensor(stds).unsqueeze(0) + deltas = deltas.sub_(means).div_(stds) + + return deltas + + +def delta2bbox(rois: Tensor, + deltas: Tensor, + means: Sequence[float] = (0., 0., 0., 0.), + stds: Sequence[float] = (1., 1., 1., 1.), + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: float = 16 / 1000, + clip_border: bool = True, + add_ctr_clamp: bool = False, + ctr_clamp: int = 32) -> Tensor: + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + This is the inverse function of :func:`bbox2delta`. + + Args: + rois (Tensor): Boxes to be transformed. Has shape (N, 4). + deltas (Tensor): Encoded offsets relative to each roi. + Has shape (N, num_classes * 4) or (N, 4). Note + N = num_base_anchors * W * H, when rois is a grid of + anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1.). + max_shape (tuple[int, int]): Maximum bounds for boxes, specifies + (H, W). Default None. + wh_ratio_clip (float): Maximum aspect ratio for boxes. Default + 16 / 1000. + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Default True. + add_ctr_clamp (bool): Whether to add center clamp. When set to True, + the center of the prediction bounding box will be clamped to + avoid being too far away from the center of the anchor. + Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. + + Returns: + Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4 + represent tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Example: + >>> rois = torch.Tensor([[ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 5., 5., 5., 5.]]) + >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], + >>> [ 1., 1., 1., 1.], + >>> [ 0., 0., 2., -1.], + >>> [ 0.7, -1.9, -0.5, 0.3]]) + >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) + tensor([[0.0000, 0.0000, 1.0000, 1.0000], + [0.1409, 0.1409, 2.8591, 2.8591], + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) + """ + num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4 + if num_bboxes == 0: + return deltas + + deltas = deltas.reshape(-1, 4) + + means = deltas.new_tensor(means).view(1, -1) + stds = deltas.new_tensor(stds).view(1, -1) + denorm_deltas = deltas * stds + means + + dxy = denorm_deltas[:, :2] + dwh = denorm_deltas[:, 2:] + + # Compute width/height of each roi + rois_ = rois.repeat(1, num_classes).reshape(-1, 4) + pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5) + pwh = (rois_[:, 2:] - rois_[:, :2]) + + dxy_wh = pwh * dxy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp) + dwh = torch.clamp(dwh, max=max_ratio) + else: + dwh = dwh.clamp(min=-max_ratio, max=max_ratio) + + gxy = pxy + dxy_wh + gwh = pwh * dwh.exp() + x1y1 = gxy - (gwh * 0.5) + x2y2 = gxy + (gwh * 0.5) + bboxes = torch.cat([x1y1, x2y2], dim=-1) + if clip_border and max_shape is not None: + bboxes[..., 0::2].clamp_(min=0, max=max_shape[1]) + bboxes[..., 1::2].clamp_(min=0, max=max_shape[0]) + bboxes = bboxes.reshape(num_bboxes, -1) + return bboxes + + +def onnx_delta2bbox(rois: Tensor, + deltas: Tensor, + means: Sequence[float] = (0., 0., 0., 0.), + stds: Sequence[float] = (1., 1., 1., 1.), + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: float = 16 / 1000, + clip_border: Optional[bool] = True, + add_ctr_clamp: bool = False, + ctr_clamp: int = 32) -> Tensor: + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + This is the inverse function of :func:`bbox2delta`. + + Args: + rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4) + deltas (Tensor): Encoded offsets with respect to each roi. + Has shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H + when rois is a grid of anchors.Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1.). + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If rois shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. Default None. + wh_ratio_clip (float): Maximum aspect ratio for boxes. + Default 16 / 1000. + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Default True. + add_ctr_clamp (bool): Whether to add center clamp, when added, the + predicted box is clamped is its center is too far away from + the original anchor's center. Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. + + Returns: + Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4), where 4 represent + tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Example: + >>> rois = torch.Tensor([[ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 5., 5., 5., 5.]]) + >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], + >>> [ 1., 1., 1., 1.], + >>> [ 0., 0., 2., -1.], + >>> [ 0.7, -1.9, -0.5, 0.3]]) + >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) + tensor([[0.0000, 0.0000, 1.0000, 1.0000], + [0.1409, 0.1409, 2.8591, 2.8591], + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) + """ + means = deltas.new_tensor(means).view(1, + -1).repeat(1, + deltas.size(-1) // 4) + stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4) + denorm_deltas = deltas * stds + means + dx = denorm_deltas[..., 0::4] + dy = denorm_deltas[..., 1::4] + dw = denorm_deltas[..., 2::4] + dh = denorm_deltas[..., 3::4] + + x1, y1 = rois[..., 0], rois[..., 1] + x2, y2 = rois[..., 2], rois[..., 3] + # Compute center of each roi + px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx) + py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy) + # Compute width/height of each roi + pw = (x2 - x1).unsqueeze(-1).expand_as(dw) + ph = (y2 - y1).unsqueeze(-1).expand_as(dh) + + dx_width = pw * dx + dy_height = ph * dy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp) + dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp) + dw = torch.clamp(dw, max=max_ratio) + dh = torch.clamp(dh, max=max_ratio) + else: + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() + # Use network energy to shift the center of each roi + gx = px + dx_width + gy = py + dy_height + # Convert center-xy/width/height to top-left, bottom-right + x1 = gx - gw * 0.5 + y1 = gy - gh * 0.5 + x2 = gx + gw * 0.5 + y2 = gy + gh * 0.5 + + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) + + if clip_border and max_shape is not None: + # clip bboxes with dynamic `min` and `max` for onnx + if torch.onnx.is_in_onnx_export(): + from mmdet.core.export import dynamic_clip_for_onnx + x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) + return bboxes + if not isinstance(max_shape, torch.Tensor): + max_shape = x1.new_tensor(max_shape) + max_shape = max_shape[..., :2].type_as(x1) + if max_shape.ndim == 2: + assert bboxes.ndim == 3 + assert max_shape.size(0) == bboxes.size(0) + + min_xy = x1.new_tensor(0) + max_xy = torch.cat( + [max_shape] * (deltas.size(-1) // 2), + dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + return bboxes + + +def delta2bbox_glip(rois: Tensor, + deltas: Tensor, + means: Sequence[float] = (0., 0., 0., 0.), + stds: Sequence[float] = (1., 1., 1., 1.), + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: float = 16 / 1000, + clip_border: bool = True, + add_ctr_clamp: bool = False, + ctr_clamp: int = 32) -> Tensor: + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + This is the inverse function of :func:`bbox2delta`. + + Args: + rois (Tensor): Boxes to be transformed. Has shape (N, 4). + deltas (Tensor): Encoded offsets relative to each roi. + Has shape (N, num_classes * 4) or (N, 4). Note + N = num_base_anchors * W * H, when rois is a grid of + anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1.). + max_shape (tuple[int, int]): Maximum bounds for boxes, specifies + (H, W). Default None. + wh_ratio_clip (float): Maximum aspect ratio for boxes. Default + 16 / 1000. + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Default True. + add_ctr_clamp (bool): Whether to add center clamp. When set to True, + the center of the prediction bounding box will be clamped to + avoid being too far away from the center of the anchor. + Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. + + Returns: + Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4 + represent tl_x, tl_y, br_x, br_y. + """ + num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4 + if num_bboxes == 0: + return deltas + + deltas = deltas.reshape(-1, 4) + + means = deltas.new_tensor(means).view(1, -1) + stds = deltas.new_tensor(stds).view(1, -1) + denorm_deltas = deltas * stds + means + + dxy = denorm_deltas[:, :2] + dwh = denorm_deltas[:, 2:] + + # Compute width/height of each roi + rois_ = rois.repeat(1, num_classes).reshape(-1, 4) + pxy = ((rois_[:, :2] + rois_[:, 2:] - 1) * 0.5) # note + pwh = (rois_[:, 2:] - rois_[:, :2]) + + dxy_wh = pwh * dxy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp) + dwh = torch.clamp(dwh, max=max_ratio) + else: + dwh = dwh.clamp(min=-max_ratio, max=max_ratio) + + gxy = pxy + dxy_wh + gwh = pwh * dwh.exp() + + x1y1 = gxy - (gwh - 1) * 0.5 # Note + x2y2 = gxy + (gwh - 1) * 0.5 # Note + + bboxes = torch.cat([x1y1, x2y2], dim=-1) + + if clip_border and max_shape is not None: + bboxes[..., 0::2].clamp_(min=0, max=max_shape[1] - 1) # Note + bboxes[..., 1::2].clamp_(min=0, max=max_shape[0] - 1) # Note + bboxes = bboxes.reshape(num_bboxes, -1) + return bboxes diff --git a/mmdet/models/task_modules/coders/distance_point_bbox_coder.py b/mmdet/models/task_modules/coders/distance_point_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..ab26bf4b96c48df689da3722c23aa65e646348db --- /dev/null +++ b/mmdet/models/task_modules/coders/distance_point_bbox_coder.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Union + +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import (BaseBoxes, HorizontalBoxes, bbox2distance, + distance2bbox, get_box_tensor) +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class DistancePointBBoxCoder(BaseBBoxCoder): + """Distance Point BBox coder. + + This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, + right) and decode it back to the original. + + Args: + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + """ + + def __init__(self, clip_border: Optional[bool] = True, **kwargs) -> None: + super().__init__(**kwargs) + self.clip_border = clip_border + + def encode(self, + points: Tensor, + gt_bboxes: Union[Tensor, BaseBoxes], + max_dis: Optional[float] = None, + eps: float = 0.1) -> Tensor: + """Encode bounding box to distances. + + Args: + points (Tensor): Shape (N, 2), The format is [x, y]. + gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format + is "xyxy" + max_dis (float): Upper bound of the distance. Default None. + eps (float): a small value to ensure target < max_dis, instead <=. + Default 0.1. + + Returns: + Tensor: Box transformation deltas. The shape is (N, 4). + """ + gt_bboxes = get_box_tensor(gt_bboxes) + assert points.size(0) == gt_bboxes.size(0) + assert points.size(-1) == 2 + assert gt_bboxes.size(-1) == 4 + return bbox2distance(points, gt_bboxes, max_dis, eps) + + def decode( + self, + points: Tensor, + pred_bboxes: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None + ) -> Union[Tensor, BaseBoxes]: + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (B, N, 2) or (N, 2). + pred_bboxes (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). Shape (B, N, 4) + or (N, 4) + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]], + and the length of max_shape should also be B. + Default None. + Returns: + Union[Tensor, :obj:`BaseBoxes`]: Boxes with shape (N, 4) or + (B, N, 4) + """ + assert points.size(0) == pred_bboxes.size(0) + assert points.size(-1) == 2 + assert pred_bboxes.size(-1) == 4 + if self.clip_border is False: + max_shape = None + bboxes = distance2bbox(points, pred_bboxes, max_shape) + + if self.use_box_type: + bboxes = HorizontalBoxes(bboxes) + return bboxes diff --git a/mmdet/models/task_modules/coders/legacy_delta_xywh_bbox_coder.py b/mmdet/models/task_modules/coders/legacy_delta_xywh_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb1bedb3fbe19433c8bdb37f80891efa2cb72fc --- /dev/null +++ b/mmdet/models/task_modules/coders/legacy_delta_xywh_bbox_coder.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Union + +import numpy as np +import torch +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder): + """Legacy Delta XYWH BBox coder used in MMDet V1.x. + + Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2, + y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh) + back to original bbox (x1, y1, x2, y2). + + Note: + The main difference between :class`LegacyDeltaXYWHBBoxCoder` and + :class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and + height calculation. We suggest to only use this coder when testing with + MMDet V1.x models. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Args: + target_means (Sequence[float]): denormalizing means of target for + delta coordinates + target_stds (Sequence[float]): denormalizing standard deviation of + target for delta coordinates + """ + + def __init__(self, + target_means: Sequence[float] = (0., 0., 0., 0.), + target_stds: Sequence[float] = (1., 1., 1., 1.), + **kwargs) -> None: + super().__init__(**kwargs) + self.means = target_means + self.stds = target_stds + + def encode(self, bboxes: Union[Tensor, BaseBoxes], + gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor: + """Get box regression transformation deltas that can be used to + transform the ``bboxes`` into the ``gt_bboxes``. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes, + e.g., object proposals. + gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the + transformation, e.g., ground-truth boxes. + + Returns: + torch.Tensor: Box transformation deltas + """ + bboxes = get_box_tensor(bboxes) + gt_bboxes = get_box_tensor(gt_bboxes) + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means, + self.stds) + return encoded_bboxes + + def decode( + self, + bboxes: Union[Tensor, BaseBoxes], + pred_bboxes: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: Optional[float] = 16 / 1000 + ) -> Union[Tensor, BaseBoxes]: + """Apply transformation `pred_bboxes` to `boxes`. + + Args: + boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. + pred_bboxes (torch.Tensor): Encoded boxes with shape + max_shape (tuple[int], optional): Maximum shape of boxes. + Defaults to None. + wh_ratio_clip (float, optional): The allowed ratio between + width and height. + + Returns: + Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + """ + bboxes = get_box_tensor(bboxes) + assert pred_bboxes.size(0) == bboxes.size(0) + decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means, + self.stds, max_shape, wh_ratio_clip) + + if self.use_box_type: + assert decoded_bboxes.size(-1) == 4, \ + ('Cannot warp decoded boxes with box type when decoded boxes' + 'have shape of (N, num_classes * 4)') + decoded_bboxes = HorizontalBoxes(decoded_bboxes) + return decoded_bboxes + + +def legacy_bbox2delta( + proposals: Tensor, + gt: Tensor, + means: Sequence[float] = (0., 0., 0., 0.), + stds: Sequence[float] = (1., 1., 1., 1.) +) -> Tensor: + """Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner. + + We usually compute the deltas of x, y, w, h of proposals w.r.t ground + truth bboxes to get regression target. + This is the inverse function of `delta2bbox()` + + Args: + proposals (Tensor): Boxes to be transformed, shape (N, ..., 4) + gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4) + means (Sequence[float]): Denormalizing means for delta coordinates + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates + + Returns: + Tensor: deltas with shape (N, 4), where columns represent dx, dy, + dw, dh. + """ + assert proposals.size() == gt.size() + + proposals = proposals.float() + gt = gt.float() + px = (proposals[..., 0] + proposals[..., 2]) * 0.5 + py = (proposals[..., 1] + proposals[..., 3]) * 0.5 + pw = proposals[..., 2] - proposals[..., 0] + 1.0 + ph = proposals[..., 3] - proposals[..., 1] + 1.0 + + gx = (gt[..., 0] + gt[..., 2]) * 0.5 + gy = (gt[..., 1] + gt[..., 3]) * 0.5 + gw = gt[..., 2] - gt[..., 0] + 1.0 + gh = gt[..., 3] - gt[..., 1] + 1.0 + + dx = (gx - px) / pw + dy = (gy - py) / ph + dw = torch.log(gw / pw) + dh = torch.log(gh / ph) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + means = deltas.new_tensor(means).unsqueeze(0) + stds = deltas.new_tensor(stds).unsqueeze(0) + deltas = deltas.sub_(means).div_(stds) + + return deltas + + +def legacy_delta2bbox(rois: Tensor, + deltas: Tensor, + means: Sequence[float] = (0., 0., 0., 0.), + stds: Sequence[float] = (1., 1., 1., 1.), + max_shape: Optional[ + Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + wh_ratio_clip: float = 16 / 1000) -> Tensor: + """Apply deltas to shift/scale base boxes in the MMDet V1.x manner. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + This is the inverse function of `bbox2delta()` + + Args: + rois (Tensor): Boxes to be transformed. Has shape (N, 4) + deltas (Tensor): Encoded offsets with respect to each roi. + Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when + rois is a grid of anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates + max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) + wh_ratio_clip (float): Maximum aspect ratio for boxes. + + Returns: + Tensor: Boxes with shape (N, 4), where columns represent + tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Example: + >>> rois = torch.Tensor([[ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 5., 5., 5., 5.]]) + >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], + >>> [ 1., 1., 1., 1.], + >>> [ 0., 0., 2., -1.], + >>> [ 0.7, -1.9, -0.5, 0.3]]) + >>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32)) + tensor([[0.0000, 0.0000, 1.5000, 1.5000], + [0.0000, 0.0000, 5.2183, 5.2183], + [0.0000, 0.1321, 7.8891, 0.8679], + [5.3967, 2.4251, 6.0033, 3.7749]]) + """ + means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) + stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) + denorm_deltas = deltas * stds + means + dx = denorm_deltas[:, 0::4] + dy = denorm_deltas[:, 1::4] + dw = denorm_deltas[:, 2::4] + dh = denorm_deltas[:, 3::4] + max_ratio = np.abs(np.log(wh_ratio_clip)) + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Compute center of each roi + px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) + py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) + # Compute width/height of each roi + pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw) + ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() + # Use network energy to shift the center of each roi + gx = px + pw * dx + gy = py + ph * dy + # Convert center-xy/width/height to top-left, bottom-right + + # The true legacy box coder should +- 0.5 here. + # However, current implementation improves the performance when testing + # the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP) + x1 = gx - gw * 0.5 + y1 = gy - gh * 0.5 + x2 = gx + gw * 0.5 + y2 = gy + gh * 0.5 + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1] - 1) + y1 = y1.clamp(min=0, max=max_shape[0] - 1) + x2 = x2.clamp(min=0, max=max_shape[1] - 1) + y2 = y2.clamp(min=0, max=max_shape[0] - 1) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) + return bboxes diff --git a/mmdet/models/task_modules/coders/pseudo_bbox_coder.py b/mmdet/models/task_modules/coders/pseudo_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee74311f6d12bde49d0c678edb60540a8c95c8b --- /dev/null +++ b/mmdet/models/task_modules/coders/pseudo_bbox_coder.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class PseudoBBoxCoder(BaseBBoxCoder): + """Pseudo bounding box coder.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def encode(self, bboxes: Tensor, gt_bboxes: Union[Tensor, + BaseBoxes]) -> Tensor: + """torch.Tensor: return the given ``bboxes``""" + gt_bboxes = get_box_tensor(gt_bboxes) + return gt_bboxes + + def decode(self, bboxes: Tensor, pred_bboxes: Union[Tensor, + BaseBoxes]) -> Tensor: + """torch.Tensor: return the given ``pred_bboxes``""" + if self.use_box_type: + pred_bboxes = HorizontalBoxes(pred_bboxes) + return pred_bboxes diff --git a/mmdet/models/task_modules/coders/tblr_bbox_coder.py b/mmdet/models/task_modules/coders/tblr_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..74b388f7bad6ebc1911cee5b0b7d73bbd04de17a --- /dev/null +++ b/mmdet/models/task_modules/coders/tblr_bbox_coder.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Union + +import torch +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class TBLRBBoxCoder(BaseBBoxCoder): + """TBLR BBox coder. + + Following the practice in `FSAF `_, + this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, + right) and decode it back to the original. + + Args: + normalizer (list | float): Normalization factor to be + divided with when coding the coordinates. If it is a list, it should + have length of 4 indicating normalization factor in tblr dims. + Otherwise it is a unified float factor for all dims. Default: 4.0 + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + """ + + def __init__(self, + normalizer: Union[Sequence[float], float] = 4.0, + clip_border: bool = True, + **kwargs) -> None: + super().__init__(**kwargs) + self.normalizer = normalizer + self.clip_border = clip_border + + def encode(self, bboxes: Union[Tensor, BaseBoxes], + gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor: + """Get box regression transformation deltas that can be used to + transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left, + bottom, right) order. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes, + e.g., object proposals. + gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the + transformation, e.g., ground truth boxes. + + Returns: + torch.Tensor: Box transformation deltas + """ + bboxes = get_box_tensor(bboxes) + gt_bboxes = get_box_tensor(gt_bboxes) + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = bboxes2tblr( + bboxes, gt_bboxes, normalizer=self.normalizer) + return encoded_bboxes + + def decode( + self, + bboxes: Union[Tensor, BaseBoxes], + pred_bboxes: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None + ) -> Union[Tensor, BaseBoxes]: + """Apply transformation `pred_bboxes` to `boxes`. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.Shape + (B, N, 4) or (N, 4) + pred_bboxes (torch.Tensor): Encoded boxes with shape + (B, N, 4) or (N, 4) + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + + Returns: + Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + """ + bboxes = get_box_tensor(bboxes) + decoded_bboxes = tblr2bboxes( + bboxes, + pred_bboxes, + normalizer=self.normalizer, + max_shape=max_shape, + clip_border=self.clip_border) + + if self.use_box_type: + decoded_bboxes = HorizontalBoxes(decoded_bboxes) + return decoded_bboxes + + +def bboxes2tblr(priors: Tensor, + gts: Tensor, + normalizer: Union[Sequence[float], float] = 4.0, + normalize_by_wh: bool = True) -> Tensor: + """Encode ground truth boxes to tblr coordinate. + + It first convert the gt coordinate to tblr format, + (top, bottom, left, right), relative to prior box centers. + The tblr coordinate may be normalized by the side length of prior bboxes + if `normalize_by_wh` is specified as True, and it is then normalized by + the `normalizer` factor. + + Args: + priors (Tensor): Prior boxes in point form + Shape: (num_proposals,4). + gts (Tensor): Coords of ground truth for each prior in point-form + Shape: (num_proposals, 4). + normalizer (Sequence[float] | float): normalization parameter of + encoded boxes. If it is a list, it has to have length = 4. + Default: 4.0 + normalize_by_wh (bool): Whether to normalize tblr coordinate by the + side length (wh) of prior bboxes. + + Return: + encoded boxes (Tensor), Shape: (num_proposals, 4) + """ + + # dist b/t match center and prior's center + if not isinstance(normalizer, float): + normalizer = torch.tensor(normalizer, device=priors.device) + assert len(normalizer) == 4, 'Normalizer must have length = 4' + assert priors.size(0) == gts.size(0) + prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2 + xmin, ymin, xmax, ymax = gts.split(1, dim=1) + top = prior_centers[:, 1].unsqueeze(1) - ymin + bottom = ymax - prior_centers[:, 1].unsqueeze(1) + left = prior_centers[:, 0].unsqueeze(1) - xmin + right = xmax - prior_centers[:, 0].unsqueeze(1) + loc = torch.cat((top, bottom, left, right), dim=1) + if normalize_by_wh: + # Normalize tblr by anchor width and height + wh = priors[:, 2:4] - priors[:, 0:2] + w, h = torch.split(wh, 1, dim=1) + loc[:, :2] /= h # tb is normalized by h + loc[:, 2:] /= w # lr is normalized by w + # Normalize tblr by the given normalization factor + return loc / normalizer + + +def tblr2bboxes(priors: Tensor, + tblr: Tensor, + normalizer: Union[Sequence[float], float] = 4.0, + normalize_by_wh: bool = True, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None, + clip_border: bool = True) -> Tensor: + """Decode tblr outputs to prediction boxes. + + The process includes 3 steps: 1) De-normalize tblr coordinates by + multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the + prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert + tblr (top, bottom, left, right) pair relative to the center of priors back + to (xmin, ymin, xmax, ymax) coordinate. + + Args: + priors (Tensor): Prior boxes in point form (x0, y0, x1, y1) + Shape: (N,4) or (B, N, 4). + tblr (Tensor): Coords of network output in tblr form + Shape: (N, 4) or (B, N, 4). + normalizer (Sequence[float] | float): Normalization parameter of + encoded boxes. By list, it represents the normalization factors at + tblr dims. By float, it is the unified normalization factor at all + dims. Default: 4.0 + normalize_by_wh (bool): Whether the tblr coordinates have been + normalized by the side length (wh) of prior bboxes. + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + + Return: + encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4) + """ + if not isinstance(normalizer, float): + normalizer = torch.tensor(normalizer, device=priors.device) + assert len(normalizer) == 4, 'Normalizer must have length = 4' + assert priors.size(0) == tblr.size(0) + if priors.ndim == 3: + assert priors.size(1) == tblr.size(1) + + loc_decode = tblr * normalizer + prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2 + if normalize_by_wh: + wh = priors[..., 2:4] - priors[..., 0:2] + w, h = torch.split(wh, 1, dim=-1) + # Inplace operation with slice would failed for exporting to ONNX + th = h * loc_decode[..., :2] # tb + tw = w * loc_decode[..., 2:] # lr + loc_decode = torch.cat([th, tw], dim=-1) + # Cannot be exported using onnx when loc_decode.split(1, dim=-1) + top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1) + xmin = prior_centers[..., 0].unsqueeze(-1) - left + xmax = prior_centers[..., 0].unsqueeze(-1) + right + ymin = prior_centers[..., 1].unsqueeze(-1) - top + ymax = prior_centers[..., 1].unsqueeze(-1) + bottom + + bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1) + + if clip_border and max_shape is not None: + # clip bboxes with dynamic `min` and `max` for onnx + if torch.onnx.is_in_onnx_export(): + from mmdet.core.export import dynamic_clip_for_onnx + xmin, ymin, xmax, ymax = dynamic_clip_for_onnx( + xmin, ymin, xmax, ymax, max_shape) + bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1) + return bboxes + if not isinstance(max_shape, torch.Tensor): + max_shape = priors.new_tensor(max_shape) + max_shape = max_shape[..., :2].type_as(priors) + if max_shape.ndim == 2: + assert bboxes.ndim == 3 + assert max_shape.size(0) == bboxes.size(0) + + min_xy = priors.new_tensor(0) + max_xy = torch.cat([max_shape, max_shape], + dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + return bboxes diff --git a/mmdet/models/task_modules/coders/yolo_bbox_coder.py b/mmdet/models/task_modules/coders/yolo_bbox_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1c766789bec844ff359e225435bc3b2f5dd736 --- /dev/null +++ b/mmdet/models/task_modules/coders/yolo_bbox_coder.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor +from .base_bbox_coder import BaseBBoxCoder + + +@TASK_UTILS.register_module() +class YOLOBBoxCoder(BaseBBoxCoder): + """YOLO BBox coder. + + Following `YOLO `_, this coder divide + image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh). + cx, cy in [0., 1.], denotes relative center position w.r.t the center of + bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`. + + Args: + eps (float): Min value of cx, cy when encoding. + """ + + def __init__(self, eps: float = 1e-6, **kwargs): + super().__init__(**kwargs) + self.eps = eps + + def encode(self, bboxes: Union[Tensor, BaseBoxes], + gt_bboxes: Union[Tensor, BaseBoxes], + stride: Union[Tensor, int]) -> Tensor: + """Get box regression transformation deltas that can be used to + transform the ``bboxes`` into the ``gt_bboxes``. + + Args: + bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, + e.g., anchors. + gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the + transformation, e.g., ground-truth boxes. + stride (torch.Tensor | int): Stride of bboxes. + + Returns: + torch.Tensor: Box transformation deltas + """ + bboxes = get_box_tensor(bboxes) + gt_bboxes = get_box_tensor(gt_bboxes) + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5 + y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5 + w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0] + h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1] + x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5 + y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5 + w = bboxes[..., 2] - bboxes[..., 0] + h = bboxes[..., 3] - bboxes[..., 1] + w_target = torch.log((w_gt / w).clamp(min=self.eps)) + h_target = torch.log((h_gt / h).clamp(min=self.eps)) + x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp( + self.eps, 1 - self.eps) + y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp( + self.eps, 1 - self.eps) + encoded_bboxes = torch.stack( + [x_center_target, y_center_target, w_target, h_target], dim=-1) + return encoded_bboxes + + def decode(self, bboxes: Union[Tensor, BaseBoxes], pred_bboxes: Tensor, + stride: Union[Tensor, int]) -> Union[Tensor, BaseBoxes]: + """Apply transformation `pred_bboxes` to `boxes`. + + Args: + boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes, + e.g. anchors. + pred_bboxes (torch.Tensor): Encoded boxes with shape + stride (torch.Tensor | int): Strides of bboxes. + + Returns: + Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + """ + bboxes = get_box_tensor(bboxes) + assert pred_bboxes.size(-1) == bboxes.size(-1) == 4 + xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + ( + pred_bboxes[..., :2] - 0.5) * stride + whs = (bboxes[..., 2:] - + bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp() + decoded_bboxes = torch.stack( + (xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] - + whs[..., 1], xy_centers[..., 0] + whs[..., 0], + xy_centers[..., 1] + whs[..., 1]), + dim=-1) + + if self.use_box_type: + decoded_bboxes = HorizontalBoxes(decoded_bboxes) + return decoded_bboxes diff --git a/mmdet/models/task_modules/prior_generators/__init__.py b/mmdet/models/task_modules/prior_generators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7795e98ca77bb5ffc77ff1da848130717d8f85a6 --- /dev/null +++ b/mmdet/models/task_modules/prior_generators/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator, + SSDAnchorGenerator, YOLOAnchorGenerator) +from .point_generator import MlvlPointGenerator, PointGenerator +from .utils import anchor_inside_flags, calc_region + +__all__ = [ + 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags', + 'PointGenerator', 'calc_region', 'YOLOAnchorGenerator', + 'MlvlPointGenerator', 'SSDAnchorGenerator' +] diff --git a/mmdet/models/task_modules/prior_generators/anchor_generator.py b/mmdet/models/task_modules/prior_generators/anchor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2757697ce2283ec8b46ba89325e63fad0be4a7e8 --- /dev/null +++ b/mmdet/models/task_modules/prior_generators/anchor_generator.py @@ -0,0 +1,848 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.utils import is_tuple_of +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import HorizontalBoxes + +DeviceType = Union[str, torch.device] + + +@TASK_UTILS.register_module() +class AnchorGenerator: + """Standard anchor generator for 2D anchor-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + ratios (list[float]): The list of ratios between the height and width + of anchors in a single level. + scales (list[int], Optional): Anchor scales for anchors + in a single level. It cannot be set at the same time + if `octave_base_scale` and `scales_per_octave` are set. + base_sizes (list[int], Optional): The basic sizes + of anchors in multiple levels. + If None is given, strides will be used as base_sizes. + (If strides are non square, the shortest stride is taken.) + scale_major (bool): Whether to multiply scales first when generating + base anchors. If true, the anchors in the same row will have the + same scales. By default it is True in V2.0 + octave_base_scale (int, Optional): The base scale of octave. + scales_per_octave (int, Optional): Number of scales for each octave. + `octave_base_scale` and `scales_per_octave` are usually used in + retinanet and the `scales` should be None when they are set. + centers (list[tuple[float]], Optional): The centers of the anchor + relative to the feature grid center in multiple feature levels. + By default it is set to be None and not used. If a list of tuple of + float is given, they will be used to shift the centers of anchors. + center_offset (float): The offset of center in proportion to anchors' + width and height. By default it is 0 in V2.0. + use_box_type (bool): Whether to warp anchors with the box type data + structure. Defaults to False. + + Examples: + >>> from mmdet.models.task_modules. + ... prior_generators import AnchorGenerator + >>> self = AnchorGenerator([16], [1.], [1.], [9]) + >>> all_anchors = self.grid_priors([(2, 2)], device='cpu') + >>> print(all_anchors) + [tensor([[-4.5000, -4.5000, 4.5000, 4.5000], + [11.5000, -4.5000, 20.5000, 4.5000], + [-4.5000, 11.5000, 4.5000, 20.5000], + [11.5000, 11.5000, 20.5000, 20.5000]])] + >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18]) + >>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu') + >>> print(all_anchors) + [tensor([[-4.5000, -4.5000, 4.5000, 4.5000], + [11.5000, -4.5000, 20.5000, 4.5000], + [-4.5000, 11.5000, 4.5000, 20.5000], + [11.5000, 11.5000, 20.5000, 20.5000]]), \ + tensor([[-9., -9., 9., 9.]])] + """ + + def __init__(self, + strides: Union[List[int], List[Tuple[int, int]]], + ratios: List[float], + scales: Optional[List[int]] = None, + base_sizes: Optional[List[int]] = None, + scale_major: bool = True, + octave_base_scale: Optional[int] = None, + scales_per_octave: Optional[int] = None, + centers: Optional[List[Tuple[float, float]]] = None, + center_offset: float = 0., + use_box_type: bool = False) -> None: + # check center and center_offset + if center_offset != 0: + assert centers is None, 'center cannot be set when center_offset' \ + f'!=0, {centers} is given.' + if not (0 <= center_offset <= 1): + raise ValueError('center_offset should be in range [0, 1], ' + f'{center_offset} is given.') + if centers is not None: + assert len(centers) == len(strides), \ + 'The number of strides should be the same as centers, got ' \ + f'{strides} and {centers}' + + # calculate base sizes of anchors + self.strides = [_pair(stride) for stride in strides] + self.base_sizes = [min(stride) for stride in self.strides + ] if base_sizes is None else base_sizes + assert len(self.base_sizes) == len(self.strides), \ + 'The number of strides should be the same as base sizes, got ' \ + f'{self.strides} and {self.base_sizes}' + + # calculate scales of anchors + assert ((octave_base_scale is not None + and scales_per_octave is not None) ^ (scales is not None)), \ + 'scales and octave_base_scale with scales_per_octave cannot' \ + ' be set at the same time' + if scales is not None: + self.scales = torch.Tensor(scales) + elif octave_base_scale is not None and scales_per_octave is not None: + octave_scales = np.array( + [2**(i / scales_per_octave) for i in range(scales_per_octave)]) + scales = octave_scales * octave_base_scale + self.scales = torch.Tensor(scales) + else: + raise ValueError('Either scales or octave_base_scale with ' + 'scales_per_octave should be set') + + self.octave_base_scale = octave_base_scale + self.scales_per_octave = scales_per_octave + self.ratios = torch.Tensor(ratios) + self.scale_major = scale_major + self.centers = centers + self.center_offset = center_offset + self.base_anchors = self.gen_base_anchors() + self.use_box_type = use_box_type + + @property + def num_base_anchors(self) -> List[int]: + """list[int]: total number of base anchors in a feature grid""" + return self.num_base_priors + + @property + def num_base_priors(self) -> List[int]: + """list[int]: The number of priors (anchors) at a point + on the feature grid""" + return [base_anchors.size(0) for base_anchors in self.base_anchors] + + @property + def num_levels(self) -> int: + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + def gen_base_anchors(self) -> List[Tensor]: + """Generate base anchors. + + Returns: + list(torch.Tensor): Base anchors of a feature grid in multiple \ + feature levels. + """ + multi_level_base_anchors = [] + for i, base_size in enumerate(self.base_sizes): + center = None + if self.centers is not None: + center = self.centers[i] + multi_level_base_anchors.append( + self.gen_single_level_base_anchors( + base_size, + scales=self.scales, + ratios=self.ratios, + center=center)) + return multi_level_base_anchors + + def gen_single_level_base_anchors(self, + base_size: Union[int, float], + scales: Tensor, + ratios: Tensor, + center: Optional[Tuple[float]] = None) \ + -> Tensor: + """Generate base anchors of a single level. + + Args: + base_size (int | float): Basic size of an anchor. + scales (torch.Tensor): Scales of the anchor. + ratios (torch.Tensor): The ratio between the height + and width of anchors in a single level. + center (tuple[float], optional): The center of the base anchor + related to a single feature grid. Defaults to None. + + Returns: + torch.Tensor: Anchors in a single-level feature maps. + """ + w = base_size + h = base_size + if center is None: + x_center = self.center_offset * w + y_center = self.center_offset * h + else: + x_center, y_center = center + + h_ratios = torch.sqrt(ratios) + w_ratios = 1 / h_ratios + if self.scale_major: + ws = (w * w_ratios[:, None] * scales[None, :]).view(-1) + hs = (h * h_ratios[:, None] * scales[None, :]).view(-1) + else: + ws = (w * scales[:, None] * w_ratios[None, :]).view(-1) + hs = (h * scales[:, None] * h_ratios[None, :]).view(-1) + + # use float anchor and the anchor's center is aligned with the + # pixel center + base_anchors = [ + x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws, + y_center + 0.5 * hs + ] + base_anchors = torch.stack(base_anchors, dim=-1) + + return base_anchors + + def _meshgrid(self, + x: Tensor, + y: Tensor, + row_major: bool = True) -> Tuple[Tensor]: + """Generate mesh grid of x and y. + + Args: + x (torch.Tensor): Grids of x dimension. + y (torch.Tensor): Grids of y dimension. + row_major (bool): Whether to return y grids first. + Defaults to True. + + Returns: + tuple[torch.Tensor]: The mesh grids of x and y. + """ + # use shape instead of len to keep tracing while exporting to onnx + xx = x.repeat(y.shape[0]) + yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1) + if row_major: + return xx, yy + else: + return yy, xx + + def grid_priors(self, + featmap_sizes: List[Tuple], + dtype: torch.dtype = torch.float32, + device: DeviceType = 'cuda') -> List[Tensor]: + """Generate grid anchors in multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels. + dtype (:obj:`torch.dtype`): Dtype of priors. + Defaults to torch.float32. + device (str | torch.device): The device where the anchors + will be put on. + + Return: + list[torch.Tensor]: Anchors in multiple feature levels. \ + The sizes of each tensor should be [N, 4], where \ + N = width * height * num_base_anchors, width and height \ + are the sizes of the corresponding feature level, \ + num_base_anchors is the number of anchors for that level. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_anchors = [] + for i in range(self.num_levels): + anchors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, device=device) + multi_level_anchors.append(anchors) + return multi_level_anchors + + def single_level_grid_priors(self, + featmap_size: Tuple[int, int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: DeviceType = 'cuda') -> Tensor: + """Generate grid anchors of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int, int]): Size of the feature maps. + level_idx (int): The index of corresponding feature map level. + dtype (obj:`torch.dtype`): Date type of points.Defaults to + ``torch.float32``. + device (str | torch.device): The device the tensor will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: Anchors in the overall feature maps. + """ + + base_anchors = self.base_anchors[level_idx].to(device).to(dtype) + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + # First create Range with the default dtype, than convert to + # target `dtype` for onnx exporting. + shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w + shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h + + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) + # first feat_w elements correspond to the first row of shifts + # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get + # shifted anchors (K, A, 4), reshape to (K*A, 4) + + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + all_anchors = all_anchors.view(-1, 4) + # first A rows correspond to A anchors of (0, 0) in feature map, + # then (0, 1), (0, 2), ... + if self.use_box_type: + all_anchors = HorizontalBoxes(all_anchors) + return all_anchors + + def sparse_priors(self, + prior_idxs: Tensor, + featmap_size: Tuple[int, int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: DeviceType = 'cuda') -> Tensor: + """Generate sparse anchors according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int, int]): feature map size arrange as (h, w). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points.Defaults to + ``torch.float32``. + device (str | torch.device): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 4), N should be equal to + the length of ``prior_idxs``. + """ + + height, width = featmap_size + num_base_anchors = self.num_base_anchors[level_idx] + base_anchor_id = prior_idxs % num_base_anchors + x = (prior_idxs // + num_base_anchors) % width * self.strides[level_idx][0] + y = (prior_idxs // width // + num_base_anchors) % height * self.strides[level_idx][1] + priors = torch.stack([x, y, x, y], 1).to(dtype).to(device) + \ + self.base_anchors[level_idx][base_anchor_id, :].to(device) + + return priors + + def grid_anchors(self, + featmap_sizes: List[Tuple], + device: DeviceType = 'cuda') -> List[Tensor]: + """Generate grid anchors in multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels. + device (str | torch.device): Device where the anchors will be + put on. + + Return: + list[torch.Tensor]: Anchors in multiple feature levels. \ + The sizes of each tensor should be [N, 4], where \ + N = width * height * num_base_anchors, width and height \ + are the sizes of the corresponding feature level, \ + num_base_anchors is the number of anchors for that level. + """ + warnings.warn('``grid_anchors`` would be deprecated soon. ' + 'Please use ``grid_priors`` ') + + assert self.num_levels == len(featmap_sizes) + multi_level_anchors = [] + for i in range(self.num_levels): + anchors = self.single_level_grid_anchors( + self.base_anchors[i].to(device), + featmap_sizes[i], + self.strides[i], + device=device) + multi_level_anchors.append(anchors) + return multi_level_anchors + + def single_level_grid_anchors(self, + base_anchors: Tensor, + featmap_size: Tuple[int, int], + stride: Tuple[int, int] = (16, 16), + device: DeviceType = 'cuda') -> Tensor: + """Generate grid anchors of a single level. + + Note: + This function is usually called by method ``self.grid_anchors``. + + Args: + base_anchors (torch.Tensor): The base anchors of a feature grid. + featmap_size (tuple[int]): Size of the feature maps. + stride (tuple[int, int]): Stride of the feature map in order + (w, h). Defaults to (16, 16). + device (str | torch.device): Device the tensor will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: Anchors in the overall feature maps. + """ + + warnings.warn( + '``single_level_grid_anchors`` would be deprecated soon. ' + 'Please use ``single_level_grid_priors`` ') + + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + feat_h, feat_w = featmap_size + shift_x = torch.arange(0, feat_w, device=device) * stride[0] + shift_y = torch.arange(0, feat_h, device=device) * stride[1] + + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) + shifts = shifts.type_as(base_anchors) + # first feat_w elements correspond to the first row of shifts + # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get + # shifted anchors (K, A, 4), reshape to (K*A, 4) + + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + all_anchors = all_anchors.view(-1, 4) + # first A rows correspond to A anchors of (0, 0) in feature map, + # then (0, 1), (0, 2), ... + return all_anchors + + def valid_flags(self, + featmap_sizes: List[Tuple[int, int]], + pad_shape: Tuple, + device: DeviceType = 'cuda') -> List[Tensor]: + """Generate valid flags of anchors in multiple feature levels. + + Args: + featmap_sizes (list(tuple[int, int])): List of feature map sizes in + multiple feature levels. + pad_shape (tuple): The padded shape of the image. + device (str | torch.device): Device where the anchors will be + put on. + + Return: + list(torch.Tensor): Valid flags of anchors in multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + anchor_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), + (valid_feat_h, valid_feat_w), + self.num_base_anchors[i], + device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size: Tuple[int, int], + valid_size: Tuple[int, int], + num_base_anchors: int, + device: DeviceType = 'cuda') -> Tensor: + """Generate the valid flags of anchor in a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + num_base_anchors (int): The number of base anchors. + device (str | torch.device): Device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each anchor in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + valid = valid[:, None].expand(valid.size(0), + num_base_anchors).contiguous().view(-1) + return valid + + def __repr__(self) -> str: + """str: a string that describes the module""" + indent_str = ' ' + repr_str = self.__class__.__name__ + '(\n' + repr_str += f'{indent_str}strides={self.strides},\n' + repr_str += f'{indent_str}ratios={self.ratios},\n' + repr_str += f'{indent_str}scales={self.scales},\n' + repr_str += f'{indent_str}base_sizes={self.base_sizes},\n' + repr_str += f'{indent_str}scale_major={self.scale_major},\n' + repr_str += f'{indent_str}octave_base_scale=' + repr_str += f'{self.octave_base_scale},\n' + repr_str += f'{indent_str}scales_per_octave=' + repr_str += f'{self.scales_per_octave},\n' + repr_str += f'{indent_str}num_levels={self.num_levels}\n' + repr_str += f'{indent_str}centers={self.centers},\n' + repr_str += f'{indent_str}center_offset={self.center_offset})' + return repr_str + + +@TASK_UTILS.register_module() +class SSDAnchorGenerator(AnchorGenerator): + """Anchor generator for SSD. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels. + ratios (list[float]): The list of ratios between the height and width + of anchors in a single level. + min_sizes (list[float]): The list of minimum anchor sizes on each + level. + max_sizes (list[float]): The list of maximum anchor sizes on each + level. + basesize_ratio_range (tuple(float)): Ratio range of anchors. Being + used when not setting min_sizes and max_sizes. + input_size (int): Size of feature map, 300 for SSD300, 512 for + SSD512. Being used when not setting min_sizes and max_sizes. + scale_major (bool): Whether to multiply scales first when generating + base anchors. If true, the anchors in the same row will have the + same scales. It is always set to be False in SSD. + use_box_type (bool): Whether to warp anchors with the box type data + structure. Defaults to False. + """ + + def __init__(self, + strides: Union[List[int], List[Tuple[int, int]]], + ratios: List[float], + min_sizes: Optional[List[float]] = None, + max_sizes: Optional[List[float]] = None, + basesize_ratio_range: Tuple[float] = (0.15, 0.9), + input_size: int = 300, + scale_major: bool = True, + use_box_type: bool = False) -> None: + assert len(strides) == len(ratios) + assert not (min_sizes is None) ^ (max_sizes is None) + self.strides = [_pair(stride) for stride in strides] + self.centers = [(stride[0] / 2., stride[1] / 2.) + for stride in self.strides] + + if min_sizes is None and max_sizes is None: + # use hard code to generate SSD anchors + self.input_size = input_size + assert is_tuple_of(basesize_ratio_range, float) + self.basesize_ratio_range = basesize_ratio_range + # calculate anchor ratios and sizes + min_ratio, max_ratio = basesize_ratio_range + min_ratio = int(min_ratio * 100) + max_ratio = int(max_ratio * 100) + step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2)) + min_sizes = [] + max_sizes = [] + for ratio in range(int(min_ratio), int(max_ratio) + 1, step): + min_sizes.append(int(self.input_size * ratio / 100)) + max_sizes.append(int(self.input_size * (ratio + step) / 100)) + if self.input_size == 300: + if basesize_ratio_range[0] == 0.15: # SSD300 COCO + min_sizes.insert(0, int(self.input_size * 7 / 100)) + max_sizes.insert(0, int(self.input_size * 15 / 100)) + elif basesize_ratio_range[0] == 0.2: # SSD300 VOC + min_sizes.insert(0, int(self.input_size * 10 / 100)) + max_sizes.insert(0, int(self.input_size * 20 / 100)) + else: + raise ValueError( + 'basesize_ratio_range[0] should be either 0.15' + 'or 0.2 when input_size is 300, got ' + f'{basesize_ratio_range[0]}.') + elif self.input_size == 512: + if basesize_ratio_range[0] == 0.1: # SSD512 COCO + min_sizes.insert(0, int(self.input_size * 4 / 100)) + max_sizes.insert(0, int(self.input_size * 10 / 100)) + elif basesize_ratio_range[0] == 0.15: # SSD512 VOC + min_sizes.insert(0, int(self.input_size * 7 / 100)) + max_sizes.insert(0, int(self.input_size * 15 / 100)) + else: + raise ValueError( + 'When not setting min_sizes and max_sizes,' + 'basesize_ratio_range[0] should be either 0.1' + 'or 0.15 when input_size is 512, got' + f' {basesize_ratio_range[0]}.') + else: + raise ValueError( + 'Only support 300 or 512 in SSDAnchorGenerator when ' + 'not setting min_sizes and max_sizes, ' + f'got {self.input_size}.') + + assert len(min_sizes) == len(max_sizes) == len(strides) + + anchor_ratios = [] + anchor_scales = [] + for k in range(len(self.strides)): + scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])] + anchor_ratio = [1.] + for r in ratios[k]: + anchor_ratio += [1 / r, r] # 4 or 6 ratio + anchor_ratios.append(torch.Tensor(anchor_ratio)) + anchor_scales.append(torch.Tensor(scales)) + + self.base_sizes = min_sizes + self.scales = anchor_scales + self.ratios = anchor_ratios + self.scale_major = scale_major + self.center_offset = 0 + self.base_anchors = self.gen_base_anchors() + self.use_box_type = use_box_type + + def gen_base_anchors(self) -> List[Tensor]: + """Generate base anchors. + + Returns: + list(torch.Tensor): Base anchors of a feature grid in multiple \ + feature levels. + """ + multi_level_base_anchors = [] + for i, base_size in enumerate(self.base_sizes): + base_anchors = self.gen_single_level_base_anchors( + base_size, + scales=self.scales[i], + ratios=self.ratios[i], + center=self.centers[i]) + indices = list(range(len(self.ratios[i]))) + indices.insert(1, len(indices)) + base_anchors = torch.index_select(base_anchors, 0, + torch.LongTensor(indices)) + multi_level_base_anchors.append(base_anchors) + return multi_level_base_anchors + + def __repr__(self) -> str: + """str: a string that describes the module""" + indent_str = ' ' + repr_str = self.__class__.__name__ + '(\n' + repr_str += f'{indent_str}strides={self.strides},\n' + repr_str += f'{indent_str}scales={self.scales},\n' + repr_str += f'{indent_str}scale_major={self.scale_major},\n' + repr_str += f'{indent_str}input_size={self.input_size},\n' + repr_str += f'{indent_str}scales={self.scales},\n' + repr_str += f'{indent_str}ratios={self.ratios},\n' + repr_str += f'{indent_str}num_levels={self.num_levels},\n' + repr_str += f'{indent_str}base_sizes={self.base_sizes},\n' + repr_str += f'{indent_str}basesize_ratio_range=' + repr_str += f'{self.basesize_ratio_range})' + return repr_str + + +@TASK_UTILS.register_module() +class LegacyAnchorGenerator(AnchorGenerator): + """Legacy anchor generator used in MMDetection V1.x. + + Note: + Difference to the V2.0 anchor generator: + + 1. The center offset of V1.x anchors are set to be 0.5 rather than 0. + 2. The width/height are minused by 1 when calculating the anchors' \ + centers and corners to meet the V1.x coordinate system. + 3. The anchors' corners are quantized. + + Args: + strides (list[int] | list[tuple[int]]): Strides of anchors + in multiple feature levels. + ratios (list[float]): The list of ratios between the height and width + of anchors in a single level. + scales (list[int] | None): Anchor scales for anchors in a single level. + It cannot be set at the same time if `octave_base_scale` and + `scales_per_octave` are set. + base_sizes (list[int]): The basic sizes of anchors in multiple levels. + If None is given, strides will be used to generate base_sizes. + scale_major (bool): Whether to multiply scales first when generating + base anchors. If true, the anchors in the same row will have the + same scales. By default it is True in V2.0 + octave_base_scale (int): The base scale of octave. + scales_per_octave (int): Number of scales for each octave. + `octave_base_scale` and `scales_per_octave` are usually used in + retinanet and the `scales` should be None when they are set. + centers (list[tuple[float, float]] | None): The centers of the anchor + relative to the feature grid center in multiple feature levels. + By default it is set to be None and not used. It a list of float + is given, this list will be used to shift the centers of anchors. + center_offset (float): The offset of center in proportion to anchors' + width and height. By default it is 0.5 in V2.0 but it should be 0.5 + in v1.x models. + use_box_type (bool): Whether to warp anchors with the box type data + structure. Defaults to False. + + Examples: + >>> from mmdet.models.task_modules. + ... prior_generators import LegacyAnchorGenerator + >>> self = LegacyAnchorGenerator( + >>> [16], [1.], [1.], [9], center_offset=0.5) + >>> all_anchors = self.grid_anchors(((2, 2),), device='cpu') + >>> print(all_anchors) + [tensor([[ 0., 0., 8., 8.], + [16., 0., 24., 8.], + [ 0., 16., 8., 24.], + [16., 16., 24., 24.]])] + """ + + def gen_single_level_base_anchors(self, + base_size: Union[int, float], + scales: Tensor, + ratios: Tensor, + center: Optional[Tuple[float]] = None) \ + -> Tensor: + """Generate base anchors of a single level. + + Note: + The width/height of anchors are minused by 1 when calculating \ + the centers and corners to meet the V1.x coordinate system. + + Args: + base_size (int | float): Basic size of an anchor. + scales (torch.Tensor): Scales of the anchor. + ratios (torch.Tensor): The ratio between the height. + and width of anchors in a single level. + center (tuple[float], optional): The center of the base anchor + related to a single feature grid. Defaults to None. + + Returns: + torch.Tensor: Anchors in a single-level feature map. + """ + w = base_size + h = base_size + if center is None: + x_center = self.center_offset * (w - 1) + y_center = self.center_offset * (h - 1) + else: + x_center, y_center = center + + h_ratios = torch.sqrt(ratios) + w_ratios = 1 / h_ratios + if self.scale_major: + ws = (w * w_ratios[:, None] * scales[None, :]).view(-1) + hs = (h * h_ratios[:, None] * scales[None, :]).view(-1) + else: + ws = (w * scales[:, None] * w_ratios[None, :]).view(-1) + hs = (h * scales[:, None] * h_ratios[None, :]).view(-1) + + # use float anchor and the anchor's center is aligned with the + # pixel center + base_anchors = [ + x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1), + x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1) + ] + base_anchors = torch.stack(base_anchors, dim=-1).round() + + return base_anchors + + +@TASK_UTILS.register_module() +class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator): + """Legacy anchor generator used in MMDetection V1.x. + + The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator` + can be found in `LegacyAnchorGenerator`. + """ + + def __init__(self, + strides: Union[List[int], List[Tuple[int, int]]], + ratios: List[float], + basesize_ratio_range: Tuple[float], + input_size: int = 300, + scale_major: bool = True, + use_box_type: bool = False) -> None: + super(LegacySSDAnchorGenerator, self).__init__( + strides=strides, + ratios=ratios, + basesize_ratio_range=basesize_ratio_range, + input_size=input_size, + scale_major=scale_major, + use_box_type=use_box_type) + self.centers = [((stride - 1) / 2., (stride - 1) / 2.) + for stride in strides] + self.base_anchors = self.gen_base_anchors() + + +@TASK_UTILS.register_module() +class YOLOAnchorGenerator(AnchorGenerator): + """Anchor generator for YOLO. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels. + base_sizes (list[list[tuple[int, int]]]): The basic sizes + of anchors in multiple levels. + """ + + def __init__(self, + strides: Union[List[int], List[Tuple[int, int]]], + base_sizes: List[List[Tuple[int, int]]], + use_box_type: bool = False) -> None: + self.strides = [_pair(stride) for stride in strides] + self.centers = [(stride[0] / 2., stride[1] / 2.) + for stride in self.strides] + self.base_sizes = [] + num_anchor_per_level = len(base_sizes[0]) + for base_sizes_per_level in base_sizes: + assert num_anchor_per_level == len(base_sizes_per_level) + self.base_sizes.append( + [_pair(base_size) for base_size in base_sizes_per_level]) + self.base_anchors = self.gen_base_anchors() + self.use_box_type = use_box_type + + @property + def num_levels(self) -> int: + """int: number of feature levels that the generator will be applied""" + return len(self.base_sizes) + + def gen_base_anchors(self) -> List[Tensor]: + """Generate base anchors. + + Returns: + list(torch.Tensor): Base anchors of a feature grid in multiple \ + feature levels. + """ + multi_level_base_anchors = [] + for i, base_sizes_per_level in enumerate(self.base_sizes): + center = None + if self.centers is not None: + center = self.centers[i] + multi_level_base_anchors.append( + self.gen_single_level_base_anchors(base_sizes_per_level, + center)) + return multi_level_base_anchors + + def gen_single_level_base_anchors(self, + base_sizes_per_level: List[Tuple[int]], + center: Optional[Tuple[float]] = None) \ + -> Tensor: + """Generate base anchors of a single level. + + Args: + base_sizes_per_level (list[tuple[int]]): Basic sizes of + anchors. + center (tuple[float], optional): The center of the base anchor + related to a single feature grid. Defaults to None. + + Returns: + torch.Tensor: Anchors in a single-level feature maps. + """ + x_center, y_center = center + base_anchors = [] + for base_size in base_sizes_per_level: + w, h = base_size + + # use float anchor and the anchor's center is aligned with the + # pixel center + base_anchor = torch.Tensor([ + x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w, + y_center + 0.5 * h + ]) + base_anchors.append(base_anchor) + base_anchors = torch.stack(base_anchors, dim=0) + + return base_anchors diff --git a/mmdet/models/task_modules/prior_generators/point_generator.py b/mmdet/models/task_modules/prior_generators/point_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c87ad656c61cb251bfdfcbd23b1cc5263c68bf5f --- /dev/null +++ b/mmdet/models/task_modules/prior_generators/point_generator.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.registry import TASK_UTILS + +DeviceType = Union[str, torch.device] + + +@TASK_UTILS.register_module() +class PointGenerator: + + def _meshgrid(self, + x: Tensor, + y: Tensor, + row_major: bool = True) -> Tuple[Tensor, Tensor]: + """Generate mesh grid of x and y. + + Args: + x (torch.Tensor): Grids of x dimension. + y (torch.Tensor): Grids of y dimension. + row_major (bool): Whether to return y grids first. + Defaults to True. + + Returns: + tuple[torch.Tensor]: The mesh grids of x and y. + """ + xx = x.repeat(len(y)) + yy = y.view(-1, 1).repeat(1, len(x)).view(-1) + if row_major: + return xx, yy + else: + return yy, xx + + def grid_points(self, + featmap_size: Tuple[int, int], + stride=16, + device: DeviceType = 'cuda') -> Tensor: + """Generate grid points of a single level. + + Args: + featmap_size (tuple[int, int]): Size of the feature maps. + stride (int): The stride of corresponding feature map. + device (str | torch.device): The device the tensor will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: grid point in a feature map. + """ + feat_h, feat_w = featmap_size + shift_x = torch.arange(0., feat_w, device=device) * stride + shift_y = torch.arange(0., feat_h, device=device) * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + stride = shift_x.new_full((shift_xx.shape[0], ), stride) + shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, + featmap_size: Tuple[int, int], + valid_size: Tuple[int, int], + device: DeviceType = 'cuda') -> Tensor: + """Generate valid flags of anchors in a feature map. + + Args: + featmap_sizes (list(tuple[int, int])): List of feature map sizes in + multiple feature levels. + valid_shape (tuple[int, int]): The valid shape of the image. + device (str | torch.device): Device where the anchors will be + put on. + + Return: + torch.Tensor: Valid flags of anchors in a level. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + +@TASK_UTILS.register_module() +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, + strides: Union[List[int], List[Tuple[int, int]]], + offset: float = 0.5) -> None: + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self) -> int: + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self) -> List[int]: + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, + x: Tensor, + y: Tensor, + row_major: bool = True) -> Tuple[Tensor, Tensor]: + yy, xx = torch.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, + featmap_sizes: List[Tuple], + dtype: torch.dtype = torch.float32, + device: DeviceType = 'cuda', + with_stride: bool = False) -> List[Tensor]: + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Defaults to torch.float32. + device (str | torch.device): The device where the anchors will be + put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], + level_idx=i, + dtype=dtype, + device=device, + with_stride=with_stride) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, + featmap_size: Tuple[int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: DeviceType = 'cuda', + with_stride: bool = False) -> Tensor: + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Defaults to torch.float32. + device (str | torch.device): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (torch.arange(0, feat_w, device=device) + + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (torch.arange(0, feat_h, device=device) + + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0], ), + stride_w).to(dtype) + stride_h = shift_xx.new_full((shift_yy.shape[0], ), + stride_h).to(dtype) + shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], + dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, + featmap_sizes: List[Tuple[int, int]], + pad_shape: Tuple[int], + device: DeviceType = 'cuda') -> List[Tensor]: + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str | torch.device): The device where the anchors will be + put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), + (valid_feat_h, valid_feat_w), + device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size: Tuple[int, int], + valid_size: Tuple[int, int], + device: DeviceType = 'cuda') -> Tensor: + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str | torch.device): The device where the flags will be + put on. Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, + prior_idxs: Tensor, + featmap_size: Tuple[int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: DeviceType = 'cuda') -> Tensor: + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (str | torch.device): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + + self.offset) * self.strides[level_idx][1] + prioris = torch.stack([x, y], 1).to(dtype) + prioris = prioris.to(device) + return prioris diff --git a/mmdet/models/task_modules/prior_generators/utils.py b/mmdet/models/task_modules/prior_generators/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3aa2dfd49669ba931d20ad9482cb841698cceb8a --- /dev/null +++ b/mmdet/models/task_modules/prior_generators/utils.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from mmdet.structures.bbox import BaseBoxes + + +def anchor_inside_flags(flat_anchors: Tensor, + valid_flags: Tensor, + img_shape: Tuple[int], + allowed_border: int = 0) -> Tensor: + """Check whether the anchors are inside the border. + + Args: + flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4). + valid_flags (torch.Tensor): An existing valid flags of anchors. + img_shape (tuple(int)): Shape of current image. + allowed_border (int): The border to allow the valid anchor. + Defaults to 0. + + Returns: + torch.Tensor: Flags indicating whether the anchors are inside a \ + valid range. + """ + img_h, img_w = img_shape[:2] + if allowed_border >= 0: + if isinstance(flat_anchors, BaseBoxes): + inside_flags = valid_flags & \ + flat_anchors.is_inside([img_h, img_w], + all_inside=True, + allowed_border=allowed_border) + else: + inside_flags = valid_flags & \ + (flat_anchors[:, 0] >= -allowed_border) & \ + (flat_anchors[:, 1] >= -allowed_border) & \ + (flat_anchors[:, 2] < img_w + allowed_border) & \ + (flat_anchors[:, 3] < img_h + allowed_border) + else: + inside_flags = valid_flags + return inside_flags + + +def calc_region(bbox: Tensor, + ratio: float, + featmap_size: Optional[Tuple] = None) -> Tuple[int]: + """Calculate a proportional bbox region. + + The bbox center are fixed and the new h' and w' is h * ratio and w * ratio. + + Args: + bbox (Tensor): Bboxes to calculate regions, shape (n, 4). + ratio (float): Ratio of the output region. + featmap_size (tuple, Optional): Feature map size in (height, width) + order used for clipping the boundary. Defaults to None. + + Returns: + tuple: x1, y1, x2, y2 + """ + x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long() + y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long() + x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long() + y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long() + if featmap_size is not None: + x1 = x1.clamp(min=0, max=featmap_size[1]) + y1 = y1.clamp(min=0, max=featmap_size[0]) + x2 = x2.clamp(min=0, max=featmap_size[1]) + y2 = y2.clamp(min=0, max=featmap_size[0]) + return (x1, y1, x2, y2) diff --git a/mmdet/models/task_modules/samplers/__init__.py b/mmdet/models/task_modules/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3782eb898cf8acace63b4f16204cae6c07eb6e30 --- /dev/null +++ b/mmdet/models/task_modules/samplers/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_sampler import BaseSampler +from .combined_sampler import CombinedSampler +from .instance_balanced_pos_sampler import InstanceBalancedPosSampler +from .iou_balanced_neg_sampler import IoUBalancedNegSampler +from .mask_pseudo_sampler import MaskPseudoSampler +from .mask_sampling_result import MaskSamplingResult +from .multi_instance_random_sampler import MultiInsRandomSampler +from .multi_instance_sampling_result import MultiInstanceSamplingResult +from .ohem_sampler import OHEMSampler +from .pseudo_sampler import PseudoSampler +from .random_sampler import RandomSampler +from .sampling_result import SamplingResult +from .score_hlr_sampler import ScoreHLRSampler + +__all__ = [ + 'BaseSampler', 'PseudoSampler', 'RandomSampler', + 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', + 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler', + 'MaskSamplingResult', 'MultiInstanceSamplingResult', + 'MultiInsRandomSampler' +] diff --git a/mmdet/models/task_modules/samplers/base_sampler.py b/mmdet/models/task_modules/samplers/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..be8a9a5ee3ec4e70b19aeea21b7998cf2b131d59 --- /dev/null +++ b/mmdet/models/task_modules/samplers/base_sampler.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import torch +from mmengine.structures import InstanceData + +from mmdet.structures.bbox import BaseBoxes, cat_boxes +from ..assigners import AssignResult +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + """Base class of samplers. + + Args: + num (int): Number of samples + pos_fraction (float): Fraction of positive samples + neg_pos_up (int): Upper bound number of negative and + positive samples. Defaults to -1. + add_gt_as_proposals (bool): Whether to add ground truth + boxes as proposals. Defaults to True. + """ + + def __init__(self, + num: int, + pos_fraction: float, + neg_pos_ub: int = -1, + add_gt_as_proposals: bool = True, + **kwargs) -> None: + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result: AssignResult, num_expected: int, + **kwargs): + """Sample positive samples.""" + pass + + @abstractmethod + def _sample_neg(self, assign_result: AssignResult, num_expected: int, + **kwargs): + """Sample negative samples.""" + pass + + def sample(self, assign_result: AssignResult, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> SamplingResult: + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Assigning results. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + + Returns: + :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmengine.structures import InstanceData + >>> from mmdet.models.task_modules.samplers import RandomSampler, + >>> from mmdet.models.task_modules.assigners import AssignResult + >>> from mmdet.models.task_modules.samplers. + ... sampling_result import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> pred_instances = InstanceData() + >>> pred_instances.priors = random_boxes(assign_result.num_preds, + ... rng=rng) + >>> gt_instances = InstanceData() + >>> gt_instances.bboxes = random_boxes(assign_result.num_gts, + ... rng=rng) + >>> gt_instances.labels = torch.randint( + ... 0, 5, (assign_result.num_gts,), dtype=torch.long) + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, pred_instances, gt_instances) + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + gt_labels = gt_instances.labels + if len(priors.shape) < 2: + priors = priors[None, :] + + gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + # When `gt_bboxes` and `priors` are all box type, convert + # `gt_bboxes` type to `priors` type. + if (isinstance(gt_bboxes, BaseBoxes) + and isinstance(priors, BaseBoxes)): + gt_bboxes_ = gt_bboxes.convert_to(type(priors)) + else: + gt_bboxes_ = gt_bboxes + priors = cat_boxes([gt_bboxes_, priors], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos( + assign_result, num_expected_pos, bboxes=priors, **kwargs) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg( + assign_result, num_expected_neg, bboxes=priors, **kwargs) + neg_inds = neg_inds.unique() + + sampling_result = SamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + gt_bboxes=gt_bboxes, + assign_result=assign_result, + gt_flags=gt_flags) + return sampling_result diff --git a/mmdet/models/task_modules/samplers/combined_sampler.py b/mmdet/models/task_modules/samplers/combined_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0560e372efffe865fa32028d823280a8bd5d87 --- /dev/null +++ b/mmdet/models/task_modules/samplers/combined_sampler.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import TASK_UTILS +from .base_sampler import BaseSampler + + +@TASK_UTILS.register_module() +class CombinedSampler(BaseSampler): + """A sampler that combines positive sampler and negative sampler.""" + + def __init__(self, pos_sampler, neg_sampler, **kwargs): + super(CombinedSampler, self).__init__(**kwargs) + self.pos_sampler = TASK_UTILS.build(pos_sampler, default_args=kwargs) + self.neg_sampler = TASK_UTILS.build(neg_sampler, default_args=kwargs) + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError diff --git a/mmdet/models/task_modules/samplers/instance_balanced_pos_sampler.py b/mmdet/models/task_modules/samplers/instance_balanced_pos_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e48d8e9158e8dabf0bb4072b8e421de9b6410d00 --- /dev/null +++ b/mmdet/models/task_modules/samplers/instance_balanced_pos_sampler.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from mmdet.registry import TASK_UTILS +from .random_sampler import RandomSampler + + +@TASK_UTILS.register_module() +class InstanceBalancedPosSampler(RandomSampler): + """Instance balanced sampler that samples equal number of positive samples + for each instance.""" + + def _sample_pos(self, assign_result, num_expected, **kwargs): + """Sample positive boxes. + + Args: + assign_result (:obj:`AssignResult`): The assigned results of boxes. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + else: + unique_gt_inds = assign_result.gt_inds[pos_inds].unique() + num_gts = len(unique_gt_inds) + num_per_gt = int(round(num_expected / float(num_gts)) + 1) + sampled_inds = [] + for i in unique_gt_inds: + inds = torch.nonzero( + assign_result.gt_inds == i.item(), as_tuple=False) + if inds.numel() != 0: + inds = inds.squeeze(1) + else: + continue + if len(inds) > num_per_gt: + inds = self.random_choice(inds, num_per_gt) + sampled_inds.append(inds) + sampled_inds = torch.cat(sampled_inds) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array( + list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) + if len(extra_inds) > num_extra: + extra_inds = self.random_choice(extra_inds, num_extra) + extra_inds = torch.from_numpy(extra_inds).to( + assign_result.gt_inds.device).long() + sampled_inds = torch.cat([sampled_inds, extra_inds]) + elif len(sampled_inds) > num_expected: + sampled_inds = self.random_choice(sampled_inds, num_expected) + return sampled_inds diff --git a/mmdet/models/task_modules/samplers/iou_balanced_neg_sampler.py b/mmdet/models/task_modules/samplers/iou_balanced_neg_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1f46413c99d115f31ef190b4fb198b588a156e --- /dev/null +++ b/mmdet/models/task_modules/samplers/iou_balanced_neg_sampler.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from mmdet.registry import TASK_UTILS +from .random_sampler import RandomSampler + + +@TASK_UTILS.register_module() +class IoUBalancedNegSampler(RandomSampler): + """IoU Balanced Sampling. + + arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) + + Sampling proposals according to their IoU. `floor_fraction` of needed RoIs + are sampled from proposals whose IoU are lower than `floor_thr` randomly. + The others are sampled from proposals whose IoU are higher than + `floor_thr`. These proposals are sampled from some bins evenly, which are + split by `num_bins` via IoU evenly. + + Args: + num (int): number of proposals. + pos_fraction (float): fraction of positive proposals. + floor_thr (float): threshold (minimum) IoU for IoU balanced sampling, + set to -1 if all using IoU balanced sampling. + floor_fraction (float): sampling fraction of proposals under floor_thr. + num_bins (int): number of bins in IoU balanced sampling. + """ + + def __init__(self, + num, + pos_fraction, + floor_thr=-1, + floor_fraction=0, + num_bins=3, + **kwargs): + super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, + **kwargs) + assert floor_thr >= 0 or floor_thr == -1 + assert 0 <= floor_fraction <= 1 + assert num_bins >= 1 + + self.floor_thr = floor_thr + self.floor_fraction = floor_fraction + self.num_bins = num_bins + + def sample_via_interval(self, max_overlaps, full_set, num_expected): + """Sample according to the iou interval. + + Args: + max_overlaps (torch.Tensor): IoU between bounding boxes and ground + truth boxes. + full_set (set(int)): A full set of indices of boxes。 + num_expected (int): Number of expected samples。 + + Returns: + np.ndarray: Indices of samples + """ + max_iou = max_overlaps.max() + iou_interval = (max_iou - self.floor_thr) / self.num_bins + per_num_expected = int(num_expected / self.num_bins) + + sampled_inds = [] + for i in range(self.num_bins): + start_iou = self.floor_thr + i * iou_interval + end_iou = self.floor_thr + (i + 1) * iou_interval + tmp_set = set( + np.where( + np.logical_and(max_overlaps >= start_iou, + max_overlaps < end_iou))[0]) + tmp_inds = list(tmp_set & full_set) + if len(tmp_inds) > per_num_expected: + tmp_sampled_set = self.random_choice(tmp_inds, + per_num_expected) + else: + tmp_sampled_set = np.array(tmp_inds, dtype=np.int64) + sampled_inds.append(tmp_sampled_set) + + sampled_inds = np.concatenate(sampled_inds) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(full_set - set(sampled_inds))) + if len(extra_inds) > num_extra: + extra_inds = self.random_choice(extra_inds, num_extra) + sampled_inds = np.concatenate([sampled_inds, extra_inds]) + + return sampled_inds + + def _sample_neg(self, assign_result, num_expected, **kwargs): + """Sample negative boxes. + + Args: + assign_result (:obj:`AssignResult`): The assigned results of boxes. + num_expected (int): The number of expected negative samples + + Returns: + Tensor or ndarray: sampled indices. + """ + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + else: + max_overlaps = assign_result.max_overlaps.cpu().numpy() + # balance sampling for negative samples + neg_set = set(neg_inds.cpu().numpy()) + + if self.floor_thr > 0: + floor_set = set( + np.where( + np.logical_and(max_overlaps >= 0, + max_overlaps < self.floor_thr))[0]) + iou_sampling_set = set( + np.where(max_overlaps >= self.floor_thr)[0]) + elif self.floor_thr == 0: + floor_set = set(np.where(max_overlaps == 0)[0]) + iou_sampling_set = set( + np.where(max_overlaps > self.floor_thr)[0]) + else: + floor_set = set() + iou_sampling_set = set( + np.where(max_overlaps > self.floor_thr)[0]) + # for sampling interval calculation + self.floor_thr = 0 + + floor_neg_inds = list(floor_set & neg_set) + iou_sampling_neg_inds = list(iou_sampling_set & neg_set) + num_expected_iou_sampling = int(num_expected * + (1 - self.floor_fraction)) + if len(iou_sampling_neg_inds) > num_expected_iou_sampling: + if self.num_bins >= 2: + iou_sampled_inds = self.sample_via_interval( + max_overlaps, set(iou_sampling_neg_inds), + num_expected_iou_sampling) + else: + iou_sampled_inds = self.random_choice( + iou_sampling_neg_inds, num_expected_iou_sampling) + else: + iou_sampled_inds = np.array( + iou_sampling_neg_inds, dtype=np.int64) + num_expected_floor = num_expected - len(iou_sampled_inds) + if len(floor_neg_inds) > num_expected_floor: + sampled_floor_inds = self.random_choice( + floor_neg_inds, num_expected_floor) + else: + sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int64) + sampled_inds = np.concatenate( + (sampled_floor_inds, iou_sampled_inds)) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(neg_set - set(sampled_inds))) + if len(extra_inds) > num_extra: + extra_inds = self.random_choice(extra_inds, num_extra) + sampled_inds = np.concatenate((sampled_inds, extra_inds)) + sampled_inds = torch.from_numpy(sampled_inds).long().to( + assign_result.gt_inds.device) + return sampled_inds diff --git a/mmdet/models/task_modules/samplers/mask_pseudo_sampler.py b/mmdet/models/task_modules/samplers/mask_pseudo_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..307dd5d15c962b97dc60b899e60170d0bfed90a7 --- /dev/null +++ b/mmdet/models/task_modules/samplers/mask_pseudo_sampler.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""copy from +https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" + +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from ..assigners import AssignResult +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@TASK_UTILS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result: AssignResult, pred_instances: InstanceData, + gt_instances: InstanceData, *args, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Mask assigning results. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``scores`` and ``masks`` predicted + by the model. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``labels`` and ``masks`` + attributes. + + Returns: + :obj:`SamplingResult`: sampler results + """ + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = pred_masks.new_zeros(pred_masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + masks=pred_masks, + gt_masks=gt_masks, + assign_result=assign_result, + gt_flags=gt_flags, + avg_factor_with_neg=False) + return sampling_result diff --git a/mmdet/models/task_modules/samplers/mask_sampling_result.py b/mmdet/models/task_modules/samplers/mask_sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..adaa62e8a0af28bb004a34b961f672ec03988d2c --- /dev/null +++ b/mmdet/models/task_modules/samplers/mask_sampling_result.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""copy from +https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" + +import torch +from torch import Tensor + +from ..assigners import AssignResult +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, + pos_inds: Tensor, + neg_inds: Tensor, + masks: Tensor, + gt_masks: Tensor, + assign_result: AssignResult, + gt_flags: Tensor, + avg_factor_with_neg: bool = True) -> None: + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.num_pos = max(pos_inds.numel(), 1) + self.num_neg = max(neg_inds.numel(), 1) + self.avg_factor = self.num_pos + self.num_neg \ + if avg_factor_with_neg else self.num_pos + + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + @property + def masks(self) -> Tensor: + """torch.Tensor: concatenated positive and negative masks.""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self) -> str: + data = self.info.copy() + data['pos_masks'] = data.pop('pos_masks').shape + data['neg_masks'] = data.pop('neg_masks').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self) -> dict: + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_masks': self.pos_masks, + 'neg_masks': self.neg_masks, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } diff --git a/mmdet/models/task_modules/samplers/multi_instance_random_sampler.py b/mmdet/models/task_modules/samplers/multi_instance_random_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8b74054e3a11ed6025e98e90bd0addb131a1dc02 --- /dev/null +++ b/mmdet/models/task_modules/samplers/multi_instance_random_sampler.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from mmengine.structures import InstanceData +from numpy import ndarray +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from ..assigners import AssignResult +from .multi_instance_sampling_result import MultiInstanceSamplingResult +from .random_sampler import RandomSampler + + +@TASK_UTILS.register_module() +class MultiInsRandomSampler(RandomSampler): + """Random sampler for multi instance. + + Note: + Multi-instance means to predict multiple detection boxes with + one proposal box. `AssignResult` may assign multiple gt boxes + to each proposal box, in this case `RandomSampler` should be + replaced by `MultiInsRandomSampler` + """ + + def _sample_pos(self, assign_result: AssignResult, num_expected: int, + **kwargs) -> Union[Tensor, ndarray]: + """Randomly sample some positive samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + pos_inds = torch.nonzero( + assign_result.labels[:, 0] > 0, as_tuple=False) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + else: + return self.random_choice(pos_inds, num_expected) + + def _sample_neg(self, assign_result: AssignResult, num_expected: int, + **kwargs) -> Union[Tensor, ndarray]: + """Randomly sample some negative samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + neg_inds = torch.nonzero( + assign_result.labels[:, 0] == 0, as_tuple=False) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + else: + return self.random_choice(neg_inds, num_expected) + + def sample(self, assign_result: AssignResult, pred_instances: InstanceData, + gt_instances: InstanceData, + **kwargs) -> MultiInstanceSamplingResult: + """Sample positive and negative bboxes. + + Args: + assign_result (:obj:`AssignResult`): Assigning results from + MultiInstanceAssigner. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + + Returns: + :obj:`MultiInstanceSamplingResult`: Sampling result. + """ + + assert 'batch_gt_instances_ignore' in kwargs, \ + 'batch_gt_instances_ignore is necessary for MultiInsRandomSampler' + + gt_bboxes = gt_instances.bboxes + ignore_bboxes = kwargs['batch_gt_instances_ignore'].bboxes + gt_and_ignore_bboxes = torch.cat([gt_bboxes, ignore_bboxes], dim=0) + priors = pred_instances.priors + if len(priors.shape) < 2: + priors = priors[None, :] + priors = priors[:, :4] + + gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) + priors = torch.cat([priors, gt_and_ignore_bboxes], dim=0) + gt_ones = priors.new_ones( + gt_and_ignore_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_flags, gt_ones]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, + num_expected_pos) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, + num_expected_neg) + neg_inds = neg_inds.unique() + + sampling_result = MultiInstanceSamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + gt_and_ignore_bboxes=gt_and_ignore_bboxes, + assign_result=assign_result, + gt_flags=gt_flags) + return sampling_result diff --git a/mmdet/models/task_modules/samplers/multi_instance_sampling_result.py b/mmdet/models/task_modules/samplers/multi_instance_sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..438a0aa91c0cc8904f6d8bba7139408dd99b98cf --- /dev/null +++ b/mmdet/models/task_modules/samplers/multi_instance_sampling_result.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from ..assigners import AssignResult +from .sampling_result import SamplingResult + + +class MultiInstanceSamplingResult(SamplingResult): + """Bbox sampling result. Further encapsulation of SamplingResult. Three + attributes neg_assigned_gt_inds, neg_gt_labels, and neg_gt_bboxes have been + added for SamplingResult. + + Args: + pos_inds (Tensor): Indices of positive samples. + neg_inds (Tensor): Indices of negative samples. + priors (Tensor): The priors can be anchors or points, + or the bboxes predicted by the previous stage. + gt_and_ignore_bboxes (Tensor): Ground truth and ignore bboxes. + assign_result (:obj:`AssignResult`): Assigning results. + gt_flags (Tensor): The Ground truth flags. + avg_factor_with_neg (bool): If True, ``avg_factor`` equal to + the number of total priors; Otherwise, it is the number of + positive priors. Defaults to True. + """ + + def __init__(self, + pos_inds: Tensor, + neg_inds: Tensor, + priors: Tensor, + gt_and_ignore_bboxes: Tensor, + assign_result: AssignResult, + gt_flags: Tensor, + avg_factor_with_neg: bool = True) -> None: + self.neg_assigned_gt_inds = assign_result.gt_inds[neg_inds] + self.neg_gt_labels = assign_result.labels[neg_inds] + + if gt_and_ignore_bboxes.numel() == 0: + self.neg_gt_bboxes = torch.empty_like(gt_and_ignore_bboxes).view( + -1, 4) + else: + if len(gt_and_ignore_bboxes.shape) < 2: + gt_and_ignore_bboxes = gt_and_ignore_bboxes.view(-1, 4) + self.neg_gt_bboxes = gt_and_ignore_bboxes[ + self.neg_assigned_gt_inds.long(), :] + + # To resist the minus 1 operation in `SamplingResult.init()`. + assign_result.gt_inds += 1 + super().__init__( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + gt_bboxes=gt_and_ignore_bboxes, + assign_result=assign_result, + gt_flags=gt_flags, + avg_factor_with_neg=avg_factor_with_neg) diff --git a/mmdet/models/task_modules/samplers/ohem_sampler.py b/mmdet/models/task_modules/samplers/ohem_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f478a448cde00d64caeba1d0ba613d2497a7fb12 --- /dev/null +++ b/mmdet/models/task_modules/samplers/ohem_sampler.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import bbox2roi +from .base_sampler import BaseSampler + + +@TASK_UTILS.register_module() +class OHEMSampler(BaseSampler): + r"""Online Hard Example Mining Sampler described in `Training Region-based + Object Detectors with Online Hard Example Mining + `_. + """ + + def __init__(self, + num, + pos_fraction, + context, + neg_pos_ub=-1, + add_gt_as_proposals=True, + loss_key='loss_cls', + **kwargs): + super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, + add_gt_as_proposals) + self.context = context + if not hasattr(self.context, 'num_stages'): + self.bbox_head = self.context.bbox_head + else: + self.bbox_head = self.context.bbox_head[self.context.current_stage] + + self.loss_key = loss_key + + def hard_mining(self, inds, num_expected, bboxes, labels, feats): + with torch.no_grad(): + rois = bbox2roi([bboxes]) + if not hasattr(self.context, 'num_stages'): + bbox_results = self.context._bbox_forward(feats, rois) + else: + bbox_results = self.context._bbox_forward( + self.context.current_stage, feats, rois) + cls_score = bbox_results['cls_score'] + loss = self.bbox_head.loss( + cls_score=cls_score, + bbox_pred=None, + rois=rois, + labels=labels, + label_weights=cls_score.new_ones(cls_score.size(0)), + bbox_targets=None, + bbox_weights=None, + reduction_override='none')[self.loss_key] + _, topk_loss_inds = loss.topk(num_expected) + return inds[topk_loss_inds] + + def _sample_pos(self, + assign_result, + num_expected, + bboxes=None, + feats=None, + **kwargs): + """Sample positive boxes. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + num_expected (int): Number of expected positive samples + bboxes (torch.Tensor, optional): Boxes. Defaults to None. + feats (list[torch.Tensor], optional): Multi-level features. + Defaults to None. + + Returns: + torch.Tensor: Indices of positive samples + """ + # Sample some hard positive samples + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + else: + return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds], + assign_result.labels[pos_inds], feats) + + def _sample_neg(self, + assign_result, + num_expected, + bboxes=None, + feats=None, + **kwargs): + """Sample negative boxes. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + num_expected (int): Number of expected negative samples + bboxes (torch.Tensor, optional): Boxes. Defaults to None. + feats (list[torch.Tensor], optional): Multi-level features. + Defaults to None. + + Returns: + torch.Tensor: Indices of negative samples + """ + # Sample some hard negative samples + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + else: + neg_labels = assign_result.labels.new_empty( + neg_inds.size(0)).fill_(self.bbox_head.num_classes) + return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds], + neg_labels, feats) diff --git a/mmdet/models/task_modules/samplers/pseudo_sampler.py b/mmdet/models/task_modules/samplers/pseudo_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a8186cc3364516f34abe1c293017db6e2042d92a --- /dev/null +++ b/mmdet/models/task_modules/samplers/pseudo_sampler.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import TASK_UTILS +from ..assigners import AssignResult +from .base_sampler import BaseSampler +from .sampling_result import SamplingResult + + +@TASK_UTILS.register_module() +class PseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result: AssignResult, pred_instances: InstanceData, + gt_instances: InstanceData, *args, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors, points, or bboxes predicted by the model, + shape(n, 4). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + :obj:`SamplingResult`: sampler results + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + + gt_flags = priors.new_zeros(priors.shape[0], dtype=torch.uint8) + sampling_result = SamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + gt_bboxes=gt_bboxes, + assign_result=assign_result, + gt_flags=gt_flags, + avg_factor_with_neg=False) + return sampling_result diff --git a/mmdet/models/task_modules/samplers/random_sampler.py b/mmdet/models/task_modules/samplers/random_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..fa03665fc36cc6a0084431324b16727b2dc8993e --- /dev/null +++ b/mmdet/models/task_modules/samplers/random_sampler.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from numpy import ndarray +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from ..assigners import AssignResult +from .base_sampler import BaseSampler + + +@TASK_UTILS.register_module() +class RandomSampler(BaseSampler): + """Random sampler. + + Args: + num (int): Number of samples + pos_fraction (float): Fraction of positive samples + neg_pos_up (int): Upper bound number of negative and + positive samples. Defaults to -1. + add_gt_as_proposals (bool): Whether to add ground truth + boxes as proposals. Defaults to True. + """ + + def __init__(self, + num: int, + pos_fraction: float, + neg_pos_ub: int = -1, + add_gt_as_proposals: bool = True, + **kwargs): + from .sampling_result import ensure_rng + super().__init__( + num=num, + pos_fraction=pos_fraction, + neg_pos_ub=neg_pos_ub, + add_gt_as_proposals=add_gt_as_proposals) + self.rng = ensure_rng(kwargs.get('rng', None)) + + def random_choice(self, gallery: Union[Tensor, ndarray, list], + num: int) -> Union[Tensor, ndarray]: + """Random select some elements from the gallery. + + If `gallery` is a Tensor, the returned indices will be a Tensor; + If `gallery` is a ndarray or list, the returned indices will be a + ndarray. + + Args: + gallery (Tensor | ndarray | list): indices pool. + num (int): expected sample num. + + Returns: + Tensor or ndarray: sampled indices. + """ + assert len(gallery) >= num + + is_tensor = isinstance(gallery, torch.Tensor) + if not is_tensor: + if torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = 'cpu' + gallery = torch.tensor(gallery, dtype=torch.long, device=device) + # This is a temporary fix. We can revert the following code + # when PyTorch fixes the abnormal return of torch.randperm. + # See: https://github.com/open-mmlab/mmdetection/pull/5014 + perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) + rand_inds = gallery[perm] + if not is_tensor: + rand_inds = rand_inds.cpu().numpy() + return rand_inds + + def _sample_pos(self, assign_result: AssignResult, num_expected: int, + **kwargs) -> Union[Tensor, ndarray]: + """Randomly sample some positive samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + else: + return self.random_choice(pos_inds, num_expected) + + def _sample_neg(self, assign_result: AssignResult, num_expected: int, + **kwargs) -> Union[Tensor, ndarray]: + """Randomly sample some negative samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + else: + return self.random_choice(neg_inds, num_expected) diff --git a/mmdet/models/task_modules/samplers/sampling_result.py b/mmdet/models/task_modules/samplers/sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..cb510ee68f24b8c444b6ed447016bfc785b825c2 --- /dev/null +++ b/mmdet/models/task_modules/samplers/sampling_result.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch +from torch import Tensor + +from mmdet.structures.bbox import BaseBoxes, cat_boxes +from mmdet.utils import util_mixins +from mmdet.utils.util_random import ensure_rng +from ..assigners import AssignResult + + +def random_boxes(num=1, scale=1, rng=None): + """Simple version of ``kwimage.Boxes.random`` + + Returns: + Tensor: shape (n, 4) in x1, y1, x2, y2 format. + + References: + https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 + + Example: + >>> num = 3 + >>> scale = 512 + >>> rng = 0 + >>> boxes = random_boxes(num, scale, rng) + >>> print(boxes) + tensor([[280.9925, 278.9802, 308.6148, 366.1769], + [216.9113, 330.6978, 224.0446, 456.5878], + [405.3632, 196.3221, 493.3953, 270.7942]]) + """ + rng = ensure_rng(rng) + + tlbr = rng.rand(num, 4).astype(np.float32) + + tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2]) + tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3]) + br_x = np.maximum(tlbr[:, 0], tlbr[:, 2]) + br_y = np.maximum(tlbr[:, 1], tlbr[:, 3]) + + tlbr[:, 0] = tl_x * scale + tlbr[:, 1] = tl_y * scale + tlbr[:, 2] = br_x * scale + tlbr[:, 3] = br_y * scale + + boxes = torch.from_numpy(tlbr) + return boxes + + +class SamplingResult(util_mixins.NiceRepr): + """Bbox sampling result. + + Args: + pos_inds (Tensor): Indices of positive samples. + neg_inds (Tensor): Indices of negative samples. + priors (Tensor): The priors can be anchors or points, + or the bboxes predicted by the previous stage. + gt_bboxes (Tensor): Ground truth of bboxes. + assign_result (:obj:`AssignResult`): Assigning results. + gt_flags (Tensor): The Ground truth flags. + avg_factor_with_neg (bool): If True, ``avg_factor`` equal to + the number of total priors; Otherwise, it is the number of + positive priors. Defaults to True. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = + """ + + def __init__(self, + pos_inds: Tensor, + neg_inds: Tensor, + priors: Tensor, + gt_bboxes: Tensor, + assign_result: AssignResult, + gt_flags: Tensor, + avg_factor_with_neg: bool = True) -> None: + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.num_pos = max(pos_inds.numel(), 1) + self.num_neg = max(neg_inds.numel(), 1) + self.avg_factor_with_neg = avg_factor_with_neg + self.avg_factor = self.num_pos + self.num_neg \ + if avg_factor_with_neg else self.num_pos + self.pos_priors = priors[pos_inds] + self.neg_priors = priors[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + self.pos_gt_labels = assign_result.labels[pos_inds] + box_dim = gt_bboxes.box_dim if isinstance(gt_bboxes, BaseBoxes) else 4 + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = gt_bboxes.view(-1, box_dim) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, box_dim) + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()] + + @property + def priors(self): + """torch.Tensor: concatenated positive and negative priors""" + return cat_boxes([self.pos_priors, self.neg_priors]) + + @property + def bboxes(self): + """torch.Tensor: concatenated positive and negative boxes""" + warnings.warn('DeprecationWarning: bboxes is deprecated, ' + 'please use "priors" instead') + return self.priors + + @property + def pos_bboxes(self): + warnings.warn('DeprecationWarning: pos_bboxes is deprecated, ' + 'please use "pos_priors" instead') + return self.pos_priors + + @property + def neg_bboxes(self): + warnings.warn('DeprecationWarning: neg_bboxes is deprecated, ' + 'please use "neg_priors" instead') + return self.neg_priors + + def to(self, device): + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, (torch.Tensor, BaseBoxes)): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data['pos_priors'] = data.pop('pos_priors').shape + data['neg_priors'] = data.pop('neg_priors').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_priors': self.pos_priors, + 'neg_priors': self.neg_priors, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + 'num_pos': self.num_pos, + 'num_neg': self.num_neg, + 'avg_factor': self.avg_factor + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state. + kwargs (keyword arguments): + - num_preds: Number of predicted boxes. + - num_gts: Number of true boxes. + - p_ignore (float): Probability of a predicted box assigned to + an ignored truth. + - p_assigned (float): probability of a predicted box not being + assigned. + + Returns: + :obj:`SamplingResult`: Randomly generated sampling result. + + Example: + >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmengine.structures import InstanceData + + from mmdet.models.task_modules.assigners import AssignResult + from mmdet.models.task_modules.samplers import RandomSampler + rng = ensure_rng(rng) + + # make probabilistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + priors = random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + gt_labels = torch.randint( + 0, 5, (assign_result.num_gts, ), dtype=torch.long) + + pred_instances = InstanceData() + pred_instances.priors = priors + + gt_instances = InstanceData() + gt_instances.bboxes = gt_bboxes + gt_instances.labels = gt_labels + + add_gt_as_proposals = True + + sampler = RandomSampler( + num, + pos_fraction, + neg_pos_ub=neg_pos_ub, + add_gt_as_proposals=add_gt_as_proposals, + rng=rng) + self = sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + return self diff --git a/mmdet/models/task_modules/samplers/score_hlr_sampler.py b/mmdet/models/task_modules/samplers/score_hlr_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..0227585b92329625d053f1e9f8c161fd02af8aef --- /dev/null +++ b/mmdet/models/task_modules/samplers/score_hlr_sampler.py @@ -0,0 +1,290 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from mmcv.ops import nms_match +from mmengine.structures import InstanceData +from numpy import ndarray +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import bbox2roi +from ..assigners import AssignResult +from .base_sampler import BaseSampler +from .sampling_result import SamplingResult + + +@TASK_UTILS.register_module() +class ScoreHLRSampler(BaseSampler): + r"""Importance-based Sample Reweighting (ISR_N), described in `Prime Sample + Attention in Object Detection `_. + + Score hierarchical local rank (HLR) differentiates with RandomSampler in + negative part. It firstly computes Score-HLR in a two-step way, + then linearly maps score hlr to the loss weights. + + Args: + num (int): Total number of sampled RoIs. + pos_fraction (float): Fraction of positive samples. + context (:obj:`BaseRoIHead`): RoI head that the sampler belongs to. + neg_pos_ub (int): Upper bound of the ratio of num negative to num + positive, -1 means no upper bound. Defaults to -1. + add_gt_as_proposals (bool): Whether to add ground truth as proposals. + Defaults to True. + k (float): Power of the non-linear mapping. Defaults to 0.5 + bias (float): Shift of the non-linear mapping. Defaults to 0. + score_thr (float): Minimum score that a negative sample is to be + considered as valid bbox. Defaults to 0.05. + iou_thr (float): IoU threshold for NMS match. Defaults to 0.5. + """ + + def __init__(self, + num: int, + pos_fraction: float, + context, + neg_pos_ub: int = -1, + add_gt_as_proposals: bool = True, + k: float = 0.5, + bias: float = 0, + score_thr: float = 0.05, + iou_thr: float = 0.5, + **kwargs) -> None: + super().__init__( + num=num, + pos_fraction=pos_fraction, + neg_pos_ub=neg_pos_ub, + add_gt_as_proposals=add_gt_as_proposals) + self.k = k + self.bias = bias + self.score_thr = score_thr + self.iou_thr = iou_thr + self.context = context + # context of cascade detectors is a list, so distinguish them here. + if not hasattr(context, 'num_stages'): + self.bbox_roi_extractor = context.bbox_roi_extractor + self.bbox_head = context.bbox_head + self.with_shared_head = context.with_shared_head + if self.with_shared_head: + self.shared_head = context.shared_head + else: + self.bbox_roi_extractor = context.bbox_roi_extractor[ + context.current_stage] + self.bbox_head = context.bbox_head[context.current_stage] + + @staticmethod + def random_choice(gallery: Union[Tensor, ndarray, list], + num: int) -> Union[Tensor, ndarray]: + """Randomly select some elements from the gallery. + + If `gallery` is a Tensor, the returned indices will be a Tensor; + If `gallery` is a ndarray or list, the returned indices will be a + ndarray. + + Args: + gallery (Tensor or ndarray or list): indices pool. + num (int): expected sample num. + + Returns: + Tensor or ndarray: sampled indices. + """ + assert len(gallery) >= num + + is_tensor = isinstance(gallery, torch.Tensor) + if not is_tensor: + if torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = 'cpu' + gallery = torch.tensor(gallery, dtype=torch.long, device=device) + perm = torch.randperm(gallery.numel(), device=gallery.device)[:num] + rand_inds = gallery[perm] + if not is_tensor: + rand_inds = rand_inds.cpu().numpy() + return rand_inds + + def _sample_pos(self, assign_result: AssignResult, num_expected: int, + **kwargs) -> Union[Tensor, ndarray]: + """Randomly sample some positive samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0).flatten() + if pos_inds.numel() <= num_expected: + return pos_inds + else: + return self.random_choice(pos_inds, num_expected) + + def _sample_neg(self, assign_result: AssignResult, num_expected: int, + bboxes: Tensor, feats: Tensor, + **kwargs) -> Union[Tensor, ndarray]: + """Sample negative samples. + + Score-HLR sampler is done in the following steps: + 1. Take the maximum positive score prediction of each negative samples + as s_i. + 2. Filter out negative samples whose s_i <= score_thr, the left samples + are called valid samples. + 3. Use NMS-Match to divide valid samples into different groups, + samples in the same group will greatly overlap with each other + 4. Rank the matched samples in two-steps to get Score-HLR. + (1) In the same group, rank samples with their scores. + (2) In the same score rank across different groups, + rank samples with their scores again. + 5. Linearly map Score-HLR to the final label weights. + + Args: + assign_result (:obj:`AssignResult`): result of assigner. + num_expected (int): Expected number of samples. + bboxes (Tensor): bbox to be sampled. + feats (Tensor): Features come from FPN. + + Returns: + Tensor or ndarray: sampled indices. + """ + neg_inds = torch.nonzero(assign_result.gt_inds == 0).flatten() + num_neg = neg_inds.size(0) + if num_neg == 0: + return neg_inds, None + with torch.no_grad(): + neg_bboxes = bboxes[neg_inds] + neg_rois = bbox2roi([neg_bboxes]) + bbox_result = self.context._bbox_forward(feats, neg_rois) + cls_score, bbox_pred = bbox_result['cls_score'], bbox_result[ + 'bbox_pred'] + + ori_loss = self.bbox_head.loss( + cls_score=cls_score, + bbox_pred=None, + rois=None, + labels=neg_inds.new_full((num_neg, ), + self.bbox_head.num_classes), + label_weights=cls_score.new_ones(num_neg), + bbox_targets=None, + bbox_weights=None, + reduction_override='none')['loss_cls'] + + # filter out samples with the max score lower than score_thr + max_score, argmax_score = cls_score.softmax(-1)[:, :-1].max(-1) + valid_inds = (max_score > self.score_thr).nonzero().view(-1) + invalid_inds = (max_score <= self.score_thr).nonzero().view(-1) + num_valid = valid_inds.size(0) + num_invalid = invalid_inds.size(0) + + num_expected = min(num_neg, num_expected) + num_hlr = min(num_valid, num_expected) + num_rand = num_expected - num_hlr + if num_valid > 0: + valid_rois = neg_rois[valid_inds] + valid_max_score = max_score[valid_inds] + valid_argmax_score = argmax_score[valid_inds] + valid_bbox_pred = bbox_pred[valid_inds] + + # valid_bbox_pred shape: [num_valid, #num_classes, 4] + valid_bbox_pred = valid_bbox_pred.view( + valid_bbox_pred.size(0), -1, 4) + selected_bbox_pred = valid_bbox_pred[range(num_valid), + valid_argmax_score] + pred_bboxes = self.bbox_head.bbox_coder.decode( + valid_rois[:, 1:], selected_bbox_pred) + pred_bboxes_with_score = torch.cat( + [pred_bboxes, valid_max_score[:, None]], -1) + group = nms_match(pred_bboxes_with_score, self.iou_thr) + + # imp: importance + imp = cls_score.new_zeros(num_valid) + for g in group: + g_score = valid_max_score[g] + # g_score has already sorted + rank = g_score.new_tensor(range(g_score.size(0))) + imp[g] = num_valid - rank + g_score + _, imp_rank_inds = imp.sort(descending=True) + _, imp_rank = imp_rank_inds.sort() + hlr_inds = imp_rank_inds[:num_expected] + + if num_rand > 0: + rand_inds = torch.randperm(num_invalid)[:num_rand] + select_inds = torch.cat( + [valid_inds[hlr_inds], invalid_inds[rand_inds]]) + else: + select_inds = valid_inds[hlr_inds] + + neg_label_weights = cls_score.new_ones(num_expected) + + up_bound = max(num_expected, num_valid) + imp_weights = (up_bound - + imp_rank[hlr_inds].float()) / up_bound + neg_label_weights[:num_hlr] = imp_weights + neg_label_weights[num_hlr:] = imp_weights.min() + neg_label_weights = (self.bias + + (1 - self.bias) * neg_label_weights).pow( + self.k) + ori_selected_loss = ori_loss[select_inds] + new_loss = ori_selected_loss * neg_label_weights + norm_ratio = ori_selected_loss.sum() / new_loss.sum() + neg_label_weights *= norm_ratio + else: + neg_label_weights = cls_score.new_ones(num_expected) + select_inds = torch.randperm(num_neg)[:num_expected] + + return neg_inds[select_inds], neg_label_weights + + def sample(self, assign_result: AssignResult, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> SamplingResult: + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Assigning results. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + + Returns: + :obj:`SamplingResult`: Sampling result. + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + gt_labels = gt_instances.labels + + gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + priors = torch.cat([gt_bboxes, priors], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = priors.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos( + assign_result, num_expected_pos, bboxes=priors, **kwargs) + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds, neg_label_weights = self.neg_sampler._sample_neg( + assign_result, num_expected_neg, bboxes=priors, **kwargs) + + sampling_result = SamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + gt_bboxes=gt_bboxes, + assign_result=assign_result, + gt_flags=gt_flags) + return sampling_result, neg_label_weights diff --git a/mmdet/models/task_modules/tracking/__init__.py b/mmdet/models/task_modules/tracking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57a86d739d586e47e007d26de4542d6bdeced755 --- /dev/null +++ b/mmdet/models/task_modules/tracking/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .aflink import AppearanceFreeLink +from .camera_motion_compensation import CameraMotionCompensation +from .interpolation import InterpolateTracklets +from .kalman_filter import KalmanFilter +from .similarity import embed_similarity + +__all__ = [ + 'KalmanFilter', 'InterpolateTracklets', 'embed_similarity', + 'AppearanceFreeLink', 'CameraMotionCompensation' +] diff --git a/mmdet/models/task_modules/tracking/aflink.py b/mmdet/models/task_modules/tracking/aflink.py new file mode 100644 index 0000000000000000000000000000000000000000..52461067e372b30bbd28325ead00f5381c546326 --- /dev/null +++ b/mmdet/models/task_modules/tracking/aflink.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Tuple + +import numpy as np +import torch +from mmengine.model import BaseModule +from mmengine.runner.checkpoint import load_checkpoint +from scipy.optimize import linear_sum_assignment +from torch import Tensor, nn + +from mmdet.registry import TASK_UTILS + +INFINITY = 1e5 + + +class TemporalBlock(BaseModule): + """The temporal block of AFLink model. + + Args: + in_channel (int): the dimension of the input channels. + out_channel (int): the dimension of the output channels. + """ + + def __init__(self, + in_channel: int, + out_channel: int, + kernel_size: tuple = (7, 1)): + super(TemporalBlock, self).__init__() + self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, bias=False) + self.relu = nn.ReLU(inplace=True) + self.bnf = nn.BatchNorm1d(out_channel) + self.bnx = nn.BatchNorm1d(out_channel) + self.bny = nn.BatchNorm1d(out_channel) + + def bn(self, x: Tensor) -> Tensor: + x[:, :, :, 0] = self.bnf(x[:, :, :, 0]) + x[:, :, :, 1] = self.bnx(x[:, :, :, 1]) + x[:, :, :, 2] = self.bny(x[:, :, :, 2]) + return x + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class FusionBlock(BaseModule): + """The fusion block of AFLink model. + + Args: + in_channel (int): the dimension of the input channels. + out_channel (int): the dimension of the output channels. + """ + + def __init__(self, in_channel: int, out_channel: int): + super(FusionBlock, self).__init__() + self.conv = nn.Conv2d(in_channel, out_channel, (1, 3), bias=False) + self.bn = nn.BatchNorm2d(out_channel) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Classifier(BaseModule): + """The classifier of AFLink model. + + Args: + in_channel (int): the dimension of the input channels. + """ + + def __init__(self, in_channel: int, out_channel: int): + super(Classifier, self).__init__() + self.fc1 = nn.Linear(in_channel * 2, in_channel // 2) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Linear(in_channel // 2, out_channel) + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + x = torch.cat((x1, x2), dim=1) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + +class AFLinkModel(BaseModule): + """Appearance-Free Link Model.""" + + def __init__(self, + temporal_module_channels: list = [1, 32, 64, 128, 256], + fusion_module_channels: list = [256, 256], + classifier_channels: list = [256, 2]): + super(AFLinkModel, self).__init__() + self.TemporalModule_1 = nn.Sequential(*[ + TemporalBlock(temporal_module_channels[i], + temporal_module_channels[i + 1]) + for i in range(len(temporal_module_channels) - 1) + ]) + + self.TemporalModule_2 = nn.Sequential(*[ + TemporalBlock(temporal_module_channels[i], + temporal_module_channels[i + 1]) + for i in range(len(temporal_module_channels) - 1) + ]) + + self.FusionBlock_1 = FusionBlock(*fusion_module_channels) + self.FusionBlock_2 = FusionBlock(*fusion_module_channels) + + self.pooling = nn.AdaptiveAvgPool2d((1, 1)) + self.classifier = Classifier(*classifier_channels) + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + assert not self.training, 'Only testing is supported for AFLink.' + x1 = x1[:, :, :, :3] + x2 = x2[:, :, :, :3] + x1 = self.TemporalModule_1(x1) # [B,1,30,3] -> [B,256,6,3] + x2 = self.TemporalModule_2(x2) + x1 = self.FusionBlock_1(x1) + x2 = self.FusionBlock_2(x2) + x1 = self.pooling(x1).squeeze(-1).squeeze(-1) + x2 = self.pooling(x2).squeeze(-1).squeeze(-1) + y = self.classifier(x1, x2) + y = torch.softmax(y, dim=1)[0, 1] + return y + + +@TASK_UTILS.register_module() +class AppearanceFreeLink(BaseModule): + """Appearance-Free Link method. + + This method is proposed in + "StrongSORT: Make DeepSORT Great Again" + `StrongSORT`_. + + Args: + checkpoint (str): Checkpoint path. + temporal_threshold (tuple, optional): The temporal constraint + for tracklets association. Defaults to (0, 30). + spatial_threshold (int, optional): The spatial constraint for + tracklets association. Defaults to 75. + confidence_threshold (float, optional): The minimum confidence + threshold for tracklets association. Defaults to 0.95. + """ + + def __init__(self, + checkpoint: str, + temporal_threshold: tuple = (0, 30), + spatial_threshold: int = 75, + confidence_threshold: float = 0.95): + super(AppearanceFreeLink, self).__init__() + self.temporal_threshold = temporal_threshold + self.spatial_threshold = spatial_threshold + self.confidence_threshold = confidence_threshold + + self.model = AFLinkModel() + if checkpoint: + load_checkpoint(self.model, checkpoint) + if torch.cuda.is_available(): + self.model.cuda() + self.model.eval() + + self.device = next(self.model.parameters()).device + self.fn_l2 = lambda x, y: np.sqrt(x**2 + y**2) + + def data_transform(self, + track1: np.ndarray, + track2: np.ndarray, + length: int = 30) -> Tuple[np.ndarray]: + """Data Transformation. This is used to standardize the length of + tracks to a unified length. Then perform min-max normalization to the + motion embeddings. + + Args: + track1 (ndarray): the first track with shape (N,C). + track2 (ndarray): the second track with shape (M,C). + length (int): the unified length of tracks. Defaults to 30. + + Returns: + Tuple[ndarray]: the transformed track1 and track2. + """ + # fill or cut track1 + length_1 = track1.shape[0] + track1 = track1[-length:] if length_1 >= length else \ + np.pad(track1, ((length - length_1, 0), (0, 0))) + + # fill or cut track1 + length_2 = track2.shape[0] + track2 = track2[:length] if length_2 >= length else \ + np.pad(track2, ((0, length - length_2), (0, 0))) + + # min-max normalization + min_ = np.concatenate((track1, track2), axis=0).min(axis=0) + max_ = np.concatenate((track1, track2), axis=0).max(axis=0) + subtractor = (max_ + min_) / 2 + divisor = (max_ - min_) / 2 + 1e-5 + track1 = (track1 - subtractor) / divisor + track2 = (track2 - subtractor) / divisor + + return track1, track2 + + def forward(self, pred_tracks: np.ndarray) -> np.ndarray: + """Forward function. + + pred_tracks (ndarray): With shape (N, 7). Each row denotes + (frame_id, track_id, x1, y1, x2, y2, score). + + Returns: + ndarray: The linked tracks with shape (N, 7). Each row denotes + (frame_id, track_id, x1, y1, x2, y2, score) + """ + # sort tracks by the frame id + pred_tracks = pred_tracks[np.argsort(pred_tracks[:, 0])] + + # gather tracks information + id2info = defaultdict(list) + for row in pred_tracks: + frame_id, track_id, x1, y1, x2, y2 = row[:6] + id2info[track_id].append([frame_id, x1, y1, x2 - x1, y2 - y1]) + id2info = {k: np.array(v) for k, v in id2info.items()} + num_track = len(id2info) + track_ids = np.array(list(id2info)) + cost_matrix = np.full((num_track, num_track), INFINITY) + + # compute the cost matrix + for i, id_i in enumerate(track_ids): + for j, id_j in enumerate(track_ids): + if id_i == id_j: + continue + info_i, info_j = id2info[id_i], id2info[id_j] + frame_i, box_i = info_i[-1][0], info_i[-1][1:3] + frame_j, box_j = info_j[0][0], info_j[0][1:3] + # temporal constraint + if not self.temporal_threshold[0] <= \ + frame_j - frame_i <= self.temporal_threshold[1]: + continue + # spatial constraint + if self.fn_l2(box_i[0] - box_j[0], box_i[1] - box_j[1]) \ + > self.spatial_threshold: + continue + # confidence constraint + track_i, track_j = self.data_transform(info_i, info_j) + + # numpy to torch + track_i = torch.tensor( + track_i, dtype=torch.float).to(self.device) + track_j = torch.tensor( + track_j, dtype=torch.float).to(self.device) + track_i = track_i.unsqueeze(0).unsqueeze(0) + track_j = track_j.unsqueeze(0).unsqueeze(0) + + confidence = self.model(track_i, + track_j).detach().cpu().numpy() + if confidence >= self.confidence_threshold: + cost_matrix[i, j] = 1 - confidence + + # linear assignment + indices = linear_sum_assignment(cost_matrix) + _id2id = dict() # the temporary assignment results + id2id = dict() # the final assignment results + for i, j in zip(indices[0], indices[1]): + if cost_matrix[i, j] < INFINITY: + _id2id[i] = j + for k, v in _id2id.items(): + if k in id2id: + id2id[v] = id2id[k] + else: + id2id[v] = k + + # link + for k, v in id2id.items(): + pred_tracks[pred_tracks[:, 1] == k, 1] = v + + # deduplicate + _, index = np.unique(pred_tracks[:, :2], return_index=True, axis=0) + + return pred_tracks[index] diff --git a/mmdet/models/task_modules/tracking/camera_motion_compensation.py b/mmdet/models/task_modules/tracking/camera_motion_compensation.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6298494fd1c24e0e7bba457dd50864725f98c8 --- /dev/null +++ b/mmdet/models/task_modules/tracking/camera_motion_compensation.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np +import torch +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import bbox_cxcyah_to_xyxy, bbox_xyxy_to_cxcyah + + +@TASK_UTILS.register_module() +class CameraMotionCompensation: + """Camera motion compensation. + + Args: + warp_mode (str): Warp mode in opencv. + Defaults to 'cv2.MOTION_EUCLIDEAN'. + num_iters (int): Number of the iterations. Defaults to 50. + stop_eps (float): Terminate threshold. Defaults to 0.001. + """ + + def __init__(self, + warp_mode: str = 'cv2.MOTION_EUCLIDEAN', + num_iters: int = 50, + stop_eps: float = 0.001): + self.warp_mode = eval(warp_mode) + self.num_iters = num_iters + self.stop_eps = stop_eps + + def get_warp_matrix(self, img: np.ndarray, ref_img: np.ndarray) -> Tensor: + """Calculate warping matrix between two images.""" + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2GRAY) + + warp_matrix = np.eye(2, 3, dtype=np.float32) + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, + self.num_iters, self.stop_eps) + cc, warp_matrix = cv2.findTransformECC(img, ref_img, warp_matrix, + self.warp_mode, criteria, None, + 1) + warp_matrix = torch.from_numpy(warp_matrix) + return warp_matrix + + def warp_bboxes(self, bboxes: Tensor, warp_matrix: Tensor) -> Tensor: + """Warp bounding boxes according to the warping matrix.""" + tl, br = bboxes[:, :2], bboxes[:, 2:] + tl = torch.cat((tl, torch.ones(tl.shape[0], 1).to(bboxes.device)), + dim=1) + br = torch.cat((br, torch.ones(tl.shape[0], 1).to(bboxes.device)), + dim=1) + trans_tl = torch.mm(warp_matrix, tl.t()).t() + trans_br = torch.mm(warp_matrix, br.t()).t() + trans_bboxes = torch.cat((trans_tl, trans_br), dim=1) + return trans_bboxes.to(bboxes.device) + + def warp_means(self, means: np.ndarray, warp_matrix: Tensor) -> np.ndarray: + """Warp track.mean according to the warping matrix.""" + cxcyah = torch.from_numpy(means[:, :4]).float() + xyxy = bbox_cxcyah_to_xyxy(cxcyah) + warped_xyxy = self.warp_bboxes(xyxy, warp_matrix) + warped_cxcyah = bbox_xyxy_to_cxcyah(warped_xyxy).numpy() + means[:, :4] = warped_cxcyah + return means + + def track(self, img: Tensor, ref_img: Tensor, tracks: dict, + num_samples: int, frame_id: int, metainfo: dict) -> dict: + """Tracking forward.""" + img = img.squeeze(0).cpu().numpy().transpose((1, 2, 0)) + ref_img = ref_img.squeeze(0).cpu().numpy().transpose((1, 2, 0)) + warp_matrix = self.get_warp_matrix(img, ref_img) + + # rescale the warp_matrix due to the `resize` in pipeline + scale_factor_h, scale_factor_w = metainfo['scale_factor'] + warp_matrix[0, 2] = warp_matrix[0, 2] / scale_factor_w + warp_matrix[1, 2] = warp_matrix[1, 2] / scale_factor_h + + bboxes = [] + num_bboxes = [] + means = [] + for k, v in tracks.items(): + if int(v['frame_ids'][-1]) < frame_id - 1: + _num = 1 + else: + _num = min(num_samples, len(v.bboxes)) + num_bboxes.append(_num) + bboxes.extend(v.bboxes[-_num:]) + if len(v.mean) > 0: + means.append(v.mean) + bboxes = torch.cat(bboxes, dim=0) + warped_bboxes = self.warp_bboxes(bboxes, warp_matrix.to(bboxes.device)) + + warped_bboxes = torch.split(warped_bboxes, num_bboxes) + for b, (k, v) in zip(warped_bboxes, tracks.items()): + _num = b.shape[0] + b = torch.split(b, [1] * _num) + tracks[k].bboxes[-_num:] = b + + if means: + means = np.asarray(means) + warped_means = self.warp_means(means, warp_matrix) + for m, (k, v) in zip(warped_means, tracks.items()): + tracks[k].mean = m + + return tracks diff --git a/mmdet/models/task_modules/tracking/interpolation.py b/mmdet/models/task_modules/tracking/interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6a25af4f253e3ec6b9781831ff43c6bafe50e1 --- /dev/null +++ b/mmdet/models/task_modules/tracking/interpolation.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +try: + from sklearn.gaussian_process import GaussianProcessRegressor as GPR + from sklearn.gaussian_process.kernels import RBF + HAS_SKIKIT_LEARN = True +except ImportError: + HAS_SKIKIT_LEARN = False + +from mmdet.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class InterpolateTracklets: + """Interpolate tracks to make tracks more complete. + + Args: + min_num_frames (int, optional): The minimum length of a track that will + be interpolated. Defaults to 5. + max_num_frames (int, optional): The maximum disconnected length in + a track. Defaults to 20. + use_gsi (bool, optional): Whether to use the GSI (Gaussian-smoothed + interpolation) method. Defaults to False. + smooth_tau (int, optional): smoothing parameter in GSI. Defaults to 10. + """ + + def __init__(self, + min_num_frames: int = 5, + max_num_frames: int = 20, + use_gsi: bool = False, + smooth_tau: int = 10): + if not HAS_SKIKIT_LEARN: + raise RuntimeError('sscikit-learn is not installed,\ + please install it by: pip install scikit-learn') + self.min_num_frames = min_num_frames + self.max_num_frames = max_num_frames + self.use_gsi = use_gsi + self.smooth_tau = smooth_tau + + def _interpolate_track(self, + track: np.ndarray, + track_id: int, + max_num_frames: int = 20) -> np.ndarray: + """Interpolate a track linearly to make the track more complete. + + This function is proposed in + "ByteTrack: Multi-Object Tracking by Associating Every Detection Box." + `ByteTrack`_. + + Args: + track (ndarray): With shape (N, 7). Each row denotes + (frame_id, track_id, x1, y1, x2, y2, score). + max_num_frames (int, optional): The maximum disconnected length in + the track. Defaults to 20. + + Returns: + ndarray: The interpolated track with shape (N, 7). Each row denotes + (frame_id, track_id, x1, y1, x2, y2, score) + """ + assert (track[:, 1] == track_id).all(), \ + 'The track id should not changed when interpolate a track.' + + frame_ids = track[:, 0] + interpolated_track = np.zeros((0, 7)) + # perform interpolation for the disconnected frames in the track. + for i in np.where(np.diff(frame_ids) > 1)[0]: + left_frame_id = frame_ids[i] + right_frame_id = frame_ids[i + 1] + num_disconnected_frames = int(right_frame_id - left_frame_id) + + if 1 < num_disconnected_frames < max_num_frames: + left_bbox = track[i, 2:6] + right_bbox = track[i + 1, 2:6] + + # perform interpolation for two adjacent tracklets. + for j in range(1, num_disconnected_frames): + cur_bbox = j / (num_disconnected_frames) * ( + right_bbox - left_bbox) + left_bbox + cur_result = np.ones((7, )) + cur_result[0] = j + left_frame_id + cur_result[1] = track_id + cur_result[2:6] = cur_bbox + + interpolated_track = np.concatenate( + (interpolated_track, cur_result[None]), axis=0) + + interpolated_track = np.concatenate((track, interpolated_track), + axis=0) + return interpolated_track + + def gaussian_smoothed_interpolation(self, + track: np.ndarray, + smooth_tau: int = 10) -> np.ndarray: + """Gaussian-Smoothed Interpolation. + + This function is proposed in + "StrongSORT: Make DeepSORT Great Again" + `StrongSORT`_. + + Args: + track (ndarray): With shape (N, 7). Each row denotes + (frame_id, track_id, x1, y1, x2, y2, score). + smooth_tau (int, optional): smoothing parameter in GSI. + Defaults to 10. + + Returns: + ndarray: The interpolated tracks with shape (N, 7). Each row + denotes (frame_id, track_id, x1, y1, x2, y2, score) + """ + len_scale = np.clip(smooth_tau * np.log(smooth_tau**3 / len(track)), + smooth_tau**-1, smooth_tau**2) + gpr = GPR(RBF(len_scale, 'fixed')) + t = track[:, 0].reshape(-1, 1) + x1 = track[:, 2].reshape(-1, 1) + y1 = track[:, 3].reshape(-1, 1) + x2 = track[:, 4].reshape(-1, 1) + y2 = track[:, 5].reshape(-1, 1) + gpr.fit(t, x1) + x1_gpr = gpr.predict(t) + gpr.fit(t, y1) + y1_gpr = gpr.predict(t) + gpr.fit(t, x2) + x2_gpr = gpr.predict(t) + gpr.fit(t, y2) + y2_gpr = gpr.predict(t) + gsi_track = [[ + t[i, 0], track[i, 1], x1_gpr[i], y1_gpr[i], x2_gpr[i], y2_gpr[i], + track[i, 6] + ] for i in range(len(t))] + return np.array(gsi_track) + + def forward(self, pred_tracks: np.ndarray) -> np.ndarray: + """Forward function. + + pred_tracks (ndarray): With shape (N, 7). Each row denotes + (frame_id, track_id, x1, y1, x2, y2, score). + + Returns: + ndarray: The interpolated tracks with shape (N, 7). Each row + denotes (frame_id, track_id, x1, y1, x2, y2, score). + """ + max_track_id = int(np.max(pred_tracks[:, 1])) + min_track_id = int(np.min(pred_tracks[:, 1])) + + # perform interpolation for each track + interpolated_tracks = [] + for track_id in range(min_track_id, max_track_id + 1): + inds = pred_tracks[:, 1] == track_id + track = pred_tracks[inds] + num_frames = len(track) + if num_frames <= 2: + continue + + if num_frames > self.min_num_frames: + interpolated_track = self._interpolate_track( + track, track_id, self.max_num_frames) + else: + interpolated_track = track + + if self.use_gsi: + interpolated_track = self.gaussian_smoothed_interpolation( + interpolated_track, self.smooth_tau) + + interpolated_tracks.append(interpolated_track) + + interpolated_tracks = np.concatenate(interpolated_tracks) + return interpolated_tracks[interpolated_tracks[:, 0].argsort()] diff --git a/mmdet/models/task_modules/tracking/kalman_filter.py b/mmdet/models/task_modules/tracking/kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ae1416af69bce17fd20dd5231eba2f12f7ed64 --- /dev/null +++ b/mmdet/models/task_modules/tracking/kalman_filter.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +import torch + +try: + import scipy.linalg + HAS_SCIPY = True +except ImportError: + HAS_SCIPY = False + +from mmdet.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class KalmanFilter: + """A simple Kalman filter for tracking bounding boxes in image space. + + The implementation is referred to https://github.com/nwojke/deep_sort. + + Args: + center_only (bool): If True, distance computation is done with + respect to the bounding box center position only. + Defaults to False. + use_nsa (bool): Whether to use the NSA (Noise Scale Adaptive) Kalman + Filter, which adaptively modulates the noise scale according to + the quality of detections. More details in + https://arxiv.org/abs/2202.11983. Defaults to False. + """ + chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919 + } + + def __init__(self, center_only: bool = False, use_nsa: bool = False): + if not HAS_SCIPY: + raise RuntimeError('sscikit-learn is not installed,\ + please install it by: pip install scikit-learn') + self.center_only = center_only + if self.center_only: + self.gating_threshold = self.chi2inv95[2] + else: + self.gating_threshold = self.chi2inv95[4] + + self.use_nsa = use_nsa + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement: np.array) -> Tuple[np.array, np.array]: + """Create track from unassociated measurement. + + Args: + measurement (ndarray): Bounding box coordinates (x, y, a, h) with + center position (x, y), aspect ratio a, and height h. + + Returns: + (ndarray, ndarray): Returns the mean vector (8 dimensional) and + covariance matrix (8x8 dimensional) of the new track. + Unobserved velocities are initialized to 0 mean. + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], 1e-5, + 10 * self._std_weight_velocity * measurement[3] + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean: np.array, + covariance: np.array) -> Tuple[np.array, np.array]: + """Run Kalman filter prediction step. + + Args: + mean (ndarray): The 8 dimensional mean vector of the object + state at the previous time step. + + covariance (ndarray): The 8x8 dimensional covariance matrix + of the object state at the previous time step. + + Returns: + (ndarray, ndarray): Returns the mean vector and covariance + matrix of the predicted state. Unobserved velocities are + initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], 1e-2, + self._std_weight_position * mean[3] + ] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], 1e-5, + self._std_weight_velocity * mean[3] + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(self._motion_mat, mean) + covariance = np.linalg.multi_dot( + (self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, + mean: np.array, + covariance: np.array, + bbox_score: float = 0.) -> Tuple[np.array, np.array]: + """Project state distribution to measurement space. + + Args: + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 + dimensional). + bbox_score (float): The confidence score of the bbox. + Defaults to 0. + + Returns: + (ndarray, ndarray): Returns the projected mean and covariance + matrix of the given state estimate. + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], 1e-1, + self._std_weight_position * mean[3] + ] + + if self.use_nsa: + std = [(1 - bbox_score) * x for x in std] + + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot( + (self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def update(self, + mean: np.array, + covariance: np.array, + measurement: np.array, + bbox_score: float = 0.) -> Tuple[np.array, np.array]: + """Run Kalman filter correction step. + + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 + dimensional). + measurement (ndarray): The 4 dimensional measurement vector + (x, y, a, h), where (x, y) is the center position, a the + aspect ratio, and h the height of the bounding box. + bbox_score (float): The confidence score of the bbox. + Defaults to 0. + + Returns: + (ndarray, ndarray): Returns the measurement-corrected state + distribution. + """ + projected_mean, projected_cov = \ + self.project(mean, covariance, bbox_score) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), + np.dot(covariance, + self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot( + (kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, + mean: np.array, + covariance: np.array, + measurements: np.array, + only_position: bool = False) -> np.array: + """Compute gating distance between state distribution and measurements. + + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + + Args: + mean (ndarray): Mean vector over the state distribution (8 + dimensional). + covariance (ndarray): Covariance of the state distribution (8x8 + dimensional). + measurements (ndarray): An Nx4 dimensional matrix of N + measurements, each in format (x, y, a, h) where (x, y) is the + bounding box center position, a the aspect ratio, and h the + height. + only_position (bool, optional): If True, distance computation is + done with respect to the bounding box center position only. + Defaults to False. + + Returns: + ndarray: Returns an array of length N, where the i-th element + contains the squared Mahalanobis distance between + (mean, covariance) and `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + cholesky_factor = np.linalg.cholesky(covariance) + d = measurements - mean + z = scipy.linalg.solve_triangular( + cholesky_factor, + d.T, + lower=True, + check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + + def track(self, tracks: dict, + bboxes: torch.Tensor) -> Tuple[dict, np.array]: + """Track forward. + + Args: + tracks (dict[int:dict]): Track buffer. + bboxes (Tensor): Detected bounding boxes. + + Returns: + (dict[int:dict], ndarray): Updated tracks and bboxes. + """ + costs = [] + for id, track in tracks.items(): + track.mean, track.covariance = self.predict( + track.mean, track.covariance) + gating_distance = self.gating_distance(track.mean, + track.covariance, + bboxes.cpu().numpy(), + self.center_only) + costs.append(gating_distance) + + costs = np.stack(costs, 0) + costs[costs > self.gating_threshold] = np.nan + return tracks, costs diff --git a/mmdet/models/task_modules/tracking/similarity.py b/mmdet/models/task_modules/tracking/similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..730e43b86214ae92ffdcab8ae39e6f9261075caa --- /dev/null +++ b/mmdet/models/task_modules/tracking/similarity.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from torch import Tensor + + +def embed_similarity(key_embeds: Tensor, + ref_embeds: Tensor, + method: str = 'dot_product', + temperature: int = -1) -> Tensor: + """Calculate feature similarity from embeddings. + + Args: + key_embeds (Tensor): Shape (N1, C). + ref_embeds (Tensor): Shape (N2, C). + method (str, optional): Method to calculate the similarity, + options are 'dot_product' and 'cosine'. Defaults to + 'dot_product'. + temperature (int, optional): Softmax temperature. Defaults to -1. + + Returns: + Tensor: Similarity matrix of shape (N1, N2). + """ + assert method in ['dot_product', 'cosine'] + + if method == 'cosine': + key_embeds = F.normalize(key_embeds, p=2, dim=1) + ref_embeds = F.normalize(ref_embeds, p=2, dim=1) + + similarity = torch.mm(key_embeds, ref_embeds.T) + + if temperature > 0: + similarity /= float(temperature) + return similarity diff --git a/mmdet/models/test_time_augs/__init__.py b/mmdet/models/test_time_augs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e4926efb011b45b3ab7d3d303fb2d105aaa192 --- /dev/null +++ b/mmdet/models/test_time_augs/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .det_tta import DetTTAModel +from .merge_augs import (merge_aug_bboxes, merge_aug_masks, + merge_aug_proposals, merge_aug_results, + merge_aug_scores) + +__all__ = [ + 'merge_aug_bboxes', 'merge_aug_masks', 'merge_aug_proposals', + 'merge_aug_scores', 'merge_aug_results', 'DetTTAModel' +] diff --git a/mmdet/models/test_time_augs/det_tta.py b/mmdet/models/test_time_augs/det_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..95f91db9e1250358db0e1a572cf4c37cc7fe6e6f --- /dev/null +++ b/mmdet/models/test_time_augs/det_tta.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from mmcv.ops import batched_nms +from mmengine.model import BaseTTAModel +from mmengine.registry import MODELS +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import bbox_flip + + +@MODELS.register_module() +class DetTTAModel(BaseTTAModel): + """Merge augmented detection results, only bboxes corresponding score under + flipping and multi-scale resizing can be processed now. + + Examples: + >>> tta_model = dict( + >>> type='DetTTAModel', + >>> tta_cfg=dict(nms=dict( + >>> type='nms', + >>> iou_threshold=0.5), + >>> max_per_img=100)) + >>> + >>> tta_pipeline = [ + >>> dict(type='LoadImageFromFile', + >>> backend_args=None), + >>> dict( + >>> type='TestTimeAug', + >>> transforms=[[ + >>> dict(type='Resize', + >>> scale=(1333, 800), + >>> keep_ratio=True), + >>> ], [ + >>> dict(type='RandomFlip', prob=1.), + >>> dict(type='RandomFlip', prob=0.) + >>> ], [ + >>> dict( + >>> type='PackDetInputs', + >>> meta_keys=('img_id', 'img_path', 'ori_shape', + >>> 'img_shape', 'scale_factor', 'flip', + >>> 'flip_direction')) + >>> ]])] + """ + + def __init__(self, tta_cfg=None, **kwargs): + super().__init__(**kwargs) + self.tta_cfg = tta_cfg + + def merge_aug_bboxes(self, aug_bboxes: List[Tensor], + aug_scores: List[Tensor], + img_metas: List[str]) -> Tuple[Tensor, Tensor]: + """Merge augmented detection bboxes and scores. + + Args: + aug_bboxes (list[Tensor]): shape (n, 4*#class) + aug_scores (list[Tensor] or None): shape (n, #class) + Returns: + tuple[Tensor]: ``bboxes`` with shape (n,4), where + 4 represent (tl_x, tl_y, br_x, br_y) + and ``scores`` with shape (n,). + """ + recovered_bboxes = [] + for bboxes, img_info in zip(aug_bboxes, img_metas): + ori_shape = img_info['ori_shape'] + flip = img_info['flip'] + flip_direction = img_info['flip_direction'] + if flip: + bboxes = bbox_flip( + bboxes=bboxes, + img_shape=ori_shape, + direction=flip_direction) + recovered_bboxes.append(bboxes) + bboxes = torch.cat(recovered_bboxes, dim=0) + if aug_scores is None: + return bboxes + else: + scores = torch.cat(aug_scores, dim=0) + return bboxes, scores + + def merge_preds(self, data_samples_list: List[List[DetDataSample]]): + """Merge batch predictions of enhanced data. + + Args: + data_samples_list (List[List[DetDataSample]]): List of predictions + of all enhanced data. The outer list indicates images, and the + inner list corresponds to the different views of one image. + Each element of the inner list is a ``DetDataSample``. + Returns: + List[DetDataSample]: Merged batch prediction. + """ + merged_data_samples = [] + for data_samples in data_samples_list: + merged_data_samples.append(self._merge_single_sample(data_samples)) + return merged_data_samples + + def _merge_single_sample( + self, data_samples: List[DetDataSample]) -> DetDataSample: + """Merge predictions which come form the different views of one image + to one prediction. + + Args: + data_samples (List[DetDataSample]): List of predictions + of enhanced data which come form one image. + Returns: + List[DetDataSample]: Merged prediction. + """ + aug_bboxes = [] + aug_scores = [] + aug_labels = [] + img_metas = [] + # TODO: support instance segmentation TTA + assert data_samples[0].pred_instances.get('masks', None) is None, \ + 'TTA of instance segmentation does not support now.' + for data_sample in data_samples: + aug_bboxes.append(data_sample.pred_instances.bboxes) + aug_scores.append(data_sample.pred_instances.scores) + aug_labels.append(data_sample.pred_instances.labels) + img_metas.append(data_sample.metainfo) + + merged_bboxes, merged_scores = self.merge_aug_bboxes( + aug_bboxes, aug_scores, img_metas) + merged_labels = torch.cat(aug_labels, dim=0) + + if merged_bboxes.numel() == 0: + return data_samples[0] + + det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores, + merged_labels, self.tta_cfg.nms) + + det_bboxes = det_bboxes[:self.tta_cfg.max_per_img] + det_labels = merged_labels[keep_idxs][:self.tta_cfg.max_per_img] + + results = InstanceData() + _det_bboxes = det_bboxes.clone() + results.bboxes = _det_bboxes[:, :-1] + results.scores = _det_bboxes[:, -1] + results.labels = det_labels + det_results = data_samples[0] + det_results.pred_instances = results + return det_results diff --git a/mmdet/models/test_time_augs/merge_augs.py b/mmdet/models/test_time_augs/merge_augs.py new file mode 100644 index 0000000000000000000000000000000000000000..5935a8614c39d70253a09a339f51c144661c64fb --- /dev/null +++ b/mmdet/models/test_time_augs/merge_augs.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import List, Optional, Union + +import numpy as np +import torch +from mmcv.ops import nms +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdet.structures.bbox import bbox_mapping_back + + +# TODO remove this, never be used in mmdet +def merge_aug_proposals(aug_proposals, img_metas, cfg): + """Merge augmented proposals (multiscale, flip, etc.) + + Args: + aug_proposals (list[Tensor]): proposals from different testing + schemes, shape (n, 5). Note that they are not rescaled to the + original image size. + + img_metas (list[dict]): list of image info dict where each dict has: + 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmdet/datasets/pipelines/formatting.py:Collect`. + + cfg (dict): rpn test config. + + Returns: + Tensor: shape (n, 4), proposals corresponding to original image scale. + """ + + cfg = copy.deepcopy(cfg) + + # deprecate arguments warning + if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: + warnings.warn( + 'In rpn_proposal or test_cfg, ' + 'nms_thr has been moved to a dict named nms as ' + 'iou_threshold, max_num has been renamed as max_per_img, ' + 'name of original arguments and the way to specify ' + 'iou_threshold of NMS will be deprecated.') + if 'nms' not in cfg: + cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) + if 'max_num' in cfg: + if 'max_per_img' in cfg: + assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \ + f'max_per_img at the same time, but get {cfg.max_num} ' \ + f'and {cfg.max_per_img} respectively' \ + f'Please delete max_num which will be deprecated.' + else: + cfg.max_per_img = cfg.max_num + if 'nms_thr' in cfg: + assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \ + f'iou_threshold in nms and ' \ + f'nms_thr at the same time, but get ' \ + f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \ + f' respectively. Please delete the nms_thr ' \ + f'which will be deprecated.' + + recovered_proposals = [] + for proposals, img_info in zip(aug_proposals, img_metas): + img_shape = img_info['img_shape'] + scale_factor = img_info['scale_factor'] + flip = img_info['flip'] + flip_direction = img_info['flip_direction'] + _proposals = proposals.clone() + _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape, + scale_factor, flip, + flip_direction) + recovered_proposals.append(_proposals) + aug_proposals = torch.cat(recovered_proposals, dim=0) + merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(), + aug_proposals[:, -1].contiguous(), + cfg.nms.iou_threshold) + scores = merged_proposals[:, 4] + _, order = scores.sort(0, descending=True) + num = min(cfg.max_per_img, merged_proposals.shape[0]) + order = order[:num] + merged_proposals = merged_proposals[order, :] + return merged_proposals + + +# TODO remove this, never be used in mmdet +def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg): + """Merge augmented detection bboxes and scores. + + Args: + aug_bboxes (list[Tensor]): shape (n, 4*#class) + aug_scores (list[Tensor] or None): shape (n, #class) + img_shapes (list[Tensor]): shape (3, ). + rcnn_test_cfg (dict): rcnn test config. + + Returns: + tuple: (bboxes, scores) + """ + recovered_bboxes = [] + for bboxes, img_info in zip(aug_bboxes, img_metas): + img_shape = img_info[0]['img_shape'] + scale_factor = img_info[0]['scale_factor'] + flip = img_info[0]['flip'] + flip_direction = img_info[0]['flip_direction'] + bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, + flip_direction) + recovered_bboxes.append(bboxes) + bboxes = torch.stack(recovered_bboxes).mean(dim=0) + if aug_scores is None: + return bboxes + else: + scores = torch.stack(aug_scores).mean(dim=0) + return bboxes, scores + + +def merge_aug_results(aug_batch_results, aug_batch_img_metas): + """Merge augmented detection results, only bboxes corresponding score under + flipping and multi-scale resizing can be processed now. + + Args: + aug_batch_results (list[list[[obj:`InstanceData`]]): + Detection results of multiple images with + different augmentations. + The outer list indicate the augmentation . The inter + list indicate the batch dimension. + Each item usually contains the following keys. + + - scores (Tensor): Classification scores, in shape + (num_instance,) + - labels (Tensor): Labels of bboxes, in shape + (num_instances,). + - bboxes (Tensor): In shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + aug_batch_img_metas (list[list[dict]]): The outer list + indicates test-time augs (multiscale, flip, etc.) + and the inner list indicates + images in a batch. Each dict in the list contains + information of an image in the batch. + + Returns: + batch_results (list[obj:`InstanceData`]): Same with + the input `aug_results` except that all bboxes have + been mapped to the original scale. + """ + num_augs = len(aug_batch_results) + num_imgs = len(aug_batch_results[0]) + + batch_results = [] + aug_batch_results = copy.deepcopy(aug_batch_results) + for img_id in range(num_imgs): + aug_results = [] + for aug_id in range(num_augs): + img_metas = aug_batch_img_metas[aug_id][img_id] + results = aug_batch_results[aug_id][img_id] + + img_shape = img_metas['img_shape'] + scale_factor = img_metas['scale_factor'] + flip = img_metas['flip'] + flip_direction = img_metas['flip_direction'] + bboxes = bbox_mapping_back(results.bboxes, img_shape, scale_factor, + flip, flip_direction) + results.bboxes = bboxes + aug_results.append(results) + merged_aug_results = results.cat(aug_results) + batch_results.append(merged_aug_results) + + return batch_results + + +def merge_aug_scores(aug_scores): + """Merge augmented bbox scores.""" + if isinstance(aug_scores[0], torch.Tensor): + return torch.mean(torch.stack(aug_scores), dim=0) + else: + return np.mean(aug_scores, axis=0) + + +def merge_aug_masks(aug_masks: List[Tensor], + img_metas: dict, + weights: Optional[Union[list, Tensor]] = None) -> Tensor: + """Merge augmented mask prediction. + + Args: + aug_masks (list[Tensor]): each has shape + (n, c, h, w). + img_metas (dict): Image information. + weights (list or Tensor): Weight of each aug_masks, + the length should be n. + + Returns: + Tensor: has shape (n, c, h, w) + """ + recovered_masks = [] + for i, mask in enumerate(aug_masks): + if weights is not None: + assert len(weights) == len(aug_masks) + weight = weights[i] + else: + weight = 1 + flip = img_metas.get('flip', False) + if flip: + flip_direction = img_metas['flip_direction'] + if flip_direction == 'horizontal': + mask = mask[:, :, :, ::-1] + elif flip_direction == 'vertical': + mask = mask[:, :, ::-1, :] + elif flip_direction == 'diagonal': + mask = mask[:, :, :, ::-1] + mask = mask[:, :, ::-1, :] + else: + raise ValueError( + f"Invalid flipping direction '{flip_direction}'") + recovered_masks.append(mask[None, :] * weight) + + merged_masks = torch.cat(recovered_masks, 0).mean(dim=0) + if weights is not None: + merged_masks = merged_masks * len(weights) / sum(weights) + return merged_masks diff --git a/mmdet/models/trackers/__init__.py b/mmdet/models/trackers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00284bb7b40dd007c28b6cc9175ac26a52c6c528 --- /dev/null +++ b/mmdet/models/trackers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_tracker import BaseTracker +from .byte_tracker import ByteTracker +from .masktrack_rcnn_tracker import MaskTrackRCNNTracker +from .ocsort_tracker import OCSORTTracker +from .quasi_dense_tracker import QuasiDenseTracker +from .sort_tracker import SORTTracker +from .strongsort_tracker import StrongSORTTracker + +__all__ = [ + 'BaseTracker', 'ByteTracker', 'QuasiDenseTracker', 'SORTTracker', + 'StrongSORTTracker', 'OCSORTTracker', 'MaskTrackRCNNTracker' +] diff --git a/mmdet/models/trackers/base_tracker.py b/mmdet/models/trackers/base_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf188653cd9adda59decd45f65fc4ede63fe3a7 --- /dev/null +++ b/mmdet/models/trackers/base_tracker.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from addict import Dict + + +class BaseTracker(metaclass=ABCMeta): + """Base tracker model. + + Args: + momentums (dict[str:float], optional): Momentums to update the buffers. + The `str` indicates the name of the buffer while the `float` + indicates the momentum. Defaults to None. + num_frames_retain (int, optional). If a track is disappeared more than + `num_frames_retain` frames, it will be deleted in the memo. + Defaults to 10. + """ + + def __init__(self, + momentums: Optional[dict] = None, + num_frames_retain: int = 10) -> None: + super().__init__() + if momentums is not None: + assert isinstance(momentums, dict), 'momentums must be a dict' + self.momentums = momentums + self.num_frames_retain = num_frames_retain + + self.reset() + + def reset(self) -> None: + """Reset the buffer of the tracker.""" + self.num_tracks = 0 + self.tracks = dict() + + @property + def empty(self) -> bool: + """Whether the buffer is empty or not.""" + return False if self.tracks else True + + @property + def ids(self) -> List[dict]: + """All ids in the tracker.""" + return list(self.tracks.keys()) + + @property + def with_reid(self) -> bool: + """bool: whether the framework has a reid model""" + return hasattr(self, 'reid') and self.reid is not None + + def update(self, **kwargs) -> None: + """Update the tracker. + + Args: + kwargs (dict[str: Tensor | int]): The `str` indicates the + name of the input variable. `ids` and `frame_ids` are + obligatory in the keys. + """ + memo_items = [k for k, v in kwargs.items() if v is not None] + rm_items = [k for k in kwargs.keys() if k not in memo_items] + for item in rm_items: + kwargs.pop(item) + if not hasattr(self, 'memo_items'): + self.memo_items = memo_items + else: + assert memo_items == self.memo_items + + assert 'ids' in memo_items + num_objs = len(kwargs['ids']) + id_indice = memo_items.index('ids') + assert 'frame_ids' in memo_items + frame_id = int(kwargs['frame_ids']) + if isinstance(kwargs['frame_ids'], int): + kwargs['frame_ids'] = torch.tensor([kwargs['frame_ids']] * + num_objs) + # cur_frame_id = int(kwargs['frame_ids'][0]) + for k, v in kwargs.items(): + if len(v) != num_objs: + raise ValueError('kwargs value must both equal') + + for obj in zip(*kwargs.values()): + id = int(obj[id_indice]) + if id in self.tracks: + self.update_track(id, obj) + else: + self.init_track(id, obj) + + self.pop_invalid_tracks(frame_id) + + def pop_invalid_tracks(self, frame_id: int) -> None: + """Pop out invalid tracks.""" + invalid_ids = [] + for k, v in self.tracks.items(): + if frame_id - v['frame_ids'][-1] >= self.num_frames_retain: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + def update_track(self, id: int, obj: Tuple[torch.Tensor]): + """Update a track.""" + for k, v in zip(self.memo_items, obj): + v = v[None] + if self.momentums is not None and k in self.momentums: + m = self.momentums[k] + self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v + else: + self.tracks[id][k].append(v) + + def init_track(self, id: int, obj: Tuple[torch.Tensor]): + """Initialize a track.""" + self.tracks[id] = Dict() + for k, v in zip(self.memo_items, obj): + v = v[None] + if self.momentums is not None and k in self.momentums: + self.tracks[id][k] = v + else: + self.tracks[id][k] = [v] + + @property + def memo(self) -> dict: + """Return all buffers in the tracker.""" + outs = Dict() + for k in self.memo_items: + outs[k] = [] + + for id, objs in self.tracks.items(): + for k, v in objs.items(): + if k not in outs: + continue + if self.momentums is not None and k in self.momentums: + v = v + else: + v = v[-1] + outs[k].append(v) + + for k, v in outs.items(): + outs[k] = torch.cat(v, dim=0) + return outs + + def get(self, + item: str, + ids: Optional[list] = None, + num_samples: Optional[int] = None, + behavior: Optional[str] = None) -> torch.Tensor: + """Get the buffer of a specific item. + + Args: + item (str): The demanded item. + ids (list[int], optional): The demanded ids. Defaults to None. + num_samples (int, optional): Number of samples to calculate the + results. Defaults to None. + behavior (str, optional): Behavior to calculate the results. + Options are `mean` | None. Defaults to None. + + Returns: + Tensor: The results of the demanded item. + """ + if ids is None: + ids = self.ids + + outs = [] + for id in ids: + out = self.tracks[id][item] + if isinstance(out, list): + if num_samples is not None: + out = out[-num_samples:] + out = torch.cat(out, dim=0) + if behavior == 'mean': + out = out.mean(dim=0, keepdim=True) + elif behavior is None: + out = out[None] + else: + raise NotImplementedError() + else: + out = out[-1] + outs.append(out) + return torch.cat(outs, dim=0) + + @abstractmethod + def track(self, *args, **kwargs): + """Tracking forward function.""" + pass + + def crop_imgs(self, + img: torch.Tensor, + meta_info: dict, + bboxes: torch.Tensor, + rescale: bool = False) -> torch.Tensor: + """Crop the images according to some bounding boxes. Typically for re- + identification sub-module. + + Args: + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + meta_info (dict): image information dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + bboxes (Tensor): of shape (N, 4) or (N, 5). + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the scale of the image. Defaults to False. + + Returns: + Tensor: Image tensor of shape (T, C, H, W). + """ + h, w = meta_info['img_shape'] + img = img[:, :, :h, :w] + if rescale: + factor_x, factor_y = meta_info['scale_factor'] + bboxes[:, :4] *= torch.tensor( + [factor_x, factor_y, factor_x, factor_y]).to(bboxes.device) + bboxes[:, 0] = torch.clamp(bboxes[:, 0], min=0, max=w - 1) + bboxes[:, 1] = torch.clamp(bboxes[:, 1], min=0, max=h - 1) + bboxes[:, 2] = torch.clamp(bboxes[:, 2], min=1, max=w) + bboxes[:, 3] = torch.clamp(bboxes[:, 3], min=1, max=h) + + crop_imgs = [] + for bbox in bboxes: + x1, y1, x2, y2 = map(int, bbox) + if x2 <= x1: + x2 = x1 + 1 + if y2 <= y1: + y2 = y1 + 1 + crop_img = img[:, :, y1:y2, x1:x2] + if self.reid.get('img_scale', False): + crop_img = F.interpolate( + crop_img, + size=self.reid['img_scale'], + mode='bilinear', + align_corners=False) + crop_imgs.append(crop_img) + + if len(crop_imgs) > 0: + return torch.cat(crop_imgs, dim=0) + elif self.reid.get('img_scale', False): + _h, _w = self.reid['img_scale'] + return img.new_zeros((0, 3, _h, _w)) + else: + return img.new_zeros((0, 3, h, w)) diff --git a/mmdet/models/trackers/byte_tracker.py b/mmdet/models/trackers/byte_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..11f3adc53c58339f6289cbfa77aed738259fc98c --- /dev/null +++ b/mmdet/models/trackers/byte_tracker.py @@ -0,0 +1,334 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +try: + import lap +except ImportError: + lap = None +import numpy as np +import torch +from mmengine.structures import InstanceData + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcyah) +from .base_tracker import BaseTracker + + +@MODELS.register_module() +class ByteTracker(BaseTracker): + """Tracker for ByteTrack. + + Args: + motion (dict): Configuration of motion. Defaults to None. + obj_score_thrs (dict): Detection score threshold for matching objects. + - high (float): Threshold of the first matching. Defaults to 0.6. + - low (float): Threshold of the second matching. Defaults to 0.1. + init_track_thr (float): Detection score threshold for initializing a + new tracklet. Defaults to 0.7. + weight_iou_with_det_scores (bool): Whether using detection scores to + weight IOU which is used for matching. Defaults to True. + match_iou_thrs (dict): IOU distance threshold for matching between two + frames. + - high (float): Threshold of the first matching. Defaults to 0.1. + - low (float): Threshold of the second matching. Defaults to 0.5. + - tentative (float): Threshold of the matching for tentative + tracklets. Defaults to 0.3. + num_tentatives (int, optional): Number of continuous frames to confirm + a track. Defaults to 3. + """ + + def __init__(self, + motion: Optional[dict] = None, + obj_score_thrs: dict = dict(high=0.6, low=0.1), + init_track_thr: float = 0.7, + weight_iou_with_det_scores: bool = True, + match_iou_thrs: dict = dict(high=0.1, low=0.5, tentative=0.3), + num_tentatives: int = 3, + **kwargs): + super().__init__(**kwargs) + + if lap is None: + raise RuntimeError('lap is not installed,\ + please install it by: pip install lap') + if motion is not None: + self.motion = TASK_UTILS.build(motion) + + self.obj_score_thrs = obj_score_thrs + self.init_track_thr = init_track_thr + + self.weight_iou_with_det_scores = weight_iou_with_det_scores + self.match_iou_thrs = match_iou_thrs + + self.num_tentatives = num_tentatives + + @property + def confirmed_ids(self) -> List: + """Confirmed ids in the tracker.""" + ids = [id for id, track in self.tracks.items() if not track.tentative] + return ids + + @property + def unconfirmed_ids(self) -> List: + """Unconfirmed ids in the tracker.""" + ids = [id for id, track in self.tracks.items() if track.tentative] + return ids + + def init_track(self, id: int, obj: Tuple[torch.Tensor]) -> None: + """Initialize a track.""" + super().init_track(id, obj) + if self.tracks[id].frame_ids[-1] == 0: + self.tracks[id].tentative = False + else: + self.tracks[id].tentative = True + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate( + bbox) + + def update_track(self, id: int, obj: Tuple[torch.Tensor]) -> None: + """Update a track.""" + super().update_track(id, obj) + if self.tracks[id].tentative: + if len(self.tracks[id]['bboxes']) >= self.num_tentatives: + self.tracks[id].tentative = False + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + track_label = self.tracks[id]['labels'][-1] + label_idx = self.memo_items.index('labels') + obj_label = obj[label_idx] + assert obj_label == track_label + self.tracks[id].mean, self.tracks[id].covariance = self.kf.update( + self.tracks[id].mean, self.tracks[id].covariance, bbox) + + def pop_invalid_tracks(self, frame_id: int) -> None: + """Pop out invalid tracks.""" + invalid_ids = [] + for k, v in self.tracks.items(): + # case1: disappeared frames >= self.num_frames_retrain + case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain + # case2: tentative tracks but not matched in this frame + case2 = v.tentative and v['frame_ids'][-1] != frame_id + if case1 or case2: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + def assign_ids( + self, + ids: List[int], + det_bboxes: torch.Tensor, + det_labels: torch.Tensor, + det_scores: torch.Tensor, + weight_iou_with_det_scores: Optional[bool] = False, + match_iou_thr: Optional[float] = 0.5 + ) -> Tuple[np.ndarray, np.ndarray]: + """Assign ids. + + Args: + ids (list[int]): Tracking ids. + det_bboxes (Tensor): of shape (N, 4) + det_labels (Tensor): of shape (N,) + det_scores (Tensor): of shape (N,) + weight_iou_with_det_scores (bool, optional): Whether using + detection scores to weight IOU which is used for matching. + Defaults to False. + match_iou_thr (float, optional): Matching threshold. + Defaults to 0.5. + + Returns: + tuple(np.ndarray, np.ndarray): The assigning ids. + """ + # get track_bboxes + track_bboxes = np.zeros((0, 4)) + for id in ids: + track_bboxes = np.concatenate( + (track_bboxes, self.tracks[id].mean[:4][None]), axis=0) + track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes) + track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes) + + # compute distance + ious = bbox_overlaps(track_bboxes, det_bboxes) + if weight_iou_with_det_scores: + ious *= det_scores + # support multi-class association + track_labels = torch.tensor([ + self.tracks[id]['labels'][-1] for id in ids + ]).to(det_bboxes.device) + + cate_match = det_labels[None, :] == track_labels[:, None] + # to avoid det and track of different categories are matched + cate_cost = (1 - cate_match.int()) * 1e6 + + dists = (1 - ious + cate_cost).cpu().numpy() + + # bipartite match + if dists.size > 0: + cost, row, col = lap.lapjv( + dists, extend_cost=True, cost_limit=1 - match_iou_thr) + else: + row = np.zeros(len(ids)).astype(np.int32) - 1 + col = np.zeros(len(det_bboxes)).astype(np.int32) - 1 + return row, col + + def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData: + """Tracking forward function. + + Args: + data_sample (:obj:`DetDataSample`): The data sample. + It includes information such as `pred_instances`. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get('frame_id', -1) + if frame_id == 0: + self.reset() + if not hasattr(self, 'kf'): + self.kf = self.motion + + if self.empty or bboxes.size(0) == 0: + valid_inds = scores > self.init_track_thr + scores = scores[valid_inds] + bboxes = bboxes[valid_inds] + labels = labels[valid_inds] + num_new_tracks = bboxes.size(0) + ids = torch.arange(self.num_tracks, + self.num_tracks + num_new_tracks).to(labels) + self.num_tracks += num_new_tracks + + else: + # 0. init + ids = torch.full((bboxes.size(0), ), + -1, + dtype=labels.dtype, + device=labels.device) + + # get the detection bboxes for the first association + first_det_inds = scores > self.obj_score_thrs['high'] + first_det_bboxes = bboxes[first_det_inds] + first_det_labels = labels[first_det_inds] + first_det_scores = scores[first_det_inds] + first_det_ids = ids[first_det_inds] + + # get the detection bboxes for the second association + second_det_inds = (~first_det_inds) & ( + scores > self.obj_score_thrs['low']) + second_det_bboxes = bboxes[second_det_inds] + second_det_labels = labels[second_det_inds] + second_det_scores = scores[second_det_inds] + second_det_ids = ids[second_det_inds] + + # 1. use Kalman Filter to predict current location + for id in self.confirmed_ids: + # track is lost in previous frame + if self.tracks[id].frame_ids[-1] != frame_id - 1: + self.tracks[id].mean[7] = 0 + (self.tracks[id].mean, + self.tracks[id].covariance) = self.kf.predict( + self.tracks[id].mean, self.tracks[id].covariance) + + # 2. first match + first_match_track_inds, first_match_det_inds = self.assign_ids( + self.confirmed_ids, first_det_bboxes, first_det_labels, + first_det_scores, self.weight_iou_with_det_scores, + self.match_iou_thrs['high']) + # '-1' mean a detection box is not matched with tracklets in + # previous frame + valid = first_match_det_inds > -1 + first_det_ids[valid] = torch.tensor( + self.confirmed_ids)[first_match_det_inds[valid]].to(labels) + + first_match_det_bboxes = first_det_bboxes[valid] + first_match_det_labels = first_det_labels[valid] + first_match_det_scores = first_det_scores[valid] + first_match_det_ids = first_det_ids[valid] + assert (first_match_det_ids > -1).all() + + first_unmatch_det_bboxes = first_det_bboxes[~valid] + first_unmatch_det_labels = first_det_labels[~valid] + first_unmatch_det_scores = first_det_scores[~valid] + first_unmatch_det_ids = first_det_ids[~valid] + assert (first_unmatch_det_ids == -1).all() + + # 3. use unmatched detection bboxes from the first match to match + # the unconfirmed tracks + (tentative_match_track_inds, + tentative_match_det_inds) = self.assign_ids( + self.unconfirmed_ids, first_unmatch_det_bboxes, + first_unmatch_det_labels, first_unmatch_det_scores, + self.weight_iou_with_det_scores, + self.match_iou_thrs['tentative']) + valid = tentative_match_det_inds > -1 + first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[ + tentative_match_det_inds[valid]].to(labels) + + # 4. second match for unmatched tracks from the first match + first_unmatch_track_ids = [] + for i, id in enumerate(self.confirmed_ids): + # tracklet is not matched in the first match + case_1 = first_match_track_inds[i] == -1 + # tracklet is not lost in the previous frame + case_2 = self.tracks[id].frame_ids[-1] == frame_id - 1 + if case_1 and case_2: + first_unmatch_track_ids.append(id) + + second_match_track_inds, second_match_det_inds = self.assign_ids( + first_unmatch_track_ids, second_det_bboxes, second_det_labels, + second_det_scores, False, self.match_iou_thrs['low']) + valid = second_match_det_inds > -1 + second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[ + second_match_det_inds[valid]].to(ids) + + # 5. gather all matched detection bboxes from step 2-4 + # we only keep matched detection bboxes in second match, which + # means the id != -1 + valid = second_det_ids > -1 + bboxes = torch.cat( + (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0) + bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0) + + labels = torch.cat( + (first_match_det_labels, first_unmatch_det_labels), dim=0) + labels = torch.cat((labels, second_det_labels[valid]), dim=0) + + scores = torch.cat( + (first_match_det_scores, first_unmatch_det_scores), dim=0) + scores = torch.cat((scores, second_det_scores[valid]), dim=0) + + ids = torch.cat((first_match_det_ids, first_unmatch_det_ids), + dim=0) + ids = torch.cat((ids, second_det_ids[valid]), dim=0) + + # 6. assign new ids + new_track_inds = ids == -1 + ids[new_track_inds] = torch.arange( + self.num_tracks, + self.num_tracks + new_track_inds.sum()).to(labels) + self.num_tracks += new_track_inds.sum() + + self.update( + ids=ids, + bboxes=bboxes, + scores=scores, + labels=labels, + frame_ids=frame_id) + + # update pred_track_instances + pred_track_instances = InstanceData() + pred_track_instances.bboxes = bboxes + pred_track_instances.labels = labels + pred_track_instances.scores = scores + pred_track_instances.instances_id = ids + + return pred_track_instances diff --git a/mmdet/models/trackers/masktrack_rcnn_tracker.py b/mmdet/models/trackers/masktrack_rcnn_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..cc167786b8b412629885a4f134a1bf79f3dfaa93 --- /dev/null +++ b/mmdet/models/trackers/masktrack_rcnn_tracker.py @@ -0,0 +1,189 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import bbox_overlaps +from .base_tracker import BaseTracker + + +@MODELS.register_module() +class MaskTrackRCNNTracker(BaseTracker): + """Tracker for MaskTrack R-CNN. + + Args: + match_weights (dict[str : float]): The Weighting factor when computing + the match score. It contains keys as follows: + + - det_score (float): The coefficient of `det_score` when computing + match score. + - iou (float): The coefficient of `ious` when computing match + score. + - det_label (float): The coefficient of `label_deltas` when + computing match score. + """ + + def __init__(self, + match_weights: dict = dict( + det_score=1.0, iou=2.0, det_label=10.0), + **kwargs): + super().__init__(**kwargs) + self.match_weights = match_weights + + def get_match_score(self, bboxes: Tensor, labels: Tensor, scores: Tensor, + prev_bboxes: Tensor, prev_labels: Tensor, + similarity_logits: Tensor) -> Tensor: + """Get the match score. + + Args: + bboxes (torch.Tensor): of shape (num_current_bboxes, 4) in + [tl_x, tl_y, br_x, br_y] format. Denoting the detection + bboxes of current frame. + labels (torch.Tensor): of shape (num_current_bboxes, ) + scores (torch.Tensor): of shape (num_current_bboxes, ) + prev_bboxes (torch.Tensor): of shape (num_previous_bboxes, 4) in + [tl_x, tl_y, br_x, br_y] format. Denoting the detection bboxes + of previous frame. + prev_labels (torch.Tensor): of shape (num_previous_bboxes, ) + similarity_logits (torch.Tensor): of shape (num_current_bboxes, + num_previous_bboxes + 1). Denoting the similarity logits from + track head. + + Returns: + torch.Tensor: The matching score of shape (num_current_bboxes, + num_previous_bboxes + 1) + """ + similarity_scores = similarity_logits.softmax(dim=1) + + ious = bbox_overlaps(bboxes, prev_bboxes) + iou_dummy = ious.new_zeros(ious.shape[0], 1) + ious = torch.cat((iou_dummy, ious), dim=1) + + label_deltas = (labels.view(-1, 1) == prev_labels).float() + label_deltas_dummy = label_deltas.new_ones(label_deltas.shape[0], 1) + label_deltas = torch.cat((label_deltas_dummy, label_deltas), dim=1) + + match_score = similarity_scores.log() + match_score += self.match_weights['det_score'] * \ + scores.view(-1, 1).log() + match_score += self.match_weights['iou'] * ious + match_score += self.match_weights['det_label'] * label_deltas + + return match_score + + def assign_ids(self, match_scores: Tensor): + num_prev_bboxes = match_scores.shape[1] - 1 + _, match_ids = match_scores.max(dim=1) + + ids = match_ids.new_zeros(match_ids.shape[0]) - 1 + best_match_scores = match_scores.new_zeros(num_prev_bboxes) - 1e6 + for idx, match_id in enumerate(match_ids): + if match_id == 0: + ids[idx] = self.num_tracks + self.num_tracks += 1 + else: + match_score = match_scores[idx, match_id] + # TODO: fix the bug where multiple candidate might match + # with the same previous object. + if match_score > best_match_scores[match_id - 1]: + ids[idx] = self.ids[match_id - 1] + best_match_scores[match_id - 1] = match_score + return ids, best_match_scores + + def track(self, + model: torch.nn.Module, + feats: List[torch.Tensor], + data_sample: DetDataSample, + rescale=True, + **kwargs) -> InstanceData: + """Tracking forward function. + + Args: + model (nn.Module): VIS model. + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + The T denotes the number of key images and usually is 1 in + MaskTrackRCNN method. + feats (list[Tensor]): Multi level feature maps of `img`. + data_sample (:obj:`TrackDataSample`): The data sample. + It includes information such as `pred_det_instances`. + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the original scale of the image. Defaults to + True. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + masks = data_sample.pred_instances.masks + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get('frame_id', -1) + # create pred_track_instances + pred_track_instances = InstanceData() + + if bboxes.shape[0] == 0: + ids = torch.zeros_like(labels) + pred_track_instances = data_sample.pred_instances.clone() + pred_track_instances.instances_id = ids + return pred_track_instances + + rescaled_bboxes = bboxes.clone() + if rescale: + scale_factor = rescaled_bboxes.new_tensor( + metainfo['scale_factor']).repeat((1, 2)) + rescaled_bboxes = rescaled_bboxes * scale_factor + roi_feats, _ = model.track_head.extract_roi_feats( + feats, [rescaled_bboxes]) + + if self.empty: + num_new_tracks = bboxes.size(0) + ids = torch.arange( + self.num_tracks, + self.num_tracks + num_new_tracks, + dtype=torch.long) + self.num_tracks += num_new_tracks + else: + prev_bboxes = self.get('bboxes') + prev_labels = self.get('labels') + prev_roi_feats = self.get('roi_feats') + + similarity_logits = model.track_head.predict( + roi_feats, prev_roi_feats) + match_scores = self.get_match_score(bboxes, labels, scores, + prev_bboxes, prev_labels, + similarity_logits) + ids, _ = self.assign_ids(match_scores) + + valid_inds = ids > -1 + ids = ids[valid_inds] + bboxes = bboxes[valid_inds] + labels = labels[valid_inds] + scores = scores[valid_inds] + masks = masks[valid_inds] + roi_feats = roi_feats[valid_inds] + + self.update( + ids=ids, + bboxes=bboxes, + labels=labels, + scores=scores, + masks=masks, + roi_feats=roi_feats, + frame_ids=frame_id) + # update pred_track_instances + pred_track_instances.bboxes = bboxes + pred_track_instances.masks = masks + pred_track_instances.labels = labels + pred_track_instances.scores = scores + pred_track_instances.instances_id = ids + + return pred_track_instances diff --git a/mmdet/models/trackers/ocsort_tracker.py b/mmdet/models/trackers/ocsort_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..4e09990c603aee8ced3bf3a65ceb530142e6e873 --- /dev/null +++ b/mmdet/models/trackers/ocsort_tracker.py @@ -0,0 +1,531 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +try: + import lap +except ImportError: + lap = None +import numpy as np +import torch +from addict import Dict +from mmengine.structures import InstanceData + +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcyah) +from .sort_tracker import SORTTracker + + +@MODELS.register_module() +class OCSORTTracker(SORTTracker): + """Tracker for OC-SORT. + + Args: + motion (dict): Configuration of motion. Defaults to None. + obj_score_thrs (float): Detection score threshold for matching objects. + Defaults to 0.3. + init_track_thr (float): Detection score threshold for initializing a + new tracklet. Defaults to 0.7. + weight_iou_with_det_scores (bool): Whether using detection scores to + weight IOU which is used for matching. Defaults to True. + match_iou_thr (float): IOU distance threshold for matching between two + frames. Defaults to 0.3. + num_tentatives (int, optional): Number of continuous frames to confirm + a track. Defaults to 3. + vel_consist_weight (float): Weight of the velocity consistency term in + association (OCM term in the paper). + vel_delta_t (int): The difference of time step for calculating of the + velocity direction of tracklets. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + motion: Optional[dict] = None, + obj_score_thr: float = 0.3, + init_track_thr: float = 0.7, + weight_iou_with_det_scores: bool = True, + match_iou_thr: float = 0.3, + num_tentatives: int = 3, + vel_consist_weight: float = 0.2, + vel_delta_t: int = 3, + **kwargs): + if lap is None: + raise RuntimeError('lap is not installed,\ + please install it by: pip install lap') + super().__init__(motion=motion, **kwargs) + self.obj_score_thr = obj_score_thr + self.init_track_thr = init_track_thr + + self.weight_iou_with_det_scores = weight_iou_with_det_scores + self.match_iou_thr = match_iou_thr + self.vel_consist_weight = vel_consist_weight + self.vel_delta_t = vel_delta_t + + self.num_tentatives = num_tentatives + + @property + def unconfirmed_ids(self): + """Unconfirmed ids in the tracker.""" + ids = [id for id, track in self.tracks.items() if track.tentative] + return ids + + def init_track(self, id: int, obj: Tuple[torch.Tensor]): + """Initialize a track.""" + super().init_track(id, obj) + if self.tracks[id].frame_ids[-1] == 0: + self.tracks[id].tentative = False + else: + self.tracks[id].tentative = True + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate( + bbox) + # track.obs maintains the history associated detections to this track + self.tracks[id].obs = [] + bbox_id = self.memo_items.index('bboxes') + self.tracks[id].obs.append(obj[bbox_id]) + # a placefolder to save mean/covariance before losing tracking it + # parameters to save: mean, covariance, measurement + self.tracks[id].tracked = True + self.tracks[id].saved_attr = Dict() + self.tracks[id].velocity = torch.tensor( + (-1, -1)).to(obj[bbox_id].device) # placeholder + + def update_track(self, id: int, obj: Tuple[torch.Tensor]): + """Update a track.""" + super().update_track(id, obj) + if self.tracks[id].tentative: + if len(self.tracks[id]['bboxes']) >= self.num_tentatives: + self.tracks[id].tentative = False + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + self.tracks[id].mean, self.tracks[id].covariance = self.kf.update( + self.tracks[id].mean, self.tracks[id].covariance, bbox) + self.tracks[id].tracked = True + bbox_id = self.memo_items.index('bboxes') + self.tracks[id].obs.append(obj[bbox_id]) + + bbox1 = self.k_step_observation(self.tracks[id]) + bbox2 = obj[bbox_id] + self.tracks[id].velocity = self.vel_direction(bbox1, bbox2).to( + obj[bbox_id].device) + + def vel_direction(self, bbox1: torch.Tensor, bbox2: torch.Tensor): + """Estimate the direction vector between two boxes.""" + if bbox1.sum() < 0 or bbox2.sum() < 0: + return torch.tensor((-1, -1)) + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = torch.tensor([cy2 - cy1, cx2 - cx1]) + norm = torch.sqrt((speed[0])**2 + (speed[1])**2) + 1e-6 + return speed / norm + + def vel_direction_batch(self, bboxes1: torch.Tensor, + bboxes2: torch.Tensor): + """Estimate the direction vector given two batches of boxes.""" + cx1, cy1 = (bboxes1[:, 0] + bboxes1[:, 2]) / 2.0, (bboxes1[:, 1] + + bboxes1[:, 3]) / 2.0 + cx2, cy2 = (bboxes2[:, 0] + bboxes2[:, 2]) / 2.0, (bboxes2[:, 1] + + bboxes2[:, 3]) / 2.0 + speed_diff_y = cy2[None, :] - cy1[:, None] + speed_diff_x = cx2[None, :] - cx1[:, None] + speed = torch.cat((speed_diff_y[..., None], speed_diff_x[..., None]), + dim=-1) + norm = torch.sqrt((speed[:, :, 0])**2 + (speed[:, :, 1])**2) + 1e-6 + speed[:, :, 0] /= norm + speed[:, :, 1] /= norm + return speed + + def k_step_observation(self, track: Dict): + """return the observation k step away before.""" + obs_seqs = track.obs + num_obs = len(obs_seqs) + if num_obs == 0: + return torch.tensor((-1, -1, -1, -1)).to(track.obs[0].device) + elif num_obs > self.vel_delta_t: + if obs_seqs[num_obs - 1 - self.vel_delta_t] is not None: + return obs_seqs[num_obs - 1 - self.vel_delta_t] + else: + return self.last_obs(track) + else: + return self.last_obs(track) + + def ocm_assign_ids(self, + ids: List[int], + det_bboxes: torch.Tensor, + det_labels: torch.Tensor, + det_scores: torch.Tensor, + weight_iou_with_det_scores: Optional[bool] = False, + match_iou_thr: Optional[float] = 0.5): + """Apply Observation-Centric Momentum (OCM) to assign ids. + + OCM adds movement direction consistency into the association cost + matrix. This term requires no additional assumption but from the + same linear motion assumption as the canonical Kalman Filter in SORT. + + Args: + ids (list[int]): Tracking ids. + det_bboxes (Tensor): of shape (N, 4) + det_labels (Tensor): of shape (N,) + det_scores (Tensor): of shape (N,) + weight_iou_with_det_scores (bool, optional): Whether using + detection scores to weight IOU which is used for matching. + Defaults to False. + match_iou_thr (float, optional): Matching threshold. + Defaults to 0.5. + + Returns: + tuple(int): The assigning ids. + + OC-SORT uses velocity consistency besides IoU for association + """ + # get track_bboxes + track_bboxes = np.zeros((0, 4)) + for id in ids: + track_bboxes = np.concatenate( + (track_bboxes, self.tracks[id].mean[:4][None]), axis=0) + track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes) + track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes) + + # compute distance + ious = bbox_overlaps(track_bboxes, det_bboxes) + if weight_iou_with_det_scores: + ious *= det_scores + + # support multi-class association + track_labels = torch.tensor([ + self.tracks[id]['labels'][-1] for id in ids + ]).to(det_bboxes.device) + cate_match = det_labels[None, :] == track_labels[:, None] + # to avoid det and track of different categories are matched + cate_cost = (1 - cate_match.int()) * 1e6 + + dists = (1 - ious + cate_cost).cpu().numpy() + + if len(ids) > 0 and len(det_bboxes) > 0: + track_velocities = torch.stack( + [self.tracks[id].velocity for id in ids]).to(det_bboxes.device) + k_step_observations = torch.stack([ + self.k_step_observation(self.tracks[id]) for id in ids + ]).to(det_bboxes.device) + # valid1: if the track has previous observations to estimate speed + # valid2: if the associated observation k steps ago is a detection + valid1 = track_velocities.sum(dim=1) != -2 + valid2 = k_step_observations.sum(dim=1) != -4 + valid = valid1 & valid2 + + vel_to_match = self.vel_direction_batch(k_step_observations, + det_bboxes) + track_velocities = track_velocities[:, None, :].repeat( + 1, det_bboxes.shape[0], 1) + + angle_cos = (vel_to_match * track_velocities).sum(dim=-1) + angle_cos = torch.clamp(angle_cos, min=-1, max=1) + angle = torch.acos(angle_cos) # [0, pi] + norm_angle = (angle - np.pi / 2.) / np.pi # [-0.5, 0.5] + valid_matrix = valid[:, None].int().repeat(1, det_bboxes.shape[0]) + # set non-valid entries 0 + valid_norm_angle = norm_angle * valid_matrix + + dists += valid_norm_angle.cpu().numpy() * self.vel_consist_weight + + # bipartite match + if dists.size > 0: + cost, row, col = lap.lapjv( + dists, extend_cost=True, cost_limit=1 - match_iou_thr) + else: + row = np.zeros(len(ids)).astype(np.int32) - 1 + col = np.zeros(len(det_bboxes)).astype(np.int32) - 1 + return row, col + + def last_obs(self, track: Dict): + """extract the last associated observation.""" + for bbox in track.obs[::-1]: + if bbox is not None: + return bbox + + def ocr_assign_ids(self, + track_obs: torch.Tensor, + last_track_labels: torch.Tensor, + det_bboxes: torch.Tensor, + det_labels: torch.Tensor, + det_scores: torch.Tensor, + weight_iou_with_det_scores: Optional[bool] = False, + match_iou_thr: Optional[float] = 0.5): + """association for Observation-Centric Recovery. + + As try to recover tracks from being lost whose estimated velocity is + out- to-date, we use IoU-only matching strategy. + + Args: + track_obs (Tensor): the list of historical associated + detections of tracks + det_bboxes (Tensor): of shape (N, 5), unmatched detections + det_labels (Tensor): of shape (N,) + det_scores (Tensor): of shape (N,) + weight_iou_with_det_scores (bool, optional): Whether using + detection scores to weight IOU which is used for matching. + Defaults to False. + match_iou_thr (float, optional): Matching threshold. + Defaults to 0.5. + + Returns: + tuple(int): The assigning ids. + """ + # compute distance + ious = bbox_overlaps(track_obs, det_bboxes) + if weight_iou_with_det_scores: + ious *= det_scores + + # support multi-class association + cate_match = det_labels[None, :] == last_track_labels[:, None] + # to avoid det and track of different categories are matched + cate_cost = (1 - cate_match.int()) * 1e6 + + dists = (1 - ious + cate_cost).cpu().numpy() + + # bipartite match + if dists.size > 0: + cost, row, col = lap.lapjv( + dists, extend_cost=True, cost_limit=1 - match_iou_thr) + else: + row = np.zeros(len(track_obs)).astype(np.int32) - 1 + col = np.zeros(len(det_bboxes)).astype(np.int32) - 1 + return row, col + + def online_smooth(self, track: Dict, obj: torch.Tensor): + """Once a track is recovered from being lost, online smooth its + parameters to fix the error accumulated during being lost. + + NOTE: you can use different virtual trajectory generation + strategies, we adopt the naive linear interpolation as default + """ + last_match_bbox = self.last_obs(track) + new_match_bbox = obj + unmatch_len = 0 + for bbox in track.obs[::-1]: + if bbox is None: + unmatch_len += 1 + else: + break + bbox_shift_per_step = (new_match_bbox - last_match_bbox) / ( + unmatch_len + 1) + track.mean = track.saved_attr.mean + track.covariance = track.saved_attr.covariance + for i in range(unmatch_len): + virtual_bbox = last_match_bbox + (i + 1) * bbox_shift_per_step + virtual_bbox = bbox_xyxy_to_cxcyah(virtual_bbox[None, :]) + virtual_bbox = virtual_bbox.squeeze(0).cpu().numpy() + track.mean, track.covariance = self.kf.update( + track.mean, track.covariance, virtual_bbox) + + def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData: + """Tracking forward function. + NOTE: this implementation is slightly different from the original + OC-SORT implementation (https://github.com/noahcao/OC_SORT)that we + do association between detections and tentative/non-tentative tracks + independently while the original implementation combines them together. + + Args: + data_sample (:obj:`DetDataSample`): The data sample. + It includes information such as `pred_instances`. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + frame_id = metainfo.get('frame_id', -1) + if frame_id == 0: + self.reset() + if not hasattr(self, 'kf'): + self.kf = self.motion + + if self.empty or bboxes.size(0) == 0: + valid_inds = scores > self.init_track_thr + scores = scores[valid_inds] + bboxes = bboxes[valid_inds] + labels = labels[valid_inds] + num_new_tracks = bboxes.size(0) + ids = torch.arange(self.num_tracks, + self.num_tracks + num_new_tracks).to(labels) + self.num_tracks += num_new_tracks + else: + # 0. init + ids = torch.full((bboxes.size(0), ), + -1, + dtype=labels.dtype, + device=labels.device) + + # get the detection bboxes for the first association + det_inds = scores > self.obj_score_thr + det_bboxes = bboxes[det_inds] + det_labels = labels[det_inds] + det_scores = scores[det_inds] + det_ids = ids[det_inds] + + # 1. predict by Kalman Filter + for id in self.confirmed_ids: + # track is lost in previous frame + if self.tracks[id].frame_ids[-1] != frame_id - 1: + self.tracks[id].mean[7] = 0 + if self.tracks[id].tracked: + self.tracks[id].saved_attr.mean = self.tracks[id].mean + self.tracks[id].saved_attr.covariance = self.tracks[ + id].covariance + (self.tracks[id].mean, + self.tracks[id].covariance) = self.kf.predict( + self.tracks[id].mean, self.tracks[id].covariance) + + # 2. match detections and tracks' predicted locations + match_track_inds, raw_match_det_inds = self.ocm_assign_ids( + self.confirmed_ids, det_bboxes, det_labels, det_scores, + self.weight_iou_with_det_scores, self.match_iou_thr) + # '-1' mean a detection box is not matched with tracklets in + # previous frame + valid = raw_match_det_inds > -1 + det_ids[valid] = torch.tensor( + self.confirmed_ids)[raw_match_det_inds[valid]].to(labels) + + match_det_bboxes = det_bboxes[valid] + match_det_labels = det_labels[valid] + match_det_scores = det_scores[valid] + match_det_ids = det_ids[valid] + assert (match_det_ids > -1).all() + + # unmatched tracks and detections + unmatch_det_bboxes = det_bboxes[~valid] + unmatch_det_labels = det_labels[~valid] + unmatch_det_scores = det_scores[~valid] + unmatch_det_ids = det_ids[~valid] + assert (unmatch_det_ids == -1).all() + + # 3. use unmatched detection bboxes from the first match to match + # the unconfirmed tracks + (tentative_match_track_inds, + tentative_match_det_inds) = self.ocm_assign_ids( + self.unconfirmed_ids, unmatch_det_bboxes, unmatch_det_labels, + unmatch_det_scores, self.weight_iou_with_det_scores, + self.match_iou_thr) + valid = tentative_match_det_inds > -1 + unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[ + tentative_match_det_inds[valid]].to(labels) + + match_det_bboxes = torch.cat( + (match_det_bboxes, unmatch_det_bboxes[valid]), dim=0) + match_det_labels = torch.cat( + (match_det_labels, unmatch_det_labels[valid]), dim=0) + match_det_scores = torch.cat( + (match_det_scores, unmatch_det_scores[valid]), dim=0) + match_det_ids = torch.cat((match_det_ids, unmatch_det_ids[valid]), + dim=0) + assert (match_det_ids > -1).all() + + unmatch_det_bboxes = unmatch_det_bboxes[~valid] + unmatch_det_labels = unmatch_det_labels[~valid] + unmatch_det_scores = unmatch_det_scores[~valid] + unmatch_det_ids = unmatch_det_ids[~valid] + assert (unmatch_det_ids == -1).all() + + all_track_ids = [id for id, _ in self.tracks.items()] + unmatched_track_inds = torch.tensor( + [ind for ind in all_track_ids if ind not in match_det_ids]) + + if len(unmatched_track_inds) > 0: + # 4. still some tracks not associated yet, perform OCR + last_observations = [] + for id in unmatched_track_inds: + last_box = self.last_obs(self.tracks[id.item()]) + last_observations.append(last_box) + last_observations = torch.stack(last_observations) + last_track_labels = torch.tensor([ + self.tracks[id.item()]['labels'][-1] + for id in unmatched_track_inds + ]).to(det_bboxes.device) + + remain_det_ids = torch.full((unmatch_det_bboxes.size(0), ), + -1, + dtype=labels.dtype, + device=labels.device) + + _, ocr_match_det_inds = self.ocr_assign_ids( + last_observations, last_track_labels, unmatch_det_bboxes, + unmatch_det_labels, unmatch_det_scores, + self.weight_iou_with_det_scores, self.match_iou_thr) + + valid = ocr_match_det_inds > -1 + remain_det_ids[valid] = unmatched_track_inds.clone()[ + ocr_match_det_inds[valid]].to(labels) + + ocr_match_det_bboxes = unmatch_det_bboxes[valid] + ocr_match_det_labels = unmatch_det_labels[valid] + ocr_match_det_scores = unmatch_det_scores[valid] + ocr_match_det_ids = remain_det_ids[valid] + assert (ocr_match_det_ids > -1).all() + + ocr_unmatch_det_bboxes = unmatch_det_bboxes[~valid] + ocr_unmatch_det_labels = unmatch_det_labels[~valid] + ocr_unmatch_det_scores = unmatch_det_scores[~valid] + ocr_unmatch_det_ids = remain_det_ids[~valid] + assert (ocr_unmatch_det_ids == -1).all() + + unmatch_det_bboxes = ocr_unmatch_det_bboxes + unmatch_det_labels = ocr_unmatch_det_labels + unmatch_det_scores = ocr_unmatch_det_scores + unmatch_det_ids = ocr_unmatch_det_ids + match_det_bboxes = torch.cat( + (match_det_bboxes, ocr_match_det_bboxes), dim=0) + match_det_labels = torch.cat( + (match_det_labels, ocr_match_det_labels), dim=0) + match_det_scores = torch.cat( + (match_det_scores, ocr_match_det_scores), dim=0) + match_det_ids = torch.cat((match_det_ids, ocr_match_det_ids), + dim=0) + + # 5. summarize the track results + for i in range(len(match_det_ids)): + det_bbox = match_det_bboxes[i] + track_id = match_det_ids[i].item() + if not self.tracks[track_id].tracked: + # the track is lost before this step + self.online_smooth(self.tracks[track_id], det_bbox) + + for track_id in all_track_ids: + if track_id not in match_det_ids: + self.tracks[track_id].tracked = False + self.tracks[track_id].obs.append(None) + + bboxes = torch.cat((match_det_bboxes, unmatch_det_bboxes), dim=0) + labels = torch.cat((match_det_labels, unmatch_det_labels), dim=0) + scores = torch.cat((match_det_scores, unmatch_det_scores), dim=0) + ids = torch.cat((match_det_ids, unmatch_det_ids), dim=0) + # 6. assign new ids + new_track_inds = ids == -1 + + ids[new_track_inds] = torch.arange( + self.num_tracks, + self.num_tracks + new_track_inds.sum()).to(labels) + self.num_tracks += new_track_inds.sum() + + self.update( + ids=ids, + bboxes=bboxes, + labels=labels, + scores=scores, + frame_ids=frame_id) + + # update pred_track_instances + pred_track_instances = InstanceData() + pred_track_instances.bboxes = bboxes + pred_track_instances.labels = labels + pred_track_instances.scores = scores + pred_track_instances.instances_id = ids + return pred_track_instances diff --git a/mmdet/models/trackers/quasi_dense_tracker.py b/mmdet/models/trackers/quasi_dense_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..c93c3c4c3bd5c8939e77195f30a7eb2f0314e225 --- /dev/null +++ b/mmdet/models/trackers/quasi_dense_tracker.py @@ -0,0 +1,316 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample +from mmdet.structures.bbox import bbox_overlaps +from .base_tracker import BaseTracker + + +@MODELS.register_module() +class QuasiDenseTracker(BaseTracker): + """Tracker for Quasi-Dense Tracking. + + Args: + init_score_thr (float): The cls_score threshold to + initialize a new tracklet. Defaults to 0.8. + obj_score_thr (float): The cls_score threshold to + update a tracked tracklet. Defaults to 0.5. + match_score_thr (float): The match threshold. Defaults to 0.5. + memo_tracklet_frames (int): The most frames in a tracklet memory. + Defaults to 10. + memo_backdrop_frames (int): The most frames in the backdrops. + Defaults to 1. + memo_momentum (float): The momentum value for embeds updating. + Defaults to 0.8. + nms_conf_thr (float): The nms threshold for confidence. + Defaults to 0.5. + nms_backdrop_iou_thr (float): The nms threshold for backdrop IoU. + Defaults to 0.3. + nms_class_iou_thr (float): The nms threshold for class IoU. + Defaults to 0.7. + with_cats (bool): Whether to track with the same category. + Defaults to True. + match_metric (str): The match metric. Defaults to 'bisoftmax'. + """ + + def __init__(self, + init_score_thr: float = 0.8, + obj_score_thr: float = 0.5, + match_score_thr: float = 0.5, + memo_tracklet_frames: int = 10, + memo_backdrop_frames: int = 1, + memo_momentum: float = 0.8, + nms_conf_thr: float = 0.5, + nms_backdrop_iou_thr: float = 0.3, + nms_class_iou_thr: float = 0.7, + with_cats: bool = True, + match_metric: str = 'bisoftmax', + **kwargs): + super().__init__(**kwargs) + assert 0 <= memo_momentum <= 1.0 + assert memo_tracklet_frames >= 0 + assert memo_backdrop_frames >= 0 + self.init_score_thr = init_score_thr + self.obj_score_thr = obj_score_thr + self.match_score_thr = match_score_thr + self.memo_tracklet_frames = memo_tracklet_frames + self.memo_backdrop_frames = memo_backdrop_frames + self.memo_momentum = memo_momentum + self.nms_conf_thr = nms_conf_thr + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + self.nms_class_iou_thr = nms_class_iou_thr + self.with_cats = with_cats + assert match_metric in ['bisoftmax', 'softmax', 'cosine'] + self.match_metric = match_metric + + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + + def reset(self): + """Reset the buffer of the tracker.""" + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + + def update(self, ids: Tensor, bboxes: Tensor, embeds: Tensor, + labels: Tensor, scores: Tensor, frame_id: int) -> None: + """Tracking forward function. + + Args: + ids (Tensor): of shape(N, ). + bboxes (Tensor): of shape (N, 5). + embeds (Tensor): of shape (N, 256). + labels (Tensor): of shape (N, ). + scores (Tensor): of shape (N, ). + frame_id (int): The id of current frame, 0-index. + """ + tracklet_inds = ids > -1 + + for id, bbox, embed, label, score in zip(ids[tracklet_inds], + bboxes[tracklet_inds], + embeds[tracklet_inds], + labels[tracklet_inds], + scores[tracklet_inds]): + id = int(id) + # update the tracked ones and initialize new tracks + if id in self.tracks.keys(): + velocity = (bbox - self.tracks[id]['bbox']) / ( + frame_id - self.tracks[id]['last_frame']) + self.tracks[id]['bbox'] = bbox + self.tracks[id]['embed'] = ( + 1 - self.memo_momentum + ) * self.tracks[id]['embed'] + self.memo_momentum * embed + self.tracks[id]['last_frame'] = frame_id + self.tracks[id]['label'] = label + self.tracks[id]['score'] = score + self.tracks[id]['velocity'] = ( + self.tracks[id]['velocity'] * self.tracks[id]['acc_frame'] + + velocity) / ( + self.tracks[id]['acc_frame'] + 1) + self.tracks[id]['acc_frame'] += 1 + else: + self.tracks[id] = dict( + bbox=bbox, + embed=embed, + label=label, + score=score, + last_frame=frame_id, + velocity=torch.zeros_like(bbox), + acc_frame=0) + # backdrop update according to IoU + backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1) + ious = bbox_overlaps(bboxes[backdrop_inds], bboxes) + for i, ind in enumerate(backdrop_inds): + if (ious[i, :ind] > self.nms_backdrop_iou_thr).any(): + backdrop_inds[i] = -1 + backdrop_inds = backdrop_inds[backdrop_inds > -1] + # old backdrops would be removed at first + self.backdrops.insert( + 0, + dict( + bboxes=bboxes[backdrop_inds], + embeds=embeds[backdrop_inds], + labels=labels[backdrop_inds])) + + # pop memo + invalid_ids = [] + for k, v in self.tracks.items(): + if frame_id - v['last_frame'] >= self.memo_tracklet_frames: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + if len(self.backdrops) > self.memo_backdrop_frames: + self.backdrops.pop() + + @property + def memo(self) -> Tuple[Tensor, ...]: + """Get tracks memory.""" + memo_embeds = [] + memo_ids = [] + memo_bboxes = [] + memo_labels = [] + # velocity of tracks + memo_vs = [] + # get tracks + for k, v in self.tracks.items(): + memo_bboxes.append(v['bbox'][None, :]) + memo_embeds.append(v['embed'][None, :]) + memo_ids.append(k) + memo_labels.append(v['label'].view(1, 1)) + memo_vs.append(v['velocity'][None, :]) + memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) + # get backdrops + for backdrop in self.backdrops: + backdrop_ids = torch.full((1, backdrop['embeds'].size(0)), + -1, + dtype=torch.long) + backdrop_vs = torch.zeros_like(backdrop['bboxes']) + memo_bboxes.append(backdrop['bboxes']) + memo_embeds.append(backdrop['embeds']) + memo_ids = torch.cat([memo_ids, backdrop_ids], dim=1) + memo_labels.append(backdrop['labels'][:, None]) + memo_vs.append(backdrop_vs) + + memo_bboxes = torch.cat(memo_bboxes, dim=0) + memo_embeds = torch.cat(memo_embeds, dim=0) + memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) + memo_vs = torch.cat(memo_vs, dim=0) + return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze( + 0), memo_vs + + def track(self, + model: torch.nn.Module, + img: torch.Tensor, + feats: List[torch.Tensor], + data_sample: TrackDataSample, + rescale=True, + **kwargs) -> InstanceData: + """Tracking forward function. + + Args: + model (nn.Module): MOT model. + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + The T denotes the number of key images and usually is 1 in + QDTrack method. + feats (list[Tensor]): Multi level feature maps of `img`. + data_sample (:obj:`TrackDataSample`): The data sample. + It includes information such as `pred_instances`. + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the original scale of the image. Defaults to + True. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get('frame_id', -1) + # create pred_track_instances + pred_track_instances = InstanceData() + + # return zero bboxes if there is no track targets + if bboxes.shape[0] == 0: + ids = torch.zeros_like(labels) + pred_track_instances = data_sample.pred_instances.clone() + pred_track_instances.instances_id = ids + return pred_track_instances + + # get track feats + rescaled_bboxes = bboxes.clone() + if rescale: + scale_factor = rescaled_bboxes.new_tensor( + metainfo['scale_factor']).repeat((1, 2)) + rescaled_bboxes = rescaled_bboxes * scale_factor + track_feats = model.track_head.predict(feats, [rescaled_bboxes]) + # sort according to the object_score + _, inds = scores.sort(descending=True) + bboxes = bboxes[inds] + scores = scores[inds] + labels = labels[inds] + embeds = track_feats[inds, :] + + # duplicate removal for potential backdrops and cross classes + valids = bboxes.new_ones((bboxes.size(0))) + ious = bbox_overlaps(bboxes, bboxes) + for i in range(1, bboxes.size(0)): + thr = self.nms_backdrop_iou_thr if scores[ + i] < self.obj_score_thr else self.nms_class_iou_thr + if (ious[i, :i] > thr).any(): + valids[i] = 0 + valids = valids == 1 + bboxes = bboxes[valids] + scores = scores[valids] + labels = labels[valids] + embeds = embeds[valids, :] + + # init ids container + ids = torch.full((bboxes.size(0), ), -1, dtype=torch.long) + + # match if buffer is not empty + if bboxes.size(0) > 0 and not self.empty: + (memo_bboxes, memo_labels, memo_embeds, memo_ids, + memo_vs) = self.memo + + if self.match_metric == 'bisoftmax': + feats = torch.mm(embeds, memo_embeds.t()) + d2t_scores = feats.softmax(dim=1) + t2d_scores = feats.softmax(dim=0) + match_scores = (d2t_scores + t2d_scores) / 2 + elif self.match_metric == 'softmax': + feats = torch.mm(embeds, memo_embeds.t()) + match_scores = feats.softmax(dim=1) + elif self.match_metric == 'cosine': + match_scores = torch.mm( + F.normalize(embeds, p=2, dim=1), + F.normalize(memo_embeds, p=2, dim=1).t()) + else: + raise NotImplementedError + # track with the same category + if self.with_cats: + cat_same = labels.view(-1, 1) == memo_labels.view(1, -1) + match_scores *= cat_same.float().to(match_scores.device) + # track according to match_scores + for i in range(bboxes.size(0)): + conf, memo_ind = torch.max(match_scores[i, :], dim=0) + id = memo_ids[memo_ind] + if conf > self.match_score_thr: + if id > -1: + # keep bboxes with high object score + # and remove background bboxes + if scores[i] > self.obj_score_thr: + ids[i] = id + match_scores[:i, memo_ind] = 0 + match_scores[i + 1:, memo_ind] = 0 + else: + if conf > self.nms_conf_thr: + ids[i] = -2 + # initialize new tracks + new_inds = (ids == -1) & (scores > self.init_score_thr).cpu() + num_news = new_inds.sum() + ids[new_inds] = torch.arange( + self.num_tracks, self.num_tracks + num_news, dtype=torch.long) + self.num_tracks += num_news + + self.update(ids, bboxes, embeds, labels, scores, frame_id) + tracklet_inds = ids > -1 + # update pred_track_instances + pred_track_instances.bboxes = bboxes[tracklet_inds] + pred_track_instances.labels = labels[tracklet_inds] + pred_track_instances.scores = scores[tracklet_inds] + pred_track_instances.instances_id = ids[tracklet_inds] + + return pred_track_instances diff --git a/mmdet/models/trackers/sort_tracker.py b/mmdet/models/trackers/sort_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a4fed92702f7d1ea66917a7157fcf5d0773a30 --- /dev/null +++ b/mmdet/models/trackers/sort_tracker.py @@ -0,0 +1,268 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch +from mmengine.structures import InstanceData + +try: + import motmetrics + from motmetrics.lap import linear_sum_assignment +except ImportError: + motmetrics = None +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcyah +from mmdet.utils import OptConfigType +from ..utils import imrenormalize +from .base_tracker import BaseTracker + + +@MODELS.register_module() +class SORTTracker(BaseTracker): + """Tracker for SORT/DeepSORT. + + Args: + obj_score_thr (float, optional): Threshold to filter the objects. + Defaults to 0.3. + motion (dict): Configuration of motion. Defaults to None. + reid (dict, optional): Configuration for the ReID model. + - num_samples (int, optional): Number of samples to calculate the + feature embeddings of a track. Default to 10. + - image_scale (tuple, optional): Input scale of the ReID model. + Default to (256, 128). + - img_norm_cfg (dict, optional): Configuration to normalize the + input. Default to None. + - match_score_thr (float, optional): Similarity threshold for the + matching process. Default to 2.0. + match_iou_thr (float, optional): Threshold of the IoU matching process. + Defaults to 0.7. + num_tentatives (int, optional): Number of continuous frames to confirm + a track. Defaults to 3. + """ + + def __init__(self, + motion: Optional[dict] = None, + obj_score_thr: float = 0.3, + reid: dict = dict( + num_samples=10, + img_scale=(256, 128), + img_norm_cfg=None, + match_score_thr=2.0), + match_iou_thr: float = 0.7, + num_tentatives: int = 3, + **kwargs): + if motmetrics is None: + raise RuntimeError('motmetrics is not installed,\ + please install it by: pip install motmetrics') + super().__init__(**kwargs) + if motion is not None: + self.motion = TASK_UTILS.build(motion) + assert self.motion is not None, 'SORT/Deep SORT need KalmanFilter' + self.obj_score_thr = obj_score_thr + self.reid = reid + self.match_iou_thr = match_iou_thr + self.num_tentatives = num_tentatives + + @property + def confirmed_ids(self) -> List: + """Confirmed ids in the tracker.""" + ids = [id for id, track in self.tracks.items() if not track.tentative] + return ids + + def init_track(self, id: int, obj: Tuple[Tensor]) -> None: + """Initialize a track.""" + super().init_track(id, obj) + self.tracks[id].tentative = True + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate( + bbox) + + def update_track(self, id: int, obj: Tuple[Tensor]) -> None: + """Update a track.""" + super().update_track(id, obj) + if self.tracks[id].tentative: + if len(self.tracks[id]['bboxes']) >= self.num_tentatives: + self.tracks[id].tentative = False + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + self.tracks[id].mean, self.tracks[id].covariance = self.kf.update( + self.tracks[id].mean, self.tracks[id].covariance, bbox) + + def pop_invalid_tracks(self, frame_id: int) -> None: + """Pop out invalid tracks.""" + invalid_ids = [] + for k, v in self.tracks.items(): + # case1: disappeared frames >= self.num_frames_retrain + case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain + # case2: tentative tracks but not matched in this frame + case2 = v.tentative and v['frame_ids'][-1] != frame_id + if case1 or case2: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + def track(self, + model: torch.nn.Module, + img: Tensor, + data_sample: DetDataSample, + data_preprocessor: OptConfigType = None, + rescale: bool = False, + **kwargs) -> InstanceData: + """Tracking forward function. + + Args: + model (nn.Module): MOT model. + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + The T denotes the number of key images and usually is 1 in + SORT method. + data_sample (:obj:`TrackDataSample`): The data sample. + It includes information such as `pred_det_instances`. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the original scale of the image. Defaults to + False. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get('frame_id', -1) + if frame_id == 0: + self.reset() + if not hasattr(self, 'kf'): + self.kf = self.motion + + if self.with_reid: + if self.reid.get('img_norm_cfg', False): + img_norm_cfg = dict( + mean=data_preprocessor['mean'], + std=data_preprocessor['std'], + to_bgr=data_preprocessor['rgb_to_bgr']) + reid_img = imrenormalize(img, img_norm_cfg, + self.reid['img_norm_cfg']) + else: + reid_img = img.clone() + + valid_inds = scores > self.obj_score_thr + bboxes = bboxes[valid_inds] + labels = labels[valid_inds] + scores = scores[valid_inds] + + if self.empty or bboxes.size(0) == 0: + num_new_tracks = bboxes.size(0) + ids = torch.arange( + self.num_tracks, + self.num_tracks + num_new_tracks, + dtype=torch.long).to(bboxes.device) + self.num_tracks += num_new_tracks + if self.with_reid: + crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(), + rescale) + if crops.size(0) > 0: + embeds = model.reid(crops, mode='tensor') + else: + embeds = crops.new_zeros((0, model.reid.head.out_channels)) + else: + ids = torch.full((bboxes.size(0), ), -1, + dtype=torch.long).to(bboxes.device) + + # motion + self.tracks, costs = self.motion.track(self.tracks, + bbox_xyxy_to_cxcyah(bboxes)) + + active_ids = self.confirmed_ids + if self.with_reid: + crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(), + rescale) + embeds = model.reid(crops, mode='tensor') + + # reid + if len(active_ids) > 0: + track_embeds = self.get( + 'embeds', + active_ids, + self.reid.get('num_samples', None), + behavior='mean') + reid_dists = torch.cdist(track_embeds, embeds) + + # support multi-class association + track_labels = torch.tensor([ + self.tracks[id]['labels'][-1] for id in active_ids + ]).to(bboxes.device) + cate_match = labels[None, :] == track_labels[:, None] + cate_cost = (1 - cate_match.int()) * 1e6 + reid_dists = (reid_dists + cate_cost).cpu().numpy() + + valid_inds = [list(self.ids).index(_) for _ in active_ids] + reid_dists[~np.isfinite(costs[valid_inds, :])] = np.nan + + row, col = linear_sum_assignment(reid_dists) + for r, c in zip(row, col): + dist = reid_dists[r, c] + if not np.isfinite(dist): + continue + if dist <= self.reid['match_score_thr']: + ids[c] = active_ids[r] + + active_ids = [ + id for id in self.ids if id not in ids + and self.tracks[id].frame_ids[-1] == frame_id - 1 + ] + if len(active_ids) > 0: + active_dets = torch.nonzero(ids == -1).squeeze(1) + track_bboxes = self.get('bboxes', active_ids) + ious = bbox_overlaps(track_bboxes, bboxes[active_dets]) + + # support multi-class association + track_labels = torch.tensor([ + self.tracks[id]['labels'][-1] for id in active_ids + ]).to(bboxes.device) + cate_match = labels[None, active_dets] == track_labels[:, None] + cate_cost = (1 - cate_match.int()) * 1e6 + + dists = (1 - ious + cate_cost).cpu().numpy() + + row, col = linear_sum_assignment(dists) + for r, c in zip(row, col): + dist = dists[r, c] + if dist < 1 - self.match_iou_thr: + ids[active_dets[c]] = active_ids[r] + + new_track_inds = ids == -1 + ids[new_track_inds] = torch.arange( + self.num_tracks, + self.num_tracks + new_track_inds.sum(), + dtype=torch.long).to(bboxes.device) + self.num_tracks += new_track_inds.sum() + + self.update( + ids=ids, + bboxes=bboxes, + scores=scores, + labels=labels, + embeds=embeds if self.with_reid else None, + frame_ids=frame_id) + + # update pred_track_instances + pred_track_instances = InstanceData() + pred_track_instances.bboxes = bboxes + pred_track_instances.labels = labels + pred_track_instances.scores = scores + pred_track_instances.instances_id = ids + + return pred_track_instances diff --git a/mmdet/models/trackers/strongsort_tracker.py b/mmdet/models/trackers/strongsort_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7075701bc3205b9ea30f03790cfa1c42a97822 --- /dev/null +++ b/mmdet/models/trackers/strongsort_tracker.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import numpy as np +import torch +from mmengine.structures import InstanceData + +try: + import motmetrics + from motmetrics.lap import linear_sum_assignment +except ImportError: + motmetrics = None +from torch import Tensor + +from mmdet.models.utils import imrenormalize +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample +from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcyah +from mmdet.utils import OptConfigType +from .sort_tracker import SORTTracker + + +def cosine_distance(x: Tensor, y: Tensor) -> np.ndarray: + """compute the cosine distance. + + Args: + x (Tensor): embeddings with shape (N,C). + y (Tensor): embeddings with shape (M,C). + + Returns: + ndarray: cosine distance with shape (N,M). + """ + x = x.cpu().numpy() + y = y.cpu().numpy() + x = x / np.linalg.norm(x, axis=1, keepdims=True) + y = y / np.linalg.norm(y, axis=1, keepdims=True) + dists = 1. - np.dot(x, y.T) + return dists + + +@MODELS.register_module() +class StrongSORTTracker(SORTTracker): + """Tracker for StrongSORT. + + Args: + obj_score_thr (float, optional): Threshold to filter the objects. + Defaults to 0.6. + motion (dict): Configuration of motion. Defaults to None. + reid (dict, optional): Configuration for the ReID model. + - num_samples (int, optional): Number of samples to calculate the + feature embeddings of a track. Default to None. + - image_scale (tuple, optional): Input scale of the ReID model. + Default to (256, 128). + - img_norm_cfg (dict, optional): Configuration to normalize the + input. Default to None. + - match_score_thr (float, optional): Similarity threshold for the + matching process. Default to 0.3. + - motion_weight (float, optional): the weight of the motion cost. + Defaults to 0.02. + match_iou_thr (float, optional): Threshold of the IoU matching process. + Defaults to 0.7. + num_tentatives (int, optional): Number of continuous frames to confirm + a track. Defaults to 2. + """ + + def __init__(self, + motion: Optional[dict] = None, + obj_score_thr: float = 0.6, + reid: dict = dict( + num_samples=None, + img_scale=(256, 128), + img_norm_cfg=None, + match_score_thr=0.3, + motion_weight=0.02), + match_iou_thr: float = 0.7, + num_tentatives: int = 2, + **kwargs): + if motmetrics is None: + raise RuntimeError('motmetrics is not installed,\ + please install it by: pip install motmetrics') + super().__init__(motion, obj_score_thr, reid, match_iou_thr, + num_tentatives, **kwargs) + + def update_track(self, id: int, obj: Tuple[Tensor]) -> None: + """Update a track.""" + for k, v in zip(self.memo_items, obj): + v = v[None] + if self.momentums is not None and k in self.momentums: + m = self.momentums[k] + self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v + else: + self.tracks[id][k].append(v) + + if self.tracks[id].tentative: + if len(self.tracks[id]['bboxes']) >= self.num_tentatives: + self.tracks[id].tentative = False + bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) + assert bbox.ndim == 2 and bbox.shape[0] == 1 + bbox = bbox.squeeze(0).cpu().numpy() + score = float(self.tracks[id].scores[-1].cpu()) + self.tracks[id].mean, self.tracks[id].covariance = self.kf.update( + self.tracks[id].mean, self.tracks[id].covariance, bbox, score) + + def track(self, + model: torch.nn.Module, + img: Tensor, + data_sample: TrackDataSample, + data_preprocessor: OptConfigType = None, + rescale: bool = False, + **kwargs) -> InstanceData: + """Tracking forward function. + + Args: + model (nn.Module): MOT model. + img (Tensor): of shape (T, C, H, W) encoding input image. + Typically these should be mean centered and std scaled. + The T denotes the number of key images and usually is 1 in + SORT method. + feats (list[Tensor]): Multi level feature maps of `img`. + data_sample (:obj:`TrackDataSample`): The data sample. + It includes information such as `pred_det_instances`. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + rescale (bool, optional): If True, the bounding boxes should be + rescaled to fit the original scale of the image. Defaults to + False. + + Returns: + :obj:`InstanceData`: Tracking results of the input images. + Each InstanceData usually contains ``bboxes``, ``labels``, + ``scores`` and ``instances_id``. + """ + metainfo = data_sample.metainfo + bboxes = data_sample.pred_instances.bboxes + labels = data_sample.pred_instances.labels + scores = data_sample.pred_instances.scores + + frame_id = metainfo.get('frame_id', -1) + if frame_id == 0: + self.reset() + if not hasattr(self, 'kf'): + self.kf = self.motion + + if self.with_reid: + if self.reid.get('img_norm_cfg', False): + img_norm_cfg = dict( + mean=data_preprocessor.get('mean', [0, 0, 0]), + std=data_preprocessor.get('std', [1, 1, 1]), + to_bgr=data_preprocessor.get('rgb_to_bgr', False)) + reid_img = imrenormalize(img, img_norm_cfg, + self.reid['img_norm_cfg']) + else: + reid_img = img.clone() + + valid_inds = scores > self.obj_score_thr + bboxes = bboxes[valid_inds] + labels = labels[valid_inds] + scores = scores[valid_inds] + + if self.empty or bboxes.size(0) == 0: + num_new_tracks = bboxes.size(0) + ids = torch.arange( + self.num_tracks, + self.num_tracks + num_new_tracks, + dtype=torch.long).to(bboxes.device) + self.num_tracks += num_new_tracks + if self.with_reid: + crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(), + rescale) + if crops.size(0) > 0: + embeds = model.reid(crops, mode='tensor') + else: + embeds = crops.new_zeros((0, model.reid.head.out_channels)) + else: + ids = torch.full((bboxes.size(0), ), -1, + dtype=torch.long).to(bboxes.device) + + # motion + if model.with_cmc: + num_samples = 1 + self.tracks = model.cmc.track(self.last_img, img, self.tracks, + num_samples, frame_id, metainfo) + + self.tracks, motion_dists = self.motion.track( + self.tracks, bbox_xyxy_to_cxcyah(bboxes)) + + active_ids = self.confirmed_ids + if self.with_reid: + crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(), + rescale) + embeds = model.reid(crops, mode='tensor') + + # reid + if len(active_ids) > 0: + track_embeds = self.get( + 'embeds', + active_ids, + self.reid.get('num_samples', None), + behavior='mean') + reid_dists = cosine_distance(track_embeds, embeds) + valid_inds = [list(self.ids).index(_) for _ in active_ids] + reid_dists[~np.isfinite(motion_dists[ + valid_inds, :])] = np.nan + + weight_motion = self.reid.get('motion_weight') + match_dists = (1 - weight_motion) * reid_dists + \ + weight_motion * motion_dists[valid_inds] + + # support multi-class association + track_labels = torch.tensor([ + self.tracks[id]['labels'][-1] for id in active_ids + ]).to(bboxes.device) + cate_match = labels[None, :] == track_labels[:, None] + cate_cost = ((1 - cate_match.int()) * 1e6).cpu().numpy() + match_dists = match_dists + cate_cost + + row, col = linear_sum_assignment(match_dists) + for r, c in zip(row, col): + dist = match_dists[r, c] + if not np.isfinite(dist): + continue + if dist <= self.reid['match_score_thr']: + ids[c] = active_ids[r] + + active_ids = [ + id for id in self.ids if id not in ids + and self.tracks[id].frame_ids[-1] == frame_id - 1 + ] + if len(active_ids) > 0: + active_dets = torch.nonzero(ids == -1).squeeze(1) + track_bboxes = self.get('bboxes', active_ids) + ious = bbox_overlaps(track_bboxes, bboxes[active_dets]) + + # support multi-class association + track_labels = torch.tensor([ + self.tracks[id]['labels'][-1] for id in active_ids + ]).to(bboxes.device) + cate_match = labels[None, active_dets] == track_labels[:, None] + cate_cost = (1 - cate_match.int()) * 1e6 + + dists = (1 - ious + cate_cost).cpu().numpy() + + row, col = linear_sum_assignment(dists) + for r, c in zip(row, col): + dist = dists[r, c] + if dist < 1 - self.match_iou_thr: + ids[active_dets[c]] = active_ids[r] + + new_track_inds = ids == -1 + ids[new_track_inds] = torch.arange( + self.num_tracks, + self.num_tracks + new_track_inds.sum(), + dtype=torch.long).to(bboxes.device) + self.num_tracks += new_track_inds.sum() + + self.update( + ids=ids, + bboxes=bboxes, + scores=scores, + labels=labels, + embeds=embeds if self.with_reid else None, + frame_ids=frame_id) + self.last_img = img + + # update pred_track_instances + pred_track_instances = InstanceData() + pred_track_instances.bboxes = bboxes + pred_track_instances.labels = labels + pred_track_instances.scores = scores + pred_track_instances.instances_id = ids + + return pred_track_instances diff --git a/mmdet/models/tracking_heads/__init__.py b/mmdet/models/tracking_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1f0561cc076f2a603a64eb479cc6de0372a438 --- /dev/null +++ b/mmdet/models/tracking_heads/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mask2former_track_head import Mask2FormerTrackHead +from .quasi_dense_embed_head import QuasiDenseEmbedHead +from .quasi_dense_track_head import QuasiDenseTrackHead +from .roi_embed_head import RoIEmbedHead +from .roi_track_head import RoITrackHead + +__all__ = [ + 'QuasiDenseEmbedHead', 'QuasiDenseTrackHead', 'Mask2FormerTrackHead', + 'RoIEmbedHead', 'RoITrackHead' +] diff --git a/mmdet/models/tracking_heads/mask2former_track_head.py b/mmdet/models/tracking_heads/mask2former_track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0877241bc33fcd1ef8f7ed154d503d9dbd8ab938 --- /dev/null +++ b/mmdet/models/tracking_heads/mask2former_track_head.py @@ -0,0 +1,729 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import defaultdict +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d +from mmcv.ops import point_sample +from mmengine.model import ModuleList +from mmengine.model.weight_init import caffe2_xavier_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.dense_heads import AnchorFreeHead, MaskFormerHead +from mmdet.models.utils import get_uncertain_point_coords_with_randomness +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import TrackDataSample, TrackSampleList +from mmdet.structures.mask import mask2bbox +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptMultiConfig, reduce_mean) +from ..layers import Mask2FormerTransformerDecoder + + +@MODELS.register_module() +class Mask2FormerTrackHead(MaskFormerHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_classes (int): Number of VIS classes. + num_queries (int): Number of query in Transformer decoder. + Defaults to 100. + num_transformer_feat_level (int): Number of feats levels. + Defaults to 3. + pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel + decoder. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of transformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`ConfigDict` or dict): Config for + transformer decoder. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer decoder position encoding. + Defaults to `SinePositionalEncoding3D`. + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to `CrossEntropyLoss`. + loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. + Defaults to 'CrossEntropyLoss'. + loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. + Defaults to 'DiceLoss'. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + Mask2Former head. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + Mask2Former head. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: List[int], + feat_channels: int, + out_channels: int, + num_classes: int, + num_frames: int = 2, + num_queries: int = 100, + num_transformer_feat_level: int = 3, + pixel_decoder: ConfigType = ..., + enforce_decoder_input_project: bool = False, + transformer_decoder: ConfigType = ..., + positional_encoding: ConfigType = dict( + num_feats=128, normalize=True), + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 133 + [0.1]), + loss_mask: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice: ConfigType = dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.num_frames = num_frames + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.layer_cfg. \ + self_attn_cfg.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = MODELS.build(pixel_decoder_) + self.transformer_decoder = Mask2FormerTransformerDecoder( + **transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = MODELS.build(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg.assigner) + self.sampler = TASK_UTILS.build( + # self.train_cfg.sampler, default_args=dict(context=self)) + self.train_cfg['sampler'], + default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + def init_weights(self) -> None: + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def preprocess_gt(self, batch_gt_instances: InstanceList) -> InstanceList: + """Preprocess the ground truth for all images. + + It aims to reorganize the `gt`. For example, in the + `batch_data_sample.gt_instances.mask`, its shape is + `(all_num_gts, h, w)`, but we don't know each gt belongs to which `img` + (assume `num_frames` is 2). So, this func used to reshape the `gt_mask` + to `(num_gts_per_img, num_frames, h, w)`. In addition, we can't + guarantee that the number of instances in these two images is equal, + so `-1` refers to nonexistent instances. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + ground truth labels of each bbox, with shape (num_gts, ) + and ``masks``, each is ground truth masks of each instances + of an image, shape (num_gts, h, w). + + Returns: + list[obj:`InstanceData`]: each contains the following keys + + - labels (Tensor): Ground truth class indices\ + for an image, with shape (n, ), n is the sum of\ + number of stuff type and number of instance in an image. + - masks (Tensor): Ground truth mask for a\ + image, with shape (n, t, h, w). + """ + final_batch_gt_instances = [] + batch_size = len(batch_gt_instances) // self.num_frames + for batch_idx in range(batch_size): + pair_gt_insatences = batch_gt_instances[batch_idx * + self.num_frames:batch_idx * + self.num_frames + + self.num_frames] + + assert len( + pair_gt_insatences + ) > 1, f'mask2former for vis need multi frames to train, \ + but you only use {len(pair_gt_insatences)} frames' + + _device = pair_gt_insatences[0].labels.device + + for gt_instances in pair_gt_insatences: + gt_instances.masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=_device) + all_ins_id = torch.cat([ + gt_instances.instances_ids + for gt_instances in pair_gt_insatences + ]) + all_ins_id = all_ins_id.unique().tolist() + map_ins_id = dict() + for i, ins_id in enumerate(all_ins_id): + map_ins_id[ins_id] = i + + num_instances = len(all_ins_id) + mask_shape = [ + num_instances, self.num_frames, + pair_gt_insatences[0].masks.shape[1], + pair_gt_insatences[0].masks.shape[2] + ] + gt_masks_per_video = torch.zeros( + mask_shape, dtype=torch.bool, device=_device) + gt_ids_per_video = torch.full((num_instances, self.num_frames), + -1, + dtype=torch.long, + device=_device) + gt_labels_per_video = torch.full((num_instances, ), + -1, + dtype=torch.long, + device=_device) + + for frame_id in range(self.num_frames): + cur_frame_gts = pair_gt_insatences[frame_id] + ins_ids = cur_frame_gts.instances_ids.tolist() + for i, id in enumerate(ins_ids): + gt_masks_per_video[map_ins_id[id], + frame_id, :, :] = cur_frame_gts.masks[i] + gt_ids_per_video[map_ins_id[id], + frame_id] = cur_frame_gts.instances_ids[i] + gt_labels_per_video[ + map_ins_id[id]] = cur_frame_gts.labels[i] + + tmp_instances = InstanceData( + labels=gt_labels_per_video, + masks=gt_masks_per_video.long(), + instances_id=gt_ids_per_video) + final_batch_gt_instances.append(tmp_instances) + + return final_batch_gt_instances + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> Tuple[Tensor]: + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, num_frames, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + img_meta (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, num_frames, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + # (num_gts, ) + gt_labels = gt_instances.labels + # (num_gts, num_frames, h, w) + gt_masks = gt_instances.masks + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred, + point_coords.repeat(num_queries, 1, + 1)).flatten(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.float(), + point_coords.repeat(num_gts, 1, + 1)).flatten(1) + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + assign_result = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances, + img_meta=img_meta) + pred_instances = InstanceData(scores=cls_score, masks=mask_pred) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries, )) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds, sampling_result) + + def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should include + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, num_frames,h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + avg_factor) = self.get_targets(cls_scores_list, mask_preds_list, + batch_gt_instances, batch_img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, num_frames, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, num_frames, h, w) + # -> (num_total_gts, num_frames, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.flatten(0, 1).unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts * num_frames, h, w) -> + # (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.flatten(0, 1).unsqueeze(1).float(), + points_coords).squeeze(1) + # shape (num_total_gts * num_frames, num_points) + mask_point_preds = point_sample( + mask_preds.flatten(0, 1).unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_total_gts * num_frames, num_points) -> + # (num_total_gts * num_frames * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points / self.num_frames) + + return loss_cls, loss_mask, loss_dice + + def _forward_head( + self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, + int]) -> Tuple[Tensor, Tensor, Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, t, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should include background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + cls_pred = self.cls_embed(decoder_out) + mask_embed = self.mask_embed(decoder_out) + + # shape (batch_size, num_queries, t, h, w) + mask_pred = torch.einsum('bqc,btchw->bqthw', mask_embed, mask_feature) + b, q, t, _, _ = mask_pred.shape + + attn_mask = F.interpolate( + mask_pred.flatten(0, 1), + attn_mask_target_size, + mode='bilinear', + align_corners=False).view(b, q, t, attn_mask_target_size[0], + attn_mask_target_size[1]) + + # shape (batch_size, num_queries, t, h, w) -> + # (batch_size, num_queries, t*h*w) -> + # (batch_size, num_head, num_queries, t*h*w) -> + # (batch_size*num_head, num_queries, t*h*w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward( + self, x: List[Tensor], data_samples: TrackDataSample + ) -> Tuple[List[Tensor], List[Tensor]]: + """Forward function. + + Args: + x (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + data_samples (List[:obj:`TrackDataSample`]): The Data + Samples. It usually includes information such as `gt_instance`. + + Returns: + tuple[list[Tensor]]: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should include background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + mask_features, multi_scale_memorys = self.pixel_decoder(x) + bt, c_m, h_m, w_m = mask_features.shape + batch_size = bt // self.num_frames if self.training else 1 + t = bt // batch_size + mask_features = mask_features.view(batch_size, t, c_m, h_m, w_m) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + decoder_input = decoder_input.flatten(2) + level_embed = self.level_embed.weight[i][None, :, None] + decoder_input = decoder_input + level_embed + _, c, hw = decoder_input.shape + # shape (batch_size*t, c, h, w) -> + # (batch_size, t, c, hw) -> + # (batch_size, t*h*w, c) + decoder_input = decoder_input.view(batch_size, t, c, + hw).permute(0, 1, 3, + 2).flatten(1, 2) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros( + (batch_size, t) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 3).permute(0, 1, 3, 2).flatten(1, 2) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (batch_size, num_queries, c) + query_feat = self.query_feat.weight.unsqueeze(0).repeat( + (batch_size, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat( + (batch_size, 1, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def loss( + self, + x: Tuple[Tensor], + data_samples: TrackSampleList, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the track head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + data_samples (List[:obj:`TrackDataSample`]): The Data + Samples. It usually includes information such as `gt_instance`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + + for data_sample in data_samples: + video_img_metas = defaultdict(list) + for image_idx in range(len(data_sample)): + batch_gt_instances.append(data_sample[image_idx].gt_instances) + for key, value in data_sample[image_idx].metainfo.items(): + video_img_metas[key].append(value) + batch_img_metas.append(video_img_metas) + + # forward + all_cls_scores, all_mask_preds = self(x, data_samples) + + # preprocess ground truth + batch_gt_instances = self.preprocess_gt(batch_gt_instances) + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, + x: Tuple[Tensor], + data_samples: TrackDataSample, + rescale: bool = True) -> InstanceList: + """Test without augmentation. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + data_samples (List[:obj:`TrackDataSample`]): The Data + Samples. It usually includes information such as `gt_instance`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + list[obj:`InstanceData`]: each contains the following keys + - labels (Tensor): Prediction class indices\ + for an image, with shape (n, ), n is the sum of\ + number of stuff type and number of instance in an image. + - masks (Tensor): Prediction mask for a\ + image, with shape (n, t, h, w). + """ + + batch_img_metas = [ + data_samples[img_idx].metainfo + for img_idx in range(len(data_samples)) + ] + all_cls_scores, all_mask_preds = self(x, data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + mask_cls_results = mask_cls_results[0] + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results[0], + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + + results = self.predict_by_feat(mask_cls_results, mask_pred_results, + batch_img_metas) + return results + + def predict_by_feat(self, + mask_cls_results: List[Tensor], + mask_pred_results: List[Tensor], + batch_img_metas: List[dict], + rescale: bool = True) -> InstanceList: + """Get top-10 predictions. + + Args: + mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should include background. + mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + batch_img_metas (list[dict]): List of image meta information. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + list[obj:`InstanceData`]: each contains the following keys + - labels (Tensor): Prediction class indices\ + for an image, with shape (n, ), n is the sum of\ + number of stuff type and number of instance in an image. + - masks (Tensor): Prediction mask for a\ + image, with shape (n, t, h, w). + """ + results = [] + if len(mask_cls_results) > 0: + scores = F.softmax(mask_cls_results, dim=-1)[:, :-1] + labels = torch.arange(self.num_classes).unsqueeze(0).repeat( + self.num_queries, 1).flatten(0, 1).to(scores.device) + # keep top-10 predictions + scores_per_image, topk_indices = scores.flatten(0, 1).topk( + 10, sorted=False) + labels_per_image = labels[topk_indices] + topk_indices = topk_indices // self.num_classes + mask_pred_results = mask_pred_results[topk_indices] + + img_shape = batch_img_metas[0]['img_shape'] + mask_pred_results = \ + mask_pred_results[:, :, :img_shape[0], :img_shape[1]] + if rescale: + # return result in original resolution + ori_height, ori_width = batch_img_metas[0]['ori_shape'][:2] + mask_pred_results = F.interpolate( + mask_pred_results, + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False) + + masks = mask_pred_results > 0. + + # format top-10 predictions + for img_idx in range(len(batch_img_metas)): + pred_track_instances = InstanceData() + + pred_track_instances.masks = masks[:, img_idx] + pred_track_instances.bboxes = mask2bbox(masks[:, img_idx]) + pred_track_instances.labels = labels_per_image + pred_track_instances.scores = scores_per_image + pred_track_instances.instances_id = torch.arange(10) + + results.append(pred_track_instances) + + return results diff --git a/mmdet/models/tracking_heads/quasi_dense_embed_head.py b/mmdet/models/tracking_heads/quasi_dense_embed_head.py new file mode 100644 index 0000000000000000000000000000000000000000..55e3c05b7aba188608f7dd2fdda54e0759cee03c --- /dev/null +++ b/mmdet/models/tracking_heads/quasi_dense_embed_head.py @@ -0,0 +1,347 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.models.task_modules import SamplingResult +from mmdet.registry import MODELS +from ..task_modules.tracking import embed_similarity + + +@MODELS.register_module() +class QuasiDenseEmbedHead(BaseModule): + """The quasi-dense roi embed head. + + Args: + embed_channels (int): The input channel of embed features. + Defaults to 256. + softmax_temp (int): Softmax temperature. Defaults to -1. + loss_track (dict): The loss function for tracking. Defaults to + MultiPosCrossEntropyLoss. + loss_track_aux (dict): The auxiliary loss function for tracking. + Defaults to MarginL2Loss. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ + + def __init__(self, + num_convs: int = 0, + num_fcs: int = 0, + roi_feat_size: int = 7, + in_channels: int = 256, + conv_out_channels: int = 256, + with_avg_pool: bool = False, + fc_out_channels: int = 1024, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + embed_channels: int = 256, + softmax_temp: int = -1, + loss_track: Optional[dict] = None, + loss_track_aux: dict = dict( + type='MarginL2Loss', + sample_ratio=3, + margin=0.3, + loss_weight=1.0, + hard_mining=True), + init_cfg: dict = dict( + type='Xavier', + layer='Linear', + distribution='uniform', + bias=0, + override=dict( + type='Normal', + name='fc_embed', + mean=0, + std=0.01, + bias=0))): + super(QuasiDenseEmbedHead, self).__init__(init_cfg=init_cfg) + self.num_convs = num_convs + self.num_fcs = num_fcs + self.roi_feat_size = _pair(roi_feat_size) + self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] + self.in_channels = in_channels + self.conv_out_channels = conv_out_channels + self.with_avg_pool = with_avg_pool + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.with_avg_pool: + self.avg_pool = nn.AvgPool2d(self.roi_feat_size) + # add convs and fcs + self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch( + self.num_convs, self.num_fcs, self.in_channels) + self.relu = nn.ReLU(inplace=True) + + if loss_track is None: + loss_track = dict( + type='MultiPosCrossEntropyLoss', loss_weight=0.25) + + self.fc_embed = nn.Linear(self.last_layer_dim, embed_channels) + self.softmax_temp = softmax_temp + self.loss_track = MODELS.build(loss_track) + if loss_track_aux is not None: + self.loss_track_aux = MODELS.build(loss_track_aux) + else: + self.loss_track_aux = None + + def _add_conv_fc_branch( + self, num_branch_convs: int, num_branch_fcs: int, + in_channels: int) -> Tuple[nn.ModuleList, nn.ModuleList, int]: + """Add shared or separable branch. convs -> avg pool (optional) -> fcs. + + Args: + num_branch_convs (int): The number of convoluational layers. + num_branch_fcs (int): The number of fully connection layers. + in_channels (int): The input channel of roi features. + + Returns: + Tuple[nn.ModuleList, nn.ModuleList, int]: The convs, fcs and the + last layer dimension. + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_channels = ( + last_layer_dim if i == 0 else self.conv_out_channels) + branch_convs.append( + ConvModule( + conv_in_channels, + self.conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + last_layer_dim = self.conv_out_channels + + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + if not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + branch_fcs.append( + nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + + return branch_convs, branch_fcs, last_layer_dim + + def forward(self, x: Tensor) -> Tensor: + """Forward function. + + Args: + x (Tensor): The input features from ROI head. + + Returns: + Tensor: The embedding feature map. + """ + + if self.num_convs > 0: + for conv in self.convs: + x = conv(x) + x = x.flatten(1) + if self.num_fcs > 0: + for fc in self.fcs: + x = self.relu(fc(x)) + x = self.fc_embed(x) + return x + + def get_targets( + self, gt_match_indices: List[Tensor], + key_sampling_results: List[SamplingResult], + ref_sampling_results: List[SamplingResult]) -> Tuple[List, List]: + """Calculate the track targets and track weights for all samples in a + batch according to the sampling_results. + + Args: + gt_match_indices (list(Tensor)): Mapping from gt_instance_ids to + ref_gt_instance_ids of the same tracklet in a pair of images. + key_sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + ref_sampling_results (List[obj:SamplingResult]): Assign results of + all reference images in a batch after sampling. + + Returns: + Tuple[list[Tensor]]: Association results. + Containing the following list of Tensors: + + - track_targets (list[Tensor]): The mapping instance ids from + all positive proposals in the key image to all proposals + in the reference image, each tensor in list has + shape (len(key_pos_bboxes), len(ref_bboxes)). + - track_weights (list[Tensor]): Loss weights for all positive + proposals in a batch, each tensor in list has + shape (len(key_pos_bboxes),). + """ + + track_targets = [] + track_weights = [] + for _gt_match_indices, key_res, ref_res in zip(gt_match_indices, + key_sampling_results, + ref_sampling_results): + targets = _gt_match_indices.new_zeros( + (key_res.pos_bboxes.size(0), ref_res.bboxes.size(0)), + dtype=torch.int) + _match_indices = _gt_match_indices[key_res.pos_assigned_gt_inds] + pos2pos = (_match_indices.view( + -1, 1) == ref_res.pos_assigned_gt_inds.view(1, -1)).int() + targets[:, :pos2pos.size(1)] = pos2pos + weights = (targets.sum(dim=1) > 0).float() + track_targets.append(targets) + track_weights.append(weights) + return track_targets, track_weights + + def match( + self, key_embeds: Tensor, ref_embeds: Tensor, + key_sampling_results: List[SamplingResult], + ref_sampling_results: List[SamplingResult] + ) -> Tuple[List[Tensor], List[Tensor]]: + """Calculate the dist matrixes for loss measurement. + + Args: + key_embeds (Tensor): Embeds of positive bboxes in sampling results + of key image. + ref_embeds (Tensor): Embeds of all bboxes in sampling results + of the reference image. + key_sampling_results (List[obj:SamplingResults]): Assign results of + all images in a batch after sampling. + ref_sampling_results (List[obj:SamplingResults]): Assign results of + all reference images in a batch after sampling. + + Returns: + Tuple[list[Tensor]]: Calculation results. + Containing the following list of Tensors: + + - dists (list[Tensor]): Dot-product dists between + key_embeds and ref_embeds, each tensor in list has + shape (len(key_pos_bboxes), len(ref_bboxes)). + - cos_dists (list[Tensor]): Cosine dists between + key_embeds and ref_embeds, each tensor in list has + shape (len(key_pos_bboxes), len(ref_bboxes)). + """ + + num_key_rois = [res.pos_bboxes.size(0) for res in key_sampling_results] + key_embeds = torch.split(key_embeds, num_key_rois) + num_ref_rois = [res.bboxes.size(0) for res in ref_sampling_results] + ref_embeds = torch.split(ref_embeds, num_ref_rois) + + dists, cos_dists = [], [] + for key_embed, ref_embed in zip(key_embeds, ref_embeds): + dist = embed_similarity( + key_embed, + ref_embed, + method='dot_product', + temperature=self.softmax_temp) + dists.append(dist) + if self.loss_track_aux is not None: + cos_dist = embed_similarity( + key_embed, ref_embed, method='cosine') + cos_dists.append(cos_dist) + else: + cos_dists.append(None) + return dists, cos_dists + + def loss(self, key_roi_feats: Tensor, ref_roi_feats: Tensor, + key_sampling_results: List[SamplingResult], + ref_sampling_results: List[SamplingResult], + gt_match_indices_list: List[Tensor]) -> dict: + """Calculate the track loss and the auxiliary track loss. + + Args: + key_roi_feats (Tensor): Embeds of positive bboxes in sampling + results of key image. + ref_roi_feats (Tensor): Embeds of all bboxes in sampling results + of the reference image. + key_sampling_results (List[obj:SamplingResults]): Assign results of + all images in a batch after sampling. + ref_sampling_results (List[obj:SamplingResults]): Assign results of + all reference images in a batch after sampling. + gt_match_indices_list (list(Tensor)): Mapping from gt_instances_ids + to ref_gt_instances_ids of the same tracklet in a pair of + images. + + Returns: + Dict [str: Tensor]: Calculation results. + Containing the following list of Tensors: + + - loss_track (Tensor): Results of loss_track function. + - loss_track_aux (Tensor): Results of loss_track_aux function. + """ + key_track_feats = self(key_roi_feats) + ref_track_feats = self(ref_roi_feats) + + losses = self.loss_by_feat(key_track_feats, ref_track_feats, + key_sampling_results, ref_sampling_results, + gt_match_indices_list) + return losses + + def loss_by_feat(self, key_track_feats: Tensor, ref_track_feats: Tensor, + key_sampling_results: List[SamplingResult], + ref_sampling_results: List[SamplingResult], + gt_match_indices_list: List[Tensor]) -> dict: + """Calculate the track loss and the auxiliary track loss. + + Args: + key_track_feats (Tensor): Embeds of positive bboxes in sampling + results of key image. + ref_track_feats (Tensor): Embeds of all bboxes in sampling results + of the reference image. + key_sampling_results (List[obj:SamplingResults]): Assign results of + all images in a batch after sampling. + ref_sampling_results (List[obj:SamplingResults]): Assign results of + all reference images in a batch after sampling. + gt_match_indices_list (list(Tensor)): Mapping from instances_ids + from key image to reference image of the same tracklet in a + pair of images. + + Returns: + Dict [str: Tensor]: Calculation results. + Containing the following list of Tensors: + + - loss_track (Tensor): Results of loss_track function. + - loss_track_aux (Tensor): Results of loss_track_aux function. + """ + dists, cos_dists = self.match(key_track_feats, ref_track_feats, + key_sampling_results, + ref_sampling_results) + targets, weights = self.get_targets(gt_match_indices_list, + key_sampling_results, + ref_sampling_results) + losses = dict() + + loss_track = 0. + loss_track_aux = 0. + for _dists, _cos_dists, _targets, _weights in zip( + dists, cos_dists, targets, weights): + loss_track += self.loss_track( + _dists, _targets, _weights, avg_factor=_weights.sum()) + if self.loss_track_aux is not None: + loss_track_aux += self.loss_track_aux(_cos_dists, _targets) + losses['loss_track'] = loss_track / len(dists) + + if self.loss_track_aux is not None: + losses['loss_track_aux'] = loss_track_aux / len(dists) + + return losses + + def predict(self, bbox_feats: Tensor) -> Tensor: + """Perform forward propagation of the tracking head and predict + tracking results on the features of the upstream network. + + Args: + bbox_feats: The extracted roi features. + + Returns: + Tensor: The extracted track features. + """ + track_feats = self(bbox_feats) + return track_feats diff --git a/mmdet/models/tracking_heads/quasi_dense_track_head.py b/mmdet/models/tracking_heads/quasi_dense_track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bd078dac827e35c7514330870cf884001985156b --- /dev/null +++ b/mmdet/models/tracking_heads/quasi_dense_track_head.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import TrackSampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList + + +@MODELS.register_module() +class QuasiDenseTrackHead(BaseModule): + """The quasi-dense track head.""" + + def __init__(self, + roi_extractor: Optional[dict] = None, + embed_head: Optional[dict] = None, + regress_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if embed_head is not None: + self.init_embed_head(roi_extractor, embed_head) + + if regress_head is not None: + raise NotImplementedError('Regression head is not supported yet.') + + self.init_assigner_sampler() + + def init_embed_head(self, roi_extractor, embed_head) -> None: + """Initialize ``embed_head`` + + Args: + roi_extractor (dict, optional): Configuration of roi extractor. + Defaults to None. + embed_head (dict, optional): Configuration of embed head. Defaults + to None. + """ + self.roi_extractor = MODELS.build(roi_extractor) + self.embed_head = MODELS.build(embed_head) + + def init_assigner_sampler(self) -> None: + """Initialize assigner and sampler.""" + self.bbox_assigner = None + self.bbox_sampler = None + if self.train_cfg: + self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner) + self.bbox_sampler = TASK_UTILS.build( + self.train_cfg.sampler, default_args=dict(context=self)) + + @property + def with_track(self) -> bool: + """bool: whether the multi-object tracker has an embed head""" + return hasattr(self, 'embed_head') and self.embed_head is not None + + def extract_roi_feats(self, feats: List[Tensor], + bboxes: List[Tensor]) -> Tensor: + """Extract roi features. + + Args: + feats (list[Tensor]): list of multi-level image features. + bboxes (list[Tensor]): list of bboxes in sampling result. + + Returns: + Tensor: The extracted roi features. + """ + rois = bbox2roi(bboxes) + bbox_feats = self.roi_extractor(feats[:self.roi_extractor.num_inputs], + rois) + return bbox_feats + + def loss(self, key_feats: List[Tensor], ref_feats: List[Tensor], + rpn_results_list: InstanceList, + ref_rpn_results_list: InstanceList, data_samples: TrackSampleList, + **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + key_feats (list[Tensor]): list of multi-level image features. + ref_feats (list[Tensor]): list of multi-level ref_img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals of key img. + ref_rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals of ref img. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict: A dictionary of loss components. + """ + assert self.with_track + num_imgs = len(data_samples) + batch_gt_instances = [] + ref_batch_gt_instances = [] + batch_gt_instances_ignore = [] + gt_match_indices_list = [] + for track_data_sample in data_samples: + key_data_sample = track_data_sample.get_key_frames()[0] + ref_data_sample = track_data_sample.get_ref_frames()[0] + batch_gt_instances.append(key_data_sample.gt_instances) + ref_batch_gt_instances.append(ref_data_sample.gt_instances) + if 'ignored_instances' in key_data_sample: + batch_gt_instances_ignore.append( + key_data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + # get gt_match_indices + ins_ids = key_data_sample.gt_instances.instances_ids.tolist() + ref_ins_ids = ref_data_sample.gt_instances.instances_ids.tolist() + match_indices = Tensor([ + ref_ins_ids.index(i) if (i in ref_ins_ids and i > 0) else -1 + for i in ins_ids + ]).to(key_feats[0].device) + gt_match_indices_list.append(match_indices) + + key_sampling_results, ref_sampling_results = [], [] + for i in range(num_imgs): + rpn_results = rpn_results_list[i] + ref_rpn_results = ref_rpn_results_list[i] + # rename ref_rpn_results.bboxes to ref_rpn_results.priors + ref_rpn_results.priors = ref_rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in key_feats]) + key_sampling_results.append(sampling_result) + + ref_assign_result = self.bbox_assigner.assign( + ref_rpn_results, ref_batch_gt_instances[i], + batch_gt_instances_ignore[i]) + ref_sampling_result = self.bbox_sampler.sample( + ref_assign_result, + ref_rpn_results, + ref_batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in ref_feats]) + ref_sampling_results.append(ref_sampling_result) + + key_bboxes = [res.pos_bboxes for res in key_sampling_results] + key_roi_feats = self.extract_roi_feats(key_feats, key_bboxes) + ref_bboxes = [res.bboxes for res in ref_sampling_results] + ref_roi_feats = self.extract_roi_feats(ref_feats, ref_bboxes) + + loss_track = self.embed_head.loss(key_roi_feats, ref_roi_feats, + key_sampling_results, + ref_sampling_results, + gt_match_indices_list) + + return loss_track + + def predict(self, feats: List[Tensor], + rescaled_bboxes: List[Tensor]) -> Tensor: + """Perform forward propagation of the tracking head and predict + tracking results on the features of the upstream network. + + Args: + feats (list[Tensor]): Multi level feature maps of `img`. + rescaled_bboxes (list[Tensor]): list of rescaled bboxes in sampling + result. + + Returns: + Tensor: The extracted track features. + """ + bbox_feats = self.extract_roi_feats(feats, rescaled_bboxes) + track_feats = self.embed_head.predict(bbox_feats) + return track_feats diff --git a/mmdet/models/tracking_heads/roi_embed_head.py b/mmdet/models/tracking_heads/roi_embed_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e18b81fbe52e109e7afb3e6d5e8e6624ef48242f --- /dev/null +++ b/mmdet/models/tracking_heads/roi_embed_head.py @@ -0,0 +1,391 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor +from torch.nn.modules.utils import _pair + +from mmdet.models.losses import accuracy +from mmdet.models.task_modules import SamplingResult +from mmdet.models.task_modules.tracking import embed_similarity +from mmdet.registry import MODELS + + +@MODELS.register_module() +class RoIEmbedHead(BaseModule): + """The roi embed head. + + This module is used in multi-object tracking methods, such as MaskTrack + R-CNN. + + Args: + num_convs (int): The number of convoluational layers to embed roi + features. Defaults to 0. + num_fcs (int): The number of fully connection layers to embed roi + features. Defaults to 0. + roi_feat_size (int|tuple(int)): The spatial size of roi features. + Defaults to 7. + in_channels (int): The input channel of roi features. Defaults to 256. + conv_out_channels (int): The output channel of roi features after + forwarding convoluational layers. Defaults to 256. + with_avg_pool (bool): Whether use average pooling before passing roi + features into fully connection layers. Defaults to False. + fc_out_channels (int): The output channel of roi features after + forwarding fully connection layers. Defaults to 1024. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Defaults to None. + loss_match (dict): The loss function. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + init_cfg (dict): Configuration of initialization. Defaults to None. + """ + + def __init__(self, + num_convs: int = 0, + num_fcs: int = 0, + roi_feat_size: int = 7, + in_channels: int = 256, + conv_out_channels: int = 256, + with_avg_pool: bool = False, + fc_out_channels: int = 1024, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + loss_match: dict = dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + init_cfg: Optional[dict] = None, + **kwargs): + super(RoIEmbedHead, self).__init__(init_cfg=init_cfg) + self.num_convs = num_convs + self.num_fcs = num_fcs + self.roi_feat_size = _pair(roi_feat_size) + self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] + self.in_channels = in_channels + self.conv_out_channels = conv_out_channels + self.with_avg_pool = with_avg_pool + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.loss_match = MODELS.build(loss_match) + self.fp16_enabled = False + + if self.with_avg_pool: + self.avg_pool = nn.AvgPool2d(self.roi_feat_size) + # add convs and fcs + self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch( + self.num_convs, self.num_fcs, self.in_channels) + self.relu = nn.ReLU(inplace=True) + + def _add_conv_fc_branch( + self, num_branch_convs: int, num_branch_fcs: int, + in_channels: int) -> Tuple[nn.ModuleList, nn.ModuleList, int]: + """Add shared or separable branch. + + convs -> avg pool (optional) -> fcs + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_channels = ( + last_layer_dim if i == 0 else self.conv_out_channels) + branch_convs.append( + ConvModule( + conv_in_channels, + self.conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + last_layer_dim = self.conv_out_channels + + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + if not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + branch_fcs.append( + nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + + return branch_convs, branch_fcs, last_layer_dim + + @property + def custom_activation(self): + return getattr(self.loss_match, 'custom_activation', False) + + def extract_feat(self, x: Tensor, + num_x_per_img: List[int]) -> Tuple[Tensor]: + """Extract feature from the input `x`, and split the output to a list. + + Args: + x (Tensor): of shape [N, C, H, W]. N is the number of proposals. + num_x_per_img (list[int]): The `x` contains proposals of + multi-images. `num_x_per_img` denotes the number of proposals + for each image. + + Returns: + list[Tensor]: Each Tensor denotes the embed features belonging to + an image in a batch. + """ + if self.num_convs > 0: + for conv in self.convs: + x = conv(x) + + if self.num_fcs > 0: + if self.with_avg_pool: + x = self.avg_pool(x) + x = x.flatten(1) + for fc in self.fcs: + x = self.relu(fc(x)) + else: + x = x.flatten(1) + + x_split = torch.split(x, num_x_per_img, dim=0) + return x_split + + def forward( + self, x: Tensor, ref_x: Tensor, num_x_per_img: List[int], + num_x_per_ref_img: List[int] + ) -> Tuple[Tuple[Tensor], Tuple[Tensor]]: + """Computing the similarity scores between `x` and `ref_x`. + + Args: + x (Tensor): of shape [N, C, H, W]. N is the number of key frame + proposals. + ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference + frame proposals. + num_x_per_img (list[int]): The `x` contains proposals of + multi-images. `num_x_per_img` denotes the number of proposals + for each key image. + num_x_per_ref_img (list[int]): The `ref_x` contains proposals of + multi-images. `num_x_per_ref_img` denotes the number of + proposals for each reference image. + + Returns: + tuple[tuple[Tensor], tuple[Tensor]]: Each tuple of tensor denotes + the embed features belonging to an image in a batch. + """ + x_split = self.extract_feat(x, num_x_per_img) + ref_x_split = self.extract_feat(ref_x, num_x_per_ref_img) + + return x_split, ref_x_split + + def get_targets(self, sampling_results: List[SamplingResult], + gt_instance_ids: List[Tensor], + ref_gt_instance_ids: List[Tensor]) -> Tuple[List, List]: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of + all images in a batch, each tensor has shape (num_gt, ). + ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes + of all reference images in a batch, each tensor has shape + (num_gt, ). + + Returns: + Tuple[list[Tensor]]: Ground truth for proposals in a batch. + Containing the following list of Tensors: + + - track_id_targets (list[Tensor]): The instance ids of + Gt_labels for all proposals in a batch, each tensor in list + has shape (num_proposals,). + - track_id_weights (list[Tensor]): Labels_weights for + all proposals in a batch, each tensor in list has + shape (num_proposals,). + """ + track_id_targets = [] + track_id_weights = [] + + for res, gt_instance_id, ref_gt_instance_id in zip( + sampling_results, gt_instance_ids, ref_gt_instance_ids): + pos_instance_ids = gt_instance_id[res.pos_assigned_gt_inds] + pos_match_id = gt_instance_id.new_zeros(len(pos_instance_ids)) + for i, id in enumerate(pos_instance_ids): + if id in ref_gt_instance_id: + pos_match_id[i] = ref_gt_instance_id.tolist().index(id) + 1 + + track_id_target = gt_instance_id.new_zeros( + len(res.bboxes), dtype=torch.int64) + track_id_target[:len(res.pos_bboxes)] = pos_match_id + track_id_weight = res.bboxes.new_zeros(len(res.bboxes)) + track_id_weight[:len(res.pos_bboxes)] = 1.0 + + track_id_targets.append(track_id_target) + track_id_weights.append(track_id_weight) + + return track_id_targets, track_id_weights + + def loss( + self, + bbox_feats: Tensor, + ref_bbox_feats: Tensor, + num_bbox_per_img: int, + num_bbox_per_ref_img: int, + sampling_results: List[SamplingResult], + gt_instance_ids: List[Tensor], + ref_gt_instance_ids: List[Tensor], + reduction_override: Optional[str] = None, + ) -> dict: + """Calculate the loss in a batch. + + Args: + bbox_feats (Tensor): of shape [N, C, H, W]. N is the number of + bboxes. + ref_bbox_feats (Tensor): of shape [M, C, H, W]. M is the number of + reference bboxes. + num_bbox_per_img (list[int]): The `bbox_feats` contains proposals + of multi-images. `num_bbox_per_img` denotes the number of + proposals for each key image. + num_bbox_per_ref_img (list[int]): The `ref_bbox_feats` contains + proposals of multi-images. `num_bbox_per_ref_img` denotes the + number of proposals for each reference image. + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of + all images in a batch, each tensor has shape (num_gt, ). + ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes + of all reference images in a batch, each tensor has shape + (num_gt, ). + reduction_override (str, optional): The method used to reduce the + loss. Options are "none", "mean" and "sum". + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + x_split, ref_x_split = self(bbox_feats, ref_bbox_feats, + num_bbox_per_img, num_bbox_per_ref_img) + + losses = self.loss_by_feat(x_split, ref_x_split, sampling_results, + gt_instance_ids, ref_gt_instance_ids, + reduction_override) + return losses + + def loss_by_feat(self, + x_split: Tuple[Tensor], + ref_x_split: Tuple[Tensor], + sampling_results: List[SamplingResult], + gt_instance_ids: List[Tensor], + ref_gt_instance_ids: List[Tensor], + reduction_override: Optional[str] = None) -> dict: + """Calculate losses. + + Args: + x_split (Tensor): The embed features belonging to key image. + ref_x_split (Tensor): The embed features belonging to ref image. + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of + all images in a batch, each tensor has shape (num_gt, ). + ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes + of all reference images in a batch, each tensor has shape + (num_gt, ). + reduction_override (str, optional): The method used to reduce the + loss. Options are "none", "mean" and "sum". + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + track_id_targets, track_id_weights = self.get_targets( + sampling_results, gt_instance_ids, ref_gt_instance_ids) + assert isinstance(track_id_targets, list) + assert isinstance(track_id_weights, list) + assert len(track_id_weights) == len(track_id_targets) + + losses = defaultdict(list) + similarity_logits = [] + for one_x, one_ref_x in zip(x_split, ref_x_split): + similarity_logit = embed_similarity( + one_x, one_ref_x, method='dot_product') + dummy = similarity_logit.new_zeros(one_x.shape[0], 1) + similarity_logit = torch.cat((dummy, similarity_logit), dim=1) + similarity_logits.append(similarity_logit) + assert isinstance(similarity_logits, list) + assert len(similarity_logits) == len(track_id_targets) + + for similarity_logit, track_id_target, track_id_weight in zip( + similarity_logits, track_id_targets, track_id_weights): + avg_factor = max(torch.sum(track_id_target > 0).float().item(), 1.) + if similarity_logit.numel() > 0: + loss_match = self.loss_match( + similarity_logit, + track_id_target, + track_id_weight, + avg_factor=avg_factor, + reduction_override=reduction_override) + if isinstance(loss_match, dict): + for key, value in loss_match.items(): + losses[key].append(value) + else: + losses['loss_match'].append(loss_match) + + valid_index = track_id_weight > 0 + valid_similarity_logit = similarity_logit[valid_index] + valid_track_id_target = track_id_target[valid_index] + if self.custom_activation: + match_accuracy = self.loss_match.get_accuracy( + valid_similarity_logit, valid_track_id_target) + for key, value in match_accuracy.items(): + losses[key].append(value) + else: + losses['match_accuracy'].append( + accuracy(valid_similarity_logit, + valid_track_id_target)) + + for key, value in losses.items(): + losses[key] = sum(losses[key]) / len(similarity_logits) + return losses + + def predict(self, roi_feats: Tensor, + prev_roi_feats: Tensor) -> List[Tensor]: + """Perform forward propagation of the tracking head and predict + tracking results on the features of the upstream network. + + Args: + roi_feats (Tensor): Feature map of current images rois. + prev_roi_feats (Tensor): Feature map of previous images rois. + + Returns: + list[Tensor]: The predicted similarity_logits of each pair of key + image and reference image. + """ + x_split, ref_x_split = self(roi_feats, prev_roi_feats, + [roi_feats.shape[0]], + [prev_roi_feats.shape[0]]) + + similarity_logits = self.predict_by_feat(x_split, ref_x_split) + + return similarity_logits + + def predict_by_feat(self, x_split: Tuple[Tensor], + ref_x_split: Tuple[Tensor]) -> List[Tensor]: + """Get similarity_logits. + + Args: + x_split (Tensor): The embed features belonging to key image. + ref_x_split (Tensor): The embed features belonging to ref image. + + Returns: + list[Tensor]: The predicted similarity_logits of each pair of key + image and reference image. + """ + similarity_logits = [] + for one_x, one_ref_x in zip(x_split, ref_x_split): + similarity_logit = embed_similarity( + one_x, one_ref_x, method='dot_product') + dummy = similarity_logit.new_zeros(one_x.shape[0], 1) + similarity_logit = torch.cat((dummy, similarity_logit), dim=1) + similarity_logits.append(similarity_logit) + return similarity_logits diff --git a/mmdet/models/tracking_heads/roi_track_head.py b/mmdet/models/tracking_heads/roi_track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c51c810022cc856411e1de83278e38fdc2b670c8 --- /dev/null +++ b/mmdet/models/tracking_heads/roi_track_head.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta +from typing import List, Optional, Tuple + +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import TrackSampleList +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList + + +@MODELS.register_module() +class RoITrackHead(BaseModule, metaclass=ABCMeta): + """The roi track head. + + This module is used in multi-object tracking methods, such as MaskTrack + R-CNN. + + Args: + roi_extractor (dict): Configuration of roi extractor. Defaults to None. + embed_head (dict): Configuration of embed head. Defaults to None. + train_cfg (dict): Configuration when training. Defaults to None. + test_cfg (dict): Configuration when testing. Defaults to None. + init_cfg (dict): Configuration of initialization. Defaults to None. + """ + + def __init__(self, + roi_extractor: Optional[dict] = None, + embed_head: Optional[dict] = None, + regress_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + *args, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if embed_head is not None: + self.init_embed_head(roi_extractor, embed_head) + + if regress_head is not None: + raise NotImplementedError('Regression head is not supported yet.') + + self.init_assigner_sampler() + + def init_embed_head(self, roi_extractor, embed_head) -> None: + """Initialize ``embed_head``""" + self.roi_extractor = MODELS.build(roi_extractor) + self.embed_head = MODELS.build(embed_head) + + def init_assigner_sampler(self) -> None: + """Initialize assigner and sampler.""" + self.bbox_assigner = None + self.bbox_sampler = None + if self.train_cfg: + self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner) + self.bbox_sampler = TASK_UTILS.build( + self.train_cfg.sampler, default_args=dict(context=self)) + + @property + def with_track(self) -> bool: + """bool: whether the multi-object tracker has an embed head""" + return hasattr(self, 'embed_head') and self.embed_head is not None + + def extract_roi_feats( + self, feats: List[Tensor], + bboxes: List[Tensor]) -> Tuple[Tuple[Tensor], List[int]]: + """Extract roi features. + + Args: + feats (list[Tensor]): list of multi-level image features. + bboxes (list[Tensor]): list of bboxes in sampling result. + + Returns: + tuple[tuple[Tensor], list[int]]: The extracted roi features and + the number of bboxes in each image. + """ + rois = bbox2roi(bboxes) + bbox_feats = self.roi_extractor(feats[:self.roi_extractor.num_inputs], + rois) + num_bbox_per_img = [len(bbox) for bbox in bboxes] + return bbox_feats, num_bbox_per_img + + def loss(self, key_feats: List[Tensor], ref_feats: List[Tensor], + rpn_results_list: InstanceList, data_samples: TrackSampleList, + **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + key_feats (list[Tensor]): list of multi-level image features. + ref_feats (list[Tensor]): list of multi-level ref_img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict: A dictionary of loss components. + """ + assert self.with_track + batch_gt_instances = [] + ref_batch_gt_instances = [] + batch_gt_instances_ignore = [] + gt_instance_ids = [] + ref_gt_instance_ids = [] + for track_data_sample in data_samples: + key_data_sample = track_data_sample.get_key_frames()[0] + ref_data_sample = track_data_sample.get_ref_frames()[0] + batch_gt_instances.append(key_data_sample.gt_instances) + ref_batch_gt_instances.append(ref_data_sample.gt_instances) + if 'ignored_instances' in key_data_sample: + batch_gt_instances_ignore.append( + key_data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + + gt_instance_ids.append(key_data_sample.gt_instances.instances_ids) + ref_gt_instance_ids.append( + ref_data_sample.gt_instances.instances_ids) + + losses = dict() + num_imgs = len(data_samples) + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + sampling_results = [] + for i in range(num_imgs): + rpn_results = rpn_results_list[i] + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in key_feats]) + sampling_results.append(sampling_result) + + bboxes = [res.bboxes for res in sampling_results] + bbox_feats, num_bbox_per_img = self.extract_roi_feats( + key_feats, bboxes) + + # batch_size is 1 + ref_gt_bboxes = [ + ref_batch_gt_instance.bboxes + for ref_batch_gt_instance in ref_batch_gt_instances + ] + ref_bbox_feats, num_bbox_per_ref_img = self.extract_roi_feats( + ref_feats, ref_gt_bboxes) + + loss_track = self.embed_head.loss(bbox_feats, ref_bbox_feats, + num_bbox_per_img, + num_bbox_per_ref_img, + sampling_results, gt_instance_ids, + ref_gt_instance_ids) + losses.update(loss_track) + + return losses + + def predict(self, roi_feats: Tensor, + prev_roi_feats: Tensor) -> List[Tensor]: + """Perform forward propagation of the tracking head and predict + tracking results on the features of the upstream network. + + Args: + roi_feats (Tensor): Feature map of current images rois. + prev_roi_feats (Tensor): Feature map of previous images rois. + + Returns: + list[Tensor]: The predicted similarity_logits of each pair of key + image and reference image. + """ + return self.embed_head.predict(roi_feats, prev_roi_feats)[0] diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a00d9a37f33169dc1c523c68db55f823dd0424fa --- /dev/null +++ b/mmdet/models/utils/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .gaussian_target import (gather_feat, gaussian_radius, + gen_gaussian_target, get_local_maximum, + get_topk_from_heatmap, transpose_and_gather_feat) +from .image import imrenormalize +from .make_divisible import make_divisible +# Disable yapf because it conflicts with isort. +# yapf: disable +from .misc import (align_tensor, aligned_bilinear, center_of_mass, + empty_instances, filter_gt_instances, + filter_scores_and_topk, flip_tensor, generate_coordinate, + images_to_levels, interpolate_as, levels_to_images, + mask2ndarray, multi_apply, relative_coordinate_maps, + rename_loss_dict, reweight_loss_dict, + samplelist_boxtype2tensor, select_single_mlvl, + sigmoid_geometric_mean, unfold_wo_center, unmap, + unpack_gt_instances) +from .panoptic_gt_processing import preprocess_panoptic_gt +from .point_sample import (get_uncertain_point_coords_with_randomness, + get_uncertainty) +from .vlfuse_helper import BertEncoderLayer, VLFuse, permute_and_flatten +from .wbf import weighted_boxes_fusion + +__all__ = [ + 'gaussian_radius', 'gen_gaussian_target', 'make_divisible', + 'get_local_maximum', 'get_topk_from_heatmap', 'transpose_and_gather_feat', + 'interpolate_as', 'sigmoid_geometric_mean', 'gather_feat', + 'preprocess_panoptic_gt', 'get_uncertain_point_coords_with_randomness', + 'get_uncertainty', 'unpack_gt_instances', 'empty_instances', + 'center_of_mass', 'filter_scores_and_topk', 'flip_tensor', + 'generate_coordinate', 'levels_to_images', 'mask2ndarray', 'multi_apply', + 'select_single_mlvl', 'unmap', 'images_to_levels', + 'samplelist_boxtype2tensor', 'filter_gt_instances', 'rename_loss_dict', + 'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear', + 'unfold_wo_center', 'imrenormalize', 'VLFuse', 'permute_and_flatten', + 'BertEncoderLayer', 'align_tensor', 'weighted_boxes_fusion' +] diff --git a/mmdet/models/utils/gaussian_target.py b/mmdet/models/utils/gaussian_target.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf4d558ce05c4f953e1c3fcf75016e5874afce1 --- /dev/null +++ b/mmdet/models/utils/gaussian_target.py @@ -0,0 +1,268 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import sqrt + +import torch +import torch.nn.functional as F + + +def gaussian2D(radius, sigma=1, dtype=torch.float32, device='cpu'): + """Generate 2D gaussian kernel. + + Args: + radius (int): Radius of gaussian kernel. + sigma (int): Sigma of gaussian function. Default: 1. + dtype (torch.dtype): Dtype of gaussian tensor. Default: torch.float32. + device (str): Device of gaussian tensor. Default: 'cpu'. + + Returns: + h (Tensor): Gaussian kernel with a + ``(2 * radius + 1) * (2 * radius + 1)`` shape. + """ + x = torch.arange( + -radius, radius + 1, dtype=dtype, device=device).view(1, -1) + y = torch.arange( + -radius, radius + 1, dtype=dtype, device=device).view(-1, 1) + + h = (-(x * x + y * y) / (2 * sigma * sigma)).exp() + + h[h < torch.finfo(h.dtype).eps * h.max()] = 0 + return h + + +def gen_gaussian_target(heatmap, center, radius, k=1): + """Generate 2D gaussian heatmap. + + Args: + heatmap (Tensor): Input heatmap, the gaussian kernel will cover on + it and maintain the max value. + center (list[int]): Coord of gaussian kernel's center. + radius (int): Radius of gaussian kernel. + k (int): Coefficient of gaussian kernel. Default: 1. + + Returns: + out_heatmap (Tensor): Updated heatmap covered by gaussian kernel. + """ + diameter = 2 * radius + 1 + gaussian_kernel = gaussian2D( + radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device) + + x, y = center + + height, width = heatmap.shape[:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian_kernel[radius - top:radius + bottom, + radius - left:radius + right] + out_heatmap = heatmap + torch.max( + masked_heatmap, + masked_gaussian * k, + out=out_heatmap[y - top:y + bottom, x - left:x + right]) + + return out_heatmap + + +def gaussian_radius(det_size, min_overlap): + r"""Generate 2D gaussian radius. + + This function is modified from the `official github repo + `_. + + Given ``min_overlap``, radius could computed by a quadratic equation + according to Vieta's formulas. + + There are 3 cases for computing gaussian radius, details are following: + + - Explanation of figure: ``lt`` and ``br`` indicates the left-top and + bottom-right corner of ground truth box. ``x`` indicates the + generated corner at the limited position when ``radius=r``. + + - Case1: one corner is inside the gt box and the other is outside. + + .. code:: text + + |< width >| + + lt-+----------+ - + | | | ^ + +--x----------+--+ + | | | | + | | | | height + | | overlap | | + | | | | + | | | | v + +--+---------br--+ - + | | | + +----------+--x + + To ensure IoU of generated box and gt box is larger than ``min_overlap``: + + .. math:: + \cfrac{(w-r)*(h-r)}{w*h+(w+h)r-r^2} \ge {iou} \quad\Rightarrow\quad + {r^2-(w+h)r+\cfrac{1-iou}{1+iou}*w*h} \ge 0 \\ + {a} = 1,\quad{b} = {-(w+h)},\quad{c} = {\cfrac{1-iou}{1+iou}*w*h} + {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a} + + - Case2: both two corners are inside the gt box. + + .. code:: text + + |< width >| + + lt-+----------+ - + | | | ^ + +--x-------+ | + | | | | + | |overlap| | height + | | | | + | +-------x--+ + | | | v + +----------+-br - + + To ensure IoU of generated box and gt box is larger than ``min_overlap``: + + .. math:: + \cfrac{(w-2*r)*(h-2*r)}{w*h} \ge {iou} \quad\Rightarrow\quad + {4r^2-2(w+h)r+(1-iou)*w*h} \ge 0 \\ + {a} = 4,\quad {b} = {-2(w+h)},\quad {c} = {(1-iou)*w*h} + {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a} + + - Case3: both two corners are outside the gt box. + + .. code:: text + + |< width >| + + x--+----------------+ + | | | + +-lt-------------+ | - + | | | | ^ + | | | | + | | overlap | | height + | | | | + | | | | v + | +------------br--+ - + | | | + +----------------+--x + + To ensure IoU of generated box and gt box is larger than ``min_overlap``: + + .. math:: + \cfrac{w*h}{(w+2*r)*(h+2*r)} \ge {iou} \quad\Rightarrow\quad + {4*iou*r^2+2*iou*(w+h)r+(iou-1)*w*h} \le 0 \\ + {a} = {4*iou},\quad {b} = {2*iou*(w+h)},\quad {c} = {(iou-1)*w*h} \\ + {r} \le \cfrac{-b+\sqrt{b^2-4*a*c}}{2*a} + + Args: + det_size (list[int]): Shape of object. + min_overlap (float): Min IoU with ground truth for boxes generated by + keypoints inside the gaussian kernel. + + Returns: + radius (int): Radius of gaussian kernel. + """ + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 - sq1) / (2 * a1) + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 - sq2) / (2 * a2) + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / (2 * a3) + return min(r1, r2, r3) + + +def get_local_maximum(heat, kernel=3): + """Extract local maximum pixel with given kernel. + + Args: + heat (Tensor): Target heatmap. + kernel (int): Kernel size of max pooling. Default: 3. + + Returns: + heat (Tensor): A heatmap where local maximum pixels maintain its + own value and other positions are 0. + """ + pad = (kernel - 1) // 2 + hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad) + keep = (hmax == heat).float() + return heat * keep + + +def get_topk_from_heatmap(scores, k=20): + """Get top k positions from heatmap. + + Args: + scores (Tensor): Target heatmap with shape + [batch, num_classes, height, width]. + k (int): Target number. Default: 20. + + Returns: + tuple[torch.Tensor]: Scores, indexes, categories and coords of + topk keypoint. Containing following Tensors: + + - topk_scores (Tensor): Max scores of each topk keypoint. + - topk_inds (Tensor): Indexes of each topk keypoint. + - topk_clses (Tensor): Categories of each topk keypoint. + - topk_ys (Tensor): Y-coord of each topk keypoint. + - topk_xs (Tensor): X-coord of each topk keypoint. + """ + batch, _, height, width = scores.size() + topk_scores, topk_inds = torch.topk(scores.view(batch, -1), k) + topk_clses = topk_inds // (height * width) + topk_inds = topk_inds % (height * width) + topk_ys = topk_inds // width + topk_xs = (topk_inds % width).int().float() + return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs + + +def gather_feat(feat, ind, mask=None): + """Gather feature according to index. + + Args: + feat (Tensor): Target feature map. + ind (Tensor): Target coord index. + mask (Tensor | None): Mask of feature map. Default: None. + + Returns: + feat (Tensor): Gathered feature. + """ + dim = feat.size(2) + ind = ind.unsqueeze(2).repeat(1, 1, dim) + feat = feat.gather(1, ind) + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + + +def transpose_and_gather_feat(feat, ind): + """Transpose and gather feature according to index. + + Args: + feat (Tensor): Target feature map. + ind (Tensor): Target coord index. + + Returns: + feat (Tensor): Transposed and gathered feature. + """ + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = gather_feat(feat, ind) + return feat diff --git a/mmdet/models/utils/image.py b/mmdet/models/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..16b5787a78232e46f47585c99526ca2b4ca9d1a1 --- /dev/null +++ b/mmdet/models/utils/image.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import mmcv +import numpy as np +import torch +from torch import Tensor + + +def imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict, + new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]: + """Re-normalize the image. + + Args: + img (Tensor | ndarray): Input image. If the input is a Tensor, the + shape is (1, C, H, W). If the input is a ndarray, the shape + is (H, W, C). + img_norm_cfg (dict): Original configuration for the normalization. + new_img_norm_cfg (dict): New configuration for the normalization. + + Returns: + Tensor | ndarray: Output image with the same type and shape of + the input. + """ + if isinstance(img, torch.Tensor): + assert img.ndim == 4 and img.shape[0] == 1 + new_img = img.squeeze(0).cpu().numpy().transpose(1, 2, 0) + new_img = _imrenormalize(new_img, img_norm_cfg, new_img_norm_cfg) + new_img = new_img.transpose(2, 0, 1)[None] + return torch.from_numpy(new_img).to(img) + else: + return _imrenormalize(img, img_norm_cfg, new_img_norm_cfg) + + +def _imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict, + new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]: + """Re-normalize the image.""" + img_norm_cfg = img_norm_cfg.copy() + new_img_norm_cfg = new_img_norm_cfg.copy() + for k, v in img_norm_cfg.items(): + if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray): + img_norm_cfg[k] = np.array(v, dtype=img.dtype) + # reverse cfg + if 'bgr_to_rgb' in img_norm_cfg: + img_norm_cfg['rgb_to_bgr'] = img_norm_cfg['bgr_to_rgb'] + img_norm_cfg.pop('bgr_to_rgb') + for k, v in new_img_norm_cfg.items(): + if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray): + new_img_norm_cfg[k] = np.array(v, dtype=img.dtype) + img = mmcv.imdenormalize(img, **img_norm_cfg) + img = mmcv.imnormalize(img, **new_img_norm_cfg) + return img diff --git a/mmdet/models/utils/make_divisible.py b/mmdet/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..ed42c2eeea2a6aed03a0be5516b8d1ef1139e486 --- /dev/null +++ b/mmdet/models/utils/make_divisible.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmdet/models/utils/misc.py b/mmdet/models/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf429153ba7e0be025396b069aef8212144e34d --- /dev/null +++ b/mmdet/models/utils/misc.py @@ -0,0 +1,697 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from mmengine.structures import InstanceData +from mmengine.utils import digit_version +from six.moves import map, zip +from torch import Tensor +from torch.autograd import Function +from torch.nn import functional as F + +from mmdet.structures import SampleList +from mmdet.structures.bbox import BaseBoxes, get_box_type, stack_boxes +from mmdet.structures.mask import BitmapMasks, PolygonMasks +from mmdet.utils import OptInstanceList + + +class SigmoidGeometricMean(Function): + """Forward and backward function of geometric mean of two sigmoid + functions. + + This implementation with analytical gradient function substitutes + the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The + original implementation incurs none during gradient backprapagation + if both x and y are very small values. + """ + + @staticmethod + def forward(ctx, x, y): + x_sigmoid = x.sigmoid() + y_sigmoid = y.sigmoid() + z = (x_sigmoid * y_sigmoid).sqrt() + ctx.save_for_backward(x_sigmoid, y_sigmoid, z) + return z + + @staticmethod + def backward(ctx, grad_output): + x_sigmoid, y_sigmoid, z = ctx.saved_tensors + grad_x = grad_output * z * (1 - x_sigmoid) / 2 + grad_y = grad_output * z * (1 - y_sigmoid) / 2 + return grad_x, grad_y + + +sigmoid_geometric_mean = SigmoidGeometricMean.apply + + +def interpolate_as(source, target, mode='bilinear', align_corners=False): + """Interpolate the `source` to the shape of the `target`. + + The `source` must be a Tensor, but the `target` can be a Tensor or a + np.ndarray with the shape (..., target_h, target_w). + + Args: + source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or + (N, C, H, W). + target (Tensor | np.ndarray): The interpolation target with the shape + (..., target_h, target_w). + mode (str): Algorithm used for interpolation. The options are the + same as those in F.interpolate(). Default: ``'bilinear'``. + align_corners (bool): The same as the argument in F.interpolate(). + + Returns: + Tensor: The interpolated source Tensor. + """ + assert len(target.shape) >= 2 + + def _interpolate_as(source, target, mode='bilinear', align_corners=False): + """Interpolate the `source` (4D) to the shape of the `target`.""" + target_h, target_w = target.shape[-2:] + source_h, source_w = source.shape[-2:] + if target_h != source_h or target_w != source_w: + source = F.interpolate( + source, + size=(target_h, target_w), + mode=mode, + align_corners=align_corners) + return source + + if len(source.shape) == 3: + source = source[:, None, :, :] + source = _interpolate_as(source, target, mode, align_corners) + return source[:, 0, :, :] + else: + return _interpolate_as(source, target, mode, align_corners) + + +def unpack_gt_instances(batch_data_samples: SampleList) -> tuple: + """Unpack ``gt_instances``, ``gt_instances_ignore`` and ``img_metas`` based + on ``batch_data_samples`` + + Args: + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple: + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + - batch_gt_instances_ignore (list[:obj:`InstanceData`]): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + - batch_img_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + """ + batch_gt_instances = [] + batch_gt_instances_ignore = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if 'ignored_instances' in data_sample: + batch_gt_instances_ignore.append(data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + + return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas + + +def empty_instances(batch_img_metas: List[dict], + device: torch.device, + task_type: str, + instance_results: OptInstanceList = None, + mask_thr_binary: Union[int, float] = 0, + box_type: Union[str, type] = 'hbox', + use_box_type: bool = False, + num_classes: int = 80, + score_per_cls: bool = False) -> List[InstanceData]: + """Handle predicted instances when RoI is empty. + + Note: If ``instance_results`` is not None, it will be modified + in place internally, and then return ``instance_results`` + + Args: + batch_img_metas (list[dict]): List of image information. + device (torch.device): Device of tensor. + task_type (str): Expected returned task type. it currently + supports bbox and mask. + instance_results (list[:obj:`InstanceData`]): List of instance + results. + mask_thr_binary (int, float): mask binarization threshold. + Defaults to 0. + box_type (str or type): The empty box type. Defaults to `hbox`. + use_box_type (bool): Whether to warp boxes with the box type. + Defaults to False. + num_classes (int): num_classes of bbox_head. Defaults to 80. + score_per_cls (bool): Whether to generate classwise score for + the empty instance. ``score_per_cls`` will be True when the model + needs to produce raw results without nms. Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + """ + assert task_type in ('bbox', 'mask'), 'Only support bbox and mask,' \ + f' but got {task_type}' + + if instance_results is not None: + assert len(instance_results) == len(batch_img_metas) + + results_list = [] + for img_id in range(len(batch_img_metas)): + if instance_results is not None: + results = instance_results[img_id] + assert isinstance(results, InstanceData) + else: + results = InstanceData() + + if task_type == 'bbox': + _, box_type = get_box_type(box_type) + bboxes = torch.zeros(0, box_type.box_dim, device=device) + if use_box_type: + bboxes = box_type(bboxes, clone=False) + results.bboxes = bboxes + score_shape = (0, num_classes + 1) if score_per_cls else (0, ) + results.scores = torch.zeros(score_shape, device=device) + results.labels = torch.zeros((0, ), + device=device, + dtype=torch.long) + else: + # TODO: Handle the case where rescale is false + img_h, img_w = batch_img_metas[img_id]['ori_shape'][:2] + # the type of `im_mask` will be torch.bool or torch.uint8, + # where uint8 if for visualization and debugging. + im_mask = torch.zeros( + 0, + img_h, + img_w, + device=device, + dtype=torch.bool if mask_thr_binary >= 0 else torch.uint8) + results.masks = im_mask + results_list.append(results) + return results_list + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def unmap(data, count, inds, fill=0): + """Unmap a subset of item (data) back to the original set of items (of size + count)""" + if data.dim() == 1: + ret = data.new_full((count, ), fill) + ret[inds.type(torch.bool)] = data + else: + new_size = (count, ) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds.type(torch.bool), :] = data + return ret + + +def mask2ndarray(mask): + """Convert Mask to ndarray.. + + Args: + mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or + torch.Tensor or np.ndarray): The mask to be converted. + + Returns: + np.ndarray: Ndarray mask of shape (n, h, w) that has been converted + """ + if isinstance(mask, (BitmapMasks, PolygonMasks)): + mask = mask.to_ndarray() + elif isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + elif not isinstance(mask, np.ndarray): + raise TypeError(f'Unsupported {type(mask)} data type') + return mask + + +def flip_tensor(src_tensor, flip_direction): + """flip tensor base on flip_direction. + + Args: + src_tensor (Tensor): input feature map, shape (B, C, H, W). + flip_direction (str): The flipping direction. Options are + 'horizontal', 'vertical', 'diagonal'. + + Returns: + out_tensor (Tensor): Flipped tensor. + """ + assert src_tensor.ndim == 4 + valid_directions = ['horizontal', 'vertical', 'diagonal'] + assert flip_direction in valid_directions + if flip_direction == 'horizontal': + out_tensor = torch.flip(src_tensor, [3]) + elif flip_direction == 'vertical': + out_tensor = torch.flip(src_tensor, [2]) + else: + out_tensor = torch.flip(src_tensor, [2, 3]) + return out_tensor + + +def select_single_mlvl(mlvl_tensors, batch_id, detach=True): + """Extract a multi-scale single image tensor from a multi-scale batch + tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): Batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: Multi-scale single image tensor. + """ + assert isinstance(mlvl_tensors, (list, tuple)) + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id].detach() for i in range(num_levels) + ] + else: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id] for i in range(num_levels) + ] + return mlvl_tensor_list + + +def filter_scores_and_topk(scores, score_thr, topk, results=None): + """Filter results using score threshold and topk candidates. + + Args: + scores (Tensor): The scores, shape (num_bboxes, K). + score_thr (float): The score filter threshold. + topk (int): The number of topk candidates. + results (dict or list or Tensor, Optional): The results to + which the filtering rule is to be applied. The shape + of each item is (num_bboxes, N). + + Returns: + tuple: Filtered results + + - scores (Tensor): The scores after being filtered, \ + shape (num_bboxes_filtered, ). + - labels (Tensor): The class labels, shape \ + (num_bboxes_filtered, ). + - anchor_idxs (Tensor): The anchor indexes, shape \ + (num_bboxes_filtered, ). + - filtered_results (dict or list or Tensor, Optional): \ + The filtered results. The shape of each item is \ + (num_bboxes_filtered, N). + """ + valid_mask = scores > score_thr + scores = scores[valid_mask] + valid_idxs = torch.nonzero(valid_mask) + + num_topk = min(topk, valid_idxs.size(0)) + # torch.sort is actually faster than .topk (at least on GPUs) + scores, idxs = scores.sort(descending=True) + scores = scores[:num_topk] + topk_idxs = valid_idxs[idxs[:num_topk]] + keep_idxs, labels = topk_idxs.unbind(dim=1) + + filtered_results = None + if results is not None: + if isinstance(results, dict): + filtered_results = {k: v[keep_idxs] for k, v in results.items()} + elif isinstance(results, list): + filtered_results = [result[keep_idxs] for result in results] + elif isinstance(results, torch.Tensor): + filtered_results = results[keep_idxs] + else: + raise NotImplementedError(f'Only supports dict or list or Tensor, ' + f'but get {type(results)}.') + return scores, labels, keep_idxs, filtered_results + + +def center_of_mass(mask, esp=1e-6): + """Calculate the centroid coordinates of the mask. + + Args: + mask (Tensor): The mask to be calculated, shape (h, w). + esp (float): Avoid dividing by zero. Default: 1e-6. + + Returns: + tuple[Tensor]: the coordinates of the center point of the mask. + + - center_h (Tensor): the center point of the height. + - center_w (Tensor): the center point of the width. + """ + h, w = mask.shape + grid_h = torch.arange(h, device=mask.device)[:, None] + grid_w = torch.arange(w, device=mask.device) + normalizer = mask.sum().float().clamp(min=esp) + center_h = (mask * grid_h).sum() / normalizer + center_w = (mask * grid_w).sum() / normalizer + return center_h, center_w + + +def generate_coordinate(featmap_sizes, device='cuda'): + """Generate the coordinate. + + Args: + featmap_sizes (tuple): The feature to be calculated, + of shape (N, C, W, H). + device (str): The device where the feature will be put on. + Returns: + coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H). + """ + + x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) + y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) + y, x = torch.meshgrid(y_range, x_range) + y = y.expand([featmap_sizes[0], 1, -1, -1]) + x = x.expand([featmap_sizes[0], 1, -1, -1]) + coord_feat = torch.cat([x, y], 1) + + return coord_feat + + +def levels_to_images(mlvl_tensor: List[torch.Tensor]) -> List[torch.Tensor]: + """Concat multi-level feature maps by image. + + [feature_level0, feature_level1...] -> [feature_image0, feature_image1...] + Convert the shape of each element in mlvl_tensor from (N, C, H, W) to + (N, H*W , C), then split the element to N elements with shape (H*W, C), and + concat elements in same image of all level along first dimension. + + Args: + mlvl_tensor (list[Tensor]): list of Tensor which collect from + corresponding level. Each element is of shape (N, C, H, W) + + Returns: + list[Tensor]: A list that contains N tensors and each tensor is + of shape (num_elements, C) + """ + batch_size = mlvl_tensor[0].size(0) + batch_list = [[] for _ in range(batch_size)] + channels = mlvl_tensor[0].size(1) + for t in mlvl_tensor: + t = t.permute(0, 2, 3, 1) + t = t.view(batch_size, -1, channels).contiguous() + for img in range(batch_size): + batch_list[img].append(t[img]) + return [torch.cat(item, 0) for item in batch_list] + + +def images_to_levels(target, num_levels): + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + target = stack_boxes(target, 0) + level_targets = [] + start = 0 + for n in num_levels: + end = start + n + # level_targets.append(target[:, start:end].squeeze(0)) + level_targets.append(target[:, start:end]) + start = end + return level_targets + + +def samplelist_boxtype2tensor(batch_data_samples: SampleList) -> SampleList: + for data_samples in batch_data_samples: + if 'gt_instances' in data_samples: + bboxes = data_samples.gt_instances.get('bboxes', None) + if isinstance(bboxes, BaseBoxes): + data_samples.gt_instances.bboxes = bboxes.tensor + if 'pred_instances' in data_samples: + bboxes = data_samples.pred_instances.get('bboxes', None) + if isinstance(bboxes, BaseBoxes): + data_samples.pred_instances.bboxes = bboxes.tensor + if 'ignored_instances' in data_samples: + bboxes = data_samples.ignored_instances.get('bboxes', None) + if isinstance(bboxes, BaseBoxes): + data_samples.ignored_instances.bboxes = bboxes.tensor + + +_torch_version_div_indexing = ( + 'parrots' not in torch.__version__ + and digit_version(torch.__version__) >= digit_version('1.8')) + + +def floordiv(dividend, divisor, rounding_mode='trunc'): + if _torch_version_div_indexing: + return torch.div(dividend, divisor, rounding_mode=rounding_mode) + else: + return dividend // divisor + + +def _filter_gt_instances_by_score(batch_data_samples: SampleList, + score_thr: float) -> SampleList: + """Filter ground truth (GT) instances by score. + + Args: + batch_data_samples (SampleList): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + score_thr (float): The score filter threshold. + + Returns: + SampleList: The Data Samples filtered by score. + """ + for data_samples in batch_data_samples: + assert 'scores' in data_samples.gt_instances, \ + 'there does not exit scores in instances' + if data_samples.gt_instances.bboxes.shape[0] > 0: + data_samples.gt_instances = data_samples.gt_instances[ + data_samples.gt_instances.scores > score_thr] + return batch_data_samples + + +def _filter_gt_instances_by_size(batch_data_samples: SampleList, + wh_thr: tuple) -> SampleList: + """Filter ground truth (GT) instances by size. + + Args: + batch_data_samples (SampleList): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + wh_thr (tuple): Minimum width and height of bbox. + + Returns: + SampleList: The Data Samples filtered by score. + """ + for data_samples in batch_data_samples: + bboxes = data_samples.gt_instances.bboxes + if bboxes.shape[0] > 0: + w = bboxes[:, 2] - bboxes[:, 0] + h = bboxes[:, 3] - bboxes[:, 1] + data_samples.gt_instances = data_samples.gt_instances[ + (w > wh_thr[0]) & (h > wh_thr[1])] + return batch_data_samples + + +def filter_gt_instances(batch_data_samples: SampleList, + score_thr: float = None, + wh_thr: tuple = None): + """Filter ground truth (GT) instances by score and/or size. + + Args: + batch_data_samples (SampleList): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + score_thr (float): The score filter threshold. + wh_thr (tuple): Minimum width and height of bbox. + + Returns: + SampleList: The Data Samples filtered by score and/or size. + """ + + if score_thr is not None: + batch_data_samples = _filter_gt_instances_by_score( + batch_data_samples, score_thr) + if wh_thr is not None: + batch_data_samples = _filter_gt_instances_by_size( + batch_data_samples, wh_thr) + return batch_data_samples + + +def rename_loss_dict(prefix: str, losses: dict) -> dict: + """Rename the key names in loss dict by adding a prefix. + + Args: + prefix (str): The prefix for loss components. + losses (dict): A dictionary of loss components. + + Returns: + dict: A dictionary of loss components with prefix. + """ + return {prefix + k: v for k, v in losses.items()} + + +def reweight_loss_dict(losses: dict, weight: float) -> dict: + """Reweight losses in the dict by weight. + + Args: + losses (dict): A dictionary of loss components. + weight (float): Weight for loss components. + + Returns: + dict: A dictionary of weighted loss components. + """ + for name, loss in losses.items(): + if 'loss' in name: + if isinstance(loss, Sequence): + losses[name] = [item * weight for item in loss] + else: + losses[name] = loss * weight + return losses + + +def relative_coordinate_maps( + locations: Tensor, + centers: Tensor, + strides: Tensor, + size_of_interest: int, + feat_sizes: Tuple[int], +) -> Tensor: + """Generate the relative coordinate maps with feat_stride. + + Args: + locations (Tensor): The prior location of mask feature map. + It has shape (num_priors, 2). + centers (Tensor): The prior points of a object in + all feature pyramid. It has shape (num_pos, 2) + strides (Tensor): The prior strides of a object in + all feature pyramid. It has shape (num_pos, 1) + size_of_interest (int): The size of the region used in rel coord. + feat_sizes (Tuple[int]): The feature size H and W, which has 2 dims. + Returns: + rel_coord_feat (Tensor): The coordinate feature + of shape (num_pos, 2, H, W). + """ + + H, W = feat_sizes + rel_coordinates = centers.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) + rel_coordinates = rel_coordinates.permute(0, 2, 1).float() + rel_coordinates = rel_coordinates / ( + strides[:, None, None] * size_of_interest) + return rel_coordinates.reshape(-1, 2, H, W) + + +def aligned_bilinear(tensor: Tensor, factor: int) -> Tensor: + """aligned bilinear, used in original implement in CondInst: + + https://github.com/aim-uofa/AdelaiDet/blob/\ + c0b2092ce72442b0f40972f7c6dda8bb52c46d16/adet/utils/comm.py#L23 + """ + + assert tensor.dim() == 4 + assert factor >= 1 + assert int(factor) == factor + + if factor == 1: + return tensor + + h, w = tensor.size()[2:] + tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode='replicate') + oh = factor * h + 1 + ow = factor * w + 1 + tensor = F.interpolate( + tensor, size=(oh, ow), mode='bilinear', align_corners=True) + tensor = F.pad( + tensor, pad=(factor // 2, 0, factor // 2, 0), mode='replicate') + + return tensor[:, :, :oh - 1, :ow - 1] + + +def unfold_wo_center(x, kernel_size: int, dilation: int) -> Tensor: + """unfold_wo_center, used in original implement in BoxInst: + + https://github.com/aim-uofa/AdelaiDet/blob/\ + 4a3a1f7372c35b48ebf5f6adc59f135a0fa28d60/\ + adet/modeling/condinst/condinst.py#L53 + """ + assert x.dim() == 4 + assert kernel_size % 2 == 1 + + # using SAME padding + padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 + unfolded_x = F.unfold( + x, kernel_size=kernel_size, padding=padding, dilation=dilation) + unfolded_x = unfolded_x.reshape( + x.size(0), x.size(1), -1, x.size(2), x.size(3)) + # remove the center pixels + size = kernel_size**2 + unfolded_x = torch.cat( + (unfolded_x[:, :, :size // 2], unfolded_x[:, :, size // 2 + 1:]), + dim=2) + + return unfolded_x + + +def padding_to(input_tensor: Tensor, max_len: int = 300) -> Tensor: + """Pad the first dimension of `input_tensor` to `max_len`. + + Args: + input_tensor (Tensor): The tensor to be padded, + max_len (int): Padding target size in the first dimension. + Default: 300 + https://github.com/jshilong/DDQ/blob/ddq_detr/projects/models/utils.py#L19 + Returns: + Tensor: The tensor padded with the first dimension size `max_len`. + """ + if max_len is None: + return input_tensor + num_padding = max_len - len(input_tensor) + if input_tensor.dim() > 1: + padding = input_tensor.new_zeros( + num_padding, *input_tensor.size()[1:], dtype=input_tensor.dtype) + else: + padding = input_tensor.new_zeros(num_padding, dtype=input_tensor.dtype) + output_tensor = torch.cat([input_tensor, padding], dim=0) + return output_tensor + + +def align_tensor(inputs: List[Tensor], + max_len: Optional[int] = None) -> Tensor: + """Pad each input to `max_len`, then stack them. If `max_len` is None, then + it is the max size of the first dimension of each input. + + https://github.com/jshilong/DDQ/blob/ddq_detr/projects/models/\ + utils.py#L12 + + Args: + inputs (list[Tensor]): The tensors to be padded, + Each input should have the same shape except the first dimension. + max_len (int): Padding target size in the first dimension. + Default: None + Returns: + Tensor: Stacked inputs after padding in the first dimension. + """ + if max_len is None: + max_len = max([len(item) for item in inputs]) + + return torch.stack([padding_to(item, max_len) for item in inputs]) diff --git a/mmdet/models/utils/panoptic_gt_processing.py b/mmdet/models/utils/panoptic_gt_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3bc95fc04040b4a2a13fa63f2d02f092f725e6 --- /dev/null +++ b/mmdet/models/utils/panoptic_gt_processing.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch import Tensor + + +def preprocess_panoptic_gt(gt_labels: Tensor, gt_masks: Tensor, + gt_semantic_seg: Tensor, num_things: int, + num_stuff: int) -> Tuple[Tensor, Tensor]: + """Preprocess the ground truth for a image. + + Args: + gt_labels (Tensor): Ground truth labels of each bbox, + with shape (num_gts, ). + gt_masks (BitmapMasks): Ground truth masks of each instances + of a image, shape (num_gts, h, w). + gt_semantic_seg (Tensor | None): Ground truth of semantic + segmentation with the shape (1, h, w). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. It's None when training instance segmentation. + + Returns: + tuple[Tensor, Tensor]: a tuple containing the following targets. + + - labels (Tensor): Ground truth class indices for a + image, with shape (n, ), n is the sum of number + of stuff type and number of instance in a image. + - masks (Tensor): Ground truth mask for a image, with + shape (n, h, w). Contains stuff and things when training + panoptic segmentation, and things only when training + instance segmentation. + """ + num_classes = num_things + num_stuff + things_masks = gt_masks.to_tensor( + dtype=torch.bool, device=gt_labels.device) + + if gt_semantic_seg is None: + masks = things_masks.long() + return gt_labels, masks + + things_labels = gt_labels + gt_semantic_seg = gt_semantic_seg.squeeze(0) + + semantic_labels = torch.unique( + gt_semantic_seg, + sorted=False, + return_inverse=False, + return_counts=False) + stuff_masks_list = [] + stuff_labels_list = [] + for label in semantic_labels: + if label < num_things or label >= num_classes: + continue + stuff_mask = gt_semantic_seg == label + stuff_masks_list.append(stuff_mask) + stuff_labels_list.append(label) + + if len(stuff_masks_list) > 0: + stuff_masks = torch.stack(stuff_masks_list, dim=0) + stuff_labels = torch.stack(stuff_labels_list, dim=0) + labels = torch.cat([things_labels, stuff_labels], dim=0) + masks = torch.cat([things_masks, stuff_masks], dim=0) + else: + labels = things_labels + masks = things_masks + + masks = masks.long() + return labels, masks diff --git a/mmdet/models/utils/point_sample.py b/mmdet/models/utils/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc957f3da7d1dc030c21d40311c768c6952ea4 --- /dev/null +++ b/mmdet/models/utils/point_sample.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import point_sample +from torch import Tensor + + +def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor: + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_preds' for the foreground class in `classes`. + + Args: + mask_preds (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (Tensor): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_preds.shape[1] == 1: + gt_class_logits = mask_preds.clone() + else: + inds = torch.arange(mask_preds.shape[0], device=mask_preds.device) + gt_class_logits = mask_preds[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_preds: Tensor, labels: Tensor, num_points: int, + oversample_ratio: float, importance_sample_ratio: float) -> Tensor: + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_preds (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (Tensor): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (float): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_preds.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=mask_preds.device) + point_logits = point_sample(mask_preds, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=mask_preds.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand( + batch_size, num_random_points, 2, device=mask_preds.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/mmdet/models/utils/vlfuse_helper.py b/mmdet/models/utils/vlfuse_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..76b54de317c1f24d7cb40573954f988fd94fef42 --- /dev/null +++ b/mmdet/models/utils/vlfuse_helper.py @@ -0,0 +1,773 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/utils/fuse_helper.py # noqa +# and https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/modeling/rpn/modeling_bert.py # noqa +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath +from torch import Tensor + +try: + from transformers import BertConfig, BertPreTrainedModel + from transformers.modeling_utils import apply_chunking_to_forward + from transformers.models.bert.modeling_bert import \ + BertAttention as HFBertAttention + from transformers.models.bert.modeling_bert import \ + BertIntermediate as HFBertIntermediate + from transformers.models.bert.modeling_bert import \ + BertOutput as HFBertOutput +except ImportError: + BertConfig = None + BertPreTrainedModel = object + apply_chunking_to_forward = None + HFBertAttention = object + HFBertIntermediate = object + HFBertOutput = object + +MAX_CLAMP_VALUE = 50000 + + +def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, + W: int) -> Tensor: + """Permute and then flatten a tensor, + + from size (N, A, C, H, W) to (N, H * W * A, C). + + Args: + layer (Tensor): Tensor of shape (N, C, H, W). + N (int): Batch size. + A (int): Number of attention heads. + C (int): Number of channels. + H (int): Height of feature map. + W (int): Width of feature map. + + Returns: + Tensor: A Tensor of shape (N, H * W * A, C). + """ + layer = layer.view(N, A, C, H, W) + layer = layer.permute(0, 3, 4, 1, 2) + layer = layer.reshape(N, -1, C) + return layer + + +def clamp_values(vector: Tensor) -> Tensor: + """Clamp the values of a vector to the range [-MAX_CLAMP_VALUE, + MAX_CLAMP_VALUE]. + + Args: + vector (Tensor): Tensor of shape (N, C, H, W). + + Returns: + Tensor: A Tensor of shape (N, C, H, W) with clamped values. + """ + vector = torch.clamp(vector, min=-MAX_CLAMP_VALUE, max=MAX_CLAMP_VALUE) + return vector + + +class BiMultiHeadAttention(nn.Module): + """Bidirectional fusion Multi-Head Attention layer. + + Args: + v_dim (int): The dimension of the vision input. + l_dim (int): The dimension of the language input. + embed_dim (int): The embedding dimension for the attention operation. + num_heads (int): The number of attention heads. + dropout (float, optional): The dropout probability. Defaults to 0.1. + """ + + def __init__(self, + v_dim: int, + l_dim: int, + embed_dim: int, + num_heads: int, + dropout: float = 0.1): + super(BiMultiHeadAttention, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.v_dim = v_dim + self.l_dim = l_dim + + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), 'embed_dim must be divisible by num_heads ' \ + f'(got `embed_dim`: {self.embed_dim} ' \ + f'and `num_heads`: {self.num_heads}).' + self.scale = self.head_dim**(-0.5) + self.dropout = dropout + + self.v_proj = nn.Linear(self.v_dim, self.embed_dim) + self.l_proj = nn.Linear(self.l_dim, self.embed_dim) + self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) + self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) + + self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) + self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) + + self.stable_softmax_2d = False + self.clamp_min_for_underflow = True + self.clamp_max_for_overflow = True + + self._reset_parameters() + + def _shape(self, tensor: Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def _reset_parameters(self): + nn.init.xavier_uniform_(self.v_proj.weight) + self.v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.l_proj.weight) + self.l_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.values_v_proj.weight) + self.values_v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.values_l_proj.weight) + self.values_l_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_v_proj.weight) + self.out_v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_l_proj.weight) + self.out_l_proj.bias.data.fill_(0) + + def forward( + self, + vision: Tensor, + lang: Tensor, + attention_mask_v: Optional[Tensor] = None, + attention_mask_l: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + bsz, tgt_len, _ = vision.size() + + query_states = self.v_proj(vision) * self.scale + key_states = self._shape(self.l_proj(lang), -1, bsz) + value_v_states = self._shape(self.values_v_proj(vision), -1, bsz) + value_l_states = self._shape(self.values_l_proj(lang), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_v_states = value_v_states.view(*proj_shape) + value_l_states = value_l_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f'Attention weights should be of ' + f'size {(bsz * self.num_heads, tgt_len, src_len)}, ' + f'but is {attn_weights.size()}') + + if self.stable_softmax_2d: + attn_weights = attn_weights - attn_weights.max() + + if self.clamp_min_for_underflow: + # Do not increase -50000, data type half has quite limited range + attn_weights = torch.clamp(attn_weights, min=-MAX_CLAMP_VALUE) + if self.clamp_max_for_overflow: + # Do not increase 50000, data type half has quite limited range + attn_weights = torch.clamp(attn_weights, max=MAX_CLAMP_VALUE) + + attn_weights_T = attn_weights.transpose(1, 2) + attn_weights_l = ( + attn_weights_T - + torch.max(attn_weights_T, dim=-1, keepdim=True)[0]) + if self.clamp_min_for_underflow: + # Do not increase -50000, data type half has quite limited range + attn_weights_l = torch.clamp(attn_weights_l, min=-MAX_CLAMP_VALUE) + if self.clamp_max_for_overflow: + # Do not increase 50000, data type half has quite limited range + attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE) + + if attention_mask_v is not None: + attention_mask_v = ( + attention_mask_v[:, None, + None, :].repeat(1, self.num_heads, 1, + 1).flatten(0, 1)) + attn_weights_l.masked_fill_(attention_mask_v, float('-inf')) + + attn_weights_l = attn_weights_l.softmax(dim=-1) + + if attention_mask_l is not None: + assert (attention_mask_l.dim() == 2) + attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) + attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) + attention_mask = attention_mask.masked_fill( + attention_mask == 0, -9e15) + + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError('Attention mask should be of ' + f'size {(bsz, 1, tgt_len, src_len)}') + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) + + attn_probs_v = F.dropout( + attn_weights_v, p=self.dropout, training=self.training) + attn_probs_l = F.dropout( + attn_weights_l, p=self.dropout, training=self.training) + + attn_output_v = torch.bmm(attn_probs_v, value_l_states) + attn_output_l = torch.bmm(attn_probs_l, value_v_states) + + if attn_output_v.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + '`attn_output_v` should be of ' + f'size {(bsz, self.num_heads, tgt_len, self.head_dim)}, ' + f'but is {attn_output_v.size()}') + + if attn_output_l.size() != (bsz * self.num_heads, src_len, + self.head_dim): + raise ValueError( + '`attn_output_l` should be of size ' + f'{(bsz, self.num_heads, src_len, self.head_dim)}, ' + f'but is {attn_output_l.size()}') + + attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output_v = attn_output_v.transpose(1, 2) + attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) + + attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, + self.head_dim) + attn_output_l = attn_output_l.transpose(1, 2) + attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) + + attn_output_v = self.out_v_proj(attn_output_v) + attn_output_l = self.out_l_proj(attn_output_l) + + return attn_output_v, attn_output_l + + +class BiAttentionBlock(nn.Module): + """BiAttentionBlock Module: + + First, multi-level visual features are concat; Then the concat visual + feature and lang feature are fused by attention; Finally the newly visual + feature are split into multi levels. + + Args: + v_dim (int): The dimension of the visual features. + l_dim (int): The dimension of the language feature. + embed_dim (int): The embedding dimension for the attention operation. + num_heads (int): The number of attention heads. + dropout (float, optional): The dropout probability. Defaults to 0.1. + drop_path (float, optional): The drop path probability. + Defaults to 0.0. + init_values (float, optional): + The initial value for the scaling parameter. + Defaults to 1e-4. + """ + + def __init__(self, + v_dim: int, + l_dim: int, + embed_dim: int, + num_heads: int, + dropout: float = 0.1, + drop_path: float = .0, + init_values: float = 1e-4): + super().__init__() + + # pre layer norm + self.layer_norm_v = nn.LayerNorm(v_dim) + self.layer_norm_l = nn.LayerNorm(l_dim) + self.attn = BiMultiHeadAttention( + v_dim=v_dim, + l_dim=l_dim, + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout) + + # add layer scale for training stability + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.gamma_v = nn.Parameter( + init_values * torch.ones(v_dim), requires_grad=True) + self.gamma_l = nn.Parameter( + init_values * torch.ones(l_dim), requires_grad=True) + + def forward(self, + vf0: Tensor, + vf1: Tensor, + vf2: Tensor, + vf3: Tensor, + vf4: Tensor, + lang_feature: Tensor, + attention_mask_l=None): + visual_features = [vf0, vf1, vf2, vf3, vf4] + size_per_level, visual_features_flatten = [], [] + for i, feat_per_level in enumerate(visual_features): + bs, c, h, w = feat_per_level.shape + size_per_level.append([h, w]) + feat = permute_and_flatten(feat_per_level, bs, -1, c, h, w) + visual_features_flatten.append(feat) + visual_features_flatten = torch.cat(visual_features_flatten, dim=1) + new_v, new_lang_feature = self.single_attention_call( + visual_features_flatten, + lang_feature, + attention_mask_l=attention_mask_l) + # [bs, N, C] -> [bs, C, N] + new_v = new_v.transpose(1, 2).contiguous() + + start = 0 + # fvfs is mean fusion_visual_features + fvfs = [] + for (h, w) in size_per_level: + new_v_per_level = new_v[:, :, + start:start + h * w].view(bs, -1, h, + w).contiguous() + fvfs.append(new_v_per_level) + start += h * w + + return fvfs[0], fvfs[1], fvfs[2], fvfs[3], fvfs[4], new_lang_feature + + def single_attention_call( + self, + visual: Tensor, + lang: Tensor, + attention_mask_v: Optional[Tensor] = None, + attention_mask_l: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Perform a single attention call between the visual and language + inputs. + + Args: + visual (Tensor): The visual input tensor. + lang (Tensor): The language input tensor. + attention_mask_v (Optional[Tensor]): + An optional attention mask tensor for the visual input. + attention_mask_l (Optional[Tensor]): + An optional attention mask tensor for the language input. + + Returns: + Tuple[Tensor, Tensor]: A tuple containing the updated + visual and language tensors after the attention call. + """ + visual = self.layer_norm_v(visual) + lang = self.layer_norm_l(lang) + delta_v, delta_l = self.attn( + visual, + lang, + attention_mask_v=attention_mask_v, + attention_mask_l=attention_mask_l) + # visual, lang = visual + delta_v, l + delta_l + visual = visual + self.drop_path(self.gamma_v * delta_v) + lang = lang + self.drop_path(self.gamma_l * delta_l) + return visual, lang + + +class SingleScaleBiAttentionBlock(BiAttentionBlock): + """This is a single-scale implementation of `BiAttentionBlock`. + + The only differenece between it and `BiAttentionBlock` is that the + `forward` function of `SingleScaleBiAttentionBlock` only accepts a single + flatten visual feature map, while the `forward` function in + `BiAttentionBlock` accepts multiple visual feature maps. + """ + + def forward(self, + visual_feature: Tensor, + lang_feature: Tensor, + attention_mask_v=None, + attention_mask_l=None): + """Single-scale forward pass. + + Args: + visual_feature (Tensor): The visual input tensor. Tensor of + shape (bs, patch_len, ch). + lang_feature (Tensor): The language input tensor. Tensor of + shape (bs, text_len, ch). + attention_mask_v (_type_, optional): Visual feature attention + mask. Defaults to None. + attention_mask_l (_type_, optional): Language feature attention + mask.Defaults to None. + """ + new_v, new_lang_feature = self.single_attention_call( + visual_feature, + lang_feature, + attention_mask_v=attention_mask_v, + attention_mask_l=attention_mask_l) + return new_v, new_lang_feature + + +class VLFuse(nn.Module): + """Early Fusion Module. + + Args: + v_dim (int): Dimension of visual features. + l_dim (int): Dimension of language features. + embed_dim (int): The embedding dimension for the attention operation. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + drop_path (float): Drop path probability. + use_checkpoint (bool): Whether to use PyTorch's checkpoint function. + """ + + def __init__(self, + v_dim: int = 256, + l_dim: int = 768, + embed_dim: int = 2048, + num_heads: int = 8, + dropout: float = 0.1, + drop_path: float = 0.0, + use_checkpoint: bool = False): + super().__init__() + self.use_checkpoint = use_checkpoint + self.b_attn = BiAttentionBlock( + v_dim=v_dim, + l_dim=l_dim, + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + drop_path=drop_path, + init_values=1.0 / 6.0) + + def forward(self, x: dict) -> dict: + """Forward pass of the VLFuse module.""" + visual_features = x['visual'] + language_dict_features = x['lang'] + + if self.use_checkpoint: + # vf is mean visual_features + # checkpoint does not allow complex data structures as input, + # such as list, so we must split them. + vf0, vf1, vf2, vf3, vf4, language_features = checkpoint.checkpoint( + self.b_attn, *visual_features, + language_dict_features['hidden'], + language_dict_features['masks']) + else: + vf0, vf1, vf2, vf3, vf4, language_features = self.b_attn( + *visual_features, language_dict_features['hidden'], + language_dict_features['masks']) + + language_dict_features['hidden'] = language_features + fused_language_dict_features = language_dict_features + + features_dict = { + 'visual': [vf0, vf1, vf2, vf3, vf4], + 'lang': fused_language_dict_features + } + + return features_dict + + +class BertEncoderLayer(BertPreTrainedModel): + """A modified version of the `BertLayer` class from the + `transformers.models.bert.modeling_bert` module. + + Args: + config (:class:`~transformers.BertConfig`): + The configuration object that + contains various parameters for the model. + clamp_min_for_underflow (bool, optional): + Whether to clamp the minimum value of the hidden states + to prevent underflow. Defaults to `False`. + clamp_max_for_overflow (bool, optional): + Whether to clamp the maximum value of the hidden states + to prevent overflow. Defaults to `False`. + """ + + def __init__(self, + config: BertConfig, + clamp_min_for_underflow: bool = False, + clamp_max_for_overflow: bool = False): + super().__init__(config) + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + self.attention = BertAttention(config, clamp_min_for_underflow, + clamp_max_for_overflow) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, inputs: Dict[str, Dict[str, torch.Tensor]] + ) -> Dict[str, Dict[str, torch.Tensor]]: + """Applies the BertEncoderLayer to the input features.""" + language_dict_features = inputs['lang'] + hidden_states = language_dict_features['hidden'] + attention_mask = language_dict_features['masks'] + + device = hidden_states.device + input_shape = hidden_states.size()[:-1] + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device) + + self_attention_outputs = self.attention( + hidden_states, + extended_attention_mask, + None, + output_attentions=False, + past_key_value=None) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + hidden_states = outputs[0] + + language_dict_features['hidden'] = hidden_states + + features_dict = { + 'visual': inputs['visual'], + 'lang': language_dict_features + } + + return features_dict + + def feed_forward_chunk(self, attention_output: Tensor) -> Tensor: + """Applies the intermediate and output layers of the BertEncoderLayer + to a chunk of the input sequence.""" + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# The following code is the same as the Huggingface code, +# with the only difference being the additional clamp operation. +class BertSelfAttention(nn.Module): + """BERT self-attention layer from Huggingface transformers. + + Compared to the BertSelfAttention of Huggingface, only add the clamp. + + Args: + config (:class:`~transformers.BertConfig`): + The configuration object that + contains various parameters for the model. + clamp_min_for_underflow (bool, optional): + Whether to clamp the minimum value of the hidden states + to prevent underflow. Defaults to `False`. + clamp_max_for_overflow (bool, optional): + Whether to clamp the maximum value of the hidden states + to prevent overflow. Defaults to `False`. + """ + + def __init__(self, + config: BertConfig, + clamp_min_for_underflow: bool = False, + clamp_max_for_overflow: bool = False): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and \ + not hasattr(config, 'embedding_size'): + raise ValueError(f'The hidden size ({config.hidden_size}) is ' + 'not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * \ + self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if self.position_embedding_type == 'relative_key' or \ + self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.clamp_min_for_underflow = clamp_min_for_underflow + self.clamp_max_for_overflow = clamp_max_for_overflow + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: Tensor) -> Tensor: + """Transpose the dimensions of `x`.""" + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + encoder_attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, ...]: + """Perform a forward pass through the BERT self-attention layer.""" + + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" + # to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == 'relative_key' or \ + self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + + if self.clamp_min_for_underflow: + attention_scores = torch.clamp( + attention_scores, min=-MAX_CLAMP_VALUE + ) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attention_scores = torch.clamp( + attention_scores, max=MAX_CLAMP_VALUE + ) # Do not increase 50000, data type half has quite limited range + + if attention_mask is not None: + # Apply the attention mask is + # (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class BertAttention(HFBertAttention): + """BertAttention is made up of self-attention and intermediate+output. + + Compared to the BertAttention of Huggingface, only add the clamp. + + Args: + config (:class:`~transformers.BertConfig`): + The configuration object that + contains various parameters for the model. + clamp_min_for_underflow (bool, optional): + Whether to clamp the minimum value of the hidden states + to prevent underflow. Defaults to `False`. + clamp_max_for_overflow (bool, optional): + Whether to clamp the maximum value of the hidden states + to prevent overflow. Defaults to `False`. + """ + + def __init__(self, + config: BertConfig, + clamp_min_for_underflow: bool = False, + clamp_max_for_overflow: bool = False): + super().__init__(config) + self.self = BertSelfAttention(config, clamp_min_for_underflow, + clamp_max_for_overflow) + + +class BertIntermediate(HFBertIntermediate): + """Modified from transformers.models.bert.modeling_bert.BertIntermediate. + + Compared to the BertIntermediate of Huggingface, only add the clamp. + """ + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = clamp_values(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = clamp_values(hidden_states) + return hidden_states + + +class BertOutput(HFBertOutput): + """Modified from transformers.models.bert.modeling_bert.BertOutput. + + Compared to the BertOutput of Huggingface, only add the clamp. + """ + + def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = clamp_values(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = clamp_values(hidden_states) + return hidden_states diff --git a/mmdet/models/utils/wbf.py b/mmdet/models/utils/wbf.py new file mode 100644 index 0000000000000000000000000000000000000000..b26a2c669a520467c6fcf52d0eec53a69834a16a --- /dev/null +++ b/mmdet/models/utils/wbf.py @@ -0,0 +1,250 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import warnings +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor + + +# References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion +def weighted_boxes_fusion( + bboxes_list: list, + scores_list: list, + labels_list: list, + weights: list = None, + iou_thr: float = 0.55, + skip_box_thr: float = 0.0, + conf_type: str = 'avg', + allows_overflow: bool = False) -> Tuple[Tensor, Tensor, Tensor]: + """weighted boxes fusion is a method for + fusing predictions from different object detection models, which utilizes + confidence scores of all proposed bounding boxes to construct averaged + boxes. + + Args: + bboxes_list(list): list of boxes predictions from each model, + each box is 4 numbers. + scores_list(list): list of scores for each model + labels_list(list): list of labels for each model + weights: list of weights for each model. + Default: None, which means weight == 1 for each model + iou_thr: IoU value for boxes to be a match + skip_box_thr: exclude boxes with score lower than this variable. + conf_type: how to calculate confidence in weighted boxes. + 'avg': average value, + 'max': maximum value, + 'box_and_model_avg': box and model wise hybrid weighted average, + 'absent_model_aware_avg': weighted average that takes into + account the absent model. + allows_overflow: false if we want confidence score not exceed 1.0. + + Returns: + bboxes(Tensor): boxes coordinates (Order of boxes: x1, y1, x2, y2). + scores(Tensor): confidence scores + labels(Tensor): boxes labels + """ + + if weights is None: + weights = np.ones(len(bboxes_list)) + if len(weights) != len(bboxes_list): + print('Warning: incorrect number of weights {}. Must be: ' + '{}. Set weights equal to 1.'.format( + len(weights), len(bboxes_list))) + weights = np.ones(len(bboxes_list)) + weights = np.array(weights) + + if conf_type not in [ + 'avg', 'max', 'box_and_model_avg', 'absent_model_aware_avg' + ]: + print('Unknown conf_type: {}. Must be "avg", ' + '"max" or "box_and_model_avg", ' + 'or "absent_model_aware_avg"'.format(conf_type)) + exit() + + filtered_boxes = prefilter_boxes(bboxes_list, scores_list, labels_list, + weights, skip_box_thr) + if len(filtered_boxes) == 0: + return torch.Tensor(), torch.Tensor(), torch.Tensor() + + overall_boxes = [] + + for label in filtered_boxes: + boxes = filtered_boxes[label] + new_boxes = [] + weighted_boxes = np.empty((0, 8)) + + # Clusterize boxes + for j in range(0, len(boxes)): + index, best_iou = find_matching_box_fast(weighted_boxes, boxes[j], + iou_thr) + + if index != -1: + new_boxes[index].append(boxes[j]) + weighted_boxes[index] = get_weighted_box( + new_boxes[index], conf_type) + else: + new_boxes.append([boxes[j].copy()]) + weighted_boxes = np.vstack((weighted_boxes, boxes[j].copy())) + + # Rescale confidence based on number of models and boxes + for i in range(len(new_boxes)): + clustered_boxes = new_boxes[i] + if conf_type == 'box_and_model_avg': + clustered_boxes = np.array(clustered_boxes) + # weighted average for boxes + weighted_boxes[i, 1] = weighted_boxes[i, 1] * len( + clustered_boxes) / weighted_boxes[i, 2] + # identify unique model index by model index column + _, idx = np.unique(clustered_boxes[:, 3], return_index=True) + # rescale by unique model weights + weighted_boxes[i, 1] = weighted_boxes[i, 1] * clustered_boxes[ + idx, 2].sum() / weights.sum() + elif conf_type == 'absent_model_aware_avg': + clustered_boxes = np.array(clustered_boxes) + # get unique model index in the cluster + models = np.unique(clustered_boxes[:, 3]).astype(int) + # create a mask to get unused model weights + mask = np.ones(len(weights), dtype=bool) + mask[models] = False + # absent model aware weighted average + weighted_boxes[ + i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / ( + weighted_boxes[i, 2] + weights[mask].sum()) + elif conf_type == 'max': + weighted_boxes[i, 1] = weighted_boxes[i, 1] / weights.max() + elif not allows_overflow: + weighted_boxes[i, 1] = weighted_boxes[i, 1] * min( + len(weights), len(clustered_boxes)) / weights.sum() + else: + weighted_boxes[i, 1] = weighted_boxes[i, 1] * len( + clustered_boxes) / weights.sum() + overall_boxes.append(weighted_boxes) + overall_boxes = np.concatenate(overall_boxes, axis=0) + overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]] + + bboxes = torch.Tensor(overall_boxes[:, 4:]) + scores = torch.Tensor(overall_boxes[:, 1]) + labels = torch.Tensor(overall_boxes[:, 0]).int() + + return bboxes, scores, labels + + +def prefilter_boxes(boxes, scores, labels, weights, thr): + + new_boxes = dict() + + for t in range(len(boxes)): + + if len(boxes[t]) != len(scores[t]): + print('Error. Length of boxes arrays not equal to ' + 'length of scores array: {} != {}'.format( + len(boxes[t]), len(scores[t]))) + exit() + + if len(boxes[t]) != len(labels[t]): + print('Error. Length of boxes arrays not equal to ' + 'length of labels array: {} != {}'.format( + len(boxes[t]), len(labels[t]))) + exit() + + for j in range(len(boxes[t])): + score = scores[t][j] + if score < thr: + continue + label = int(labels[t][j]) + box_part = boxes[t][j] + x1 = float(box_part[0]) + y1 = float(box_part[1]) + x2 = float(box_part[2]) + y2 = float(box_part[3]) + + # Box data checks + if x2 < x1: + warnings.warn('X2 < X1 value in box. Swap them.') + x1, x2 = x2, x1 + if y2 < y1: + warnings.warn('Y2 < Y1 value in box. Swap them.') + y1, y2 = y2, y1 + if (x2 - x1) * (y2 - y1) == 0.0: + warnings.warn('Zero area box skipped: {}.'.format(box_part)) + continue + + # [label, score, weight, model index, x1, y1, x2, y2] + b = [ + int(label), + float(score) * weights[t], weights[t], t, x1, y1, x2, y2 + ] + + if label not in new_boxes: + new_boxes[label] = [] + new_boxes[label].append(b) + + # Sort each list in dict by score and transform it to numpy array + for k in new_boxes: + current_boxes = np.array(new_boxes[k]) + new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]] + + return new_boxes + + +def get_weighted_box(boxes, conf_type='avg'): + + box = np.zeros(8, dtype=np.float32) + conf = 0 + conf_list = [] + w = 0 + for b in boxes: + box[4:] += (b[1] * b[4:]) + conf += b[1] + conf_list.append(b[1]) + w += b[2] + box[0] = boxes[0][0] + if conf_type in ('avg', 'box_and_model_avg', 'absent_model_aware_avg'): + box[1] = conf / len(boxes) + elif conf_type == 'max': + box[1] = np.array(conf_list).max() + box[2] = w + box[3] = -1 + box[4:] /= conf + + return box + + +def find_matching_box_fast(boxes_list, new_box, match_iou): + + def bb_iou_array(boxes, new_box): + # bb intersection over union + xA = np.maximum(boxes[:, 0], new_box[0]) + yA = np.maximum(boxes[:, 1], new_box[1]) + xB = np.minimum(boxes[:, 2], new_box[2]) + yB = np.minimum(boxes[:, 3], new_box[3]) + + interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0) + + # compute the area of both the prediction and ground-truth rectangles + boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1]) + + iou = interArea / (boxAArea + boxBArea - interArea) + + return iou + + if boxes_list.shape[0] == 0: + return -1, match_iou + + boxes = boxes_list + + ious = bb_iou_array(boxes[:, 4:], new_box[4:]) + + ious[boxes[:, 0] != new_box[0]] = -1 + + best_idx = np.argmax(ious) + best_iou = ious[best_idx] + + if best_iou <= match_iou: + best_iou = match_iou + best_idx = -1 + + return best_idx, best_iou diff --git a/mmdet/models/vis/__init__.py b/mmdet/models/vis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab63a9066bcf6cd25d7c9063cc66d9b0390b3d42 --- /dev/null +++ b/mmdet/models/vis/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mask2former_vis import Mask2FormerVideo +from .masktrack_rcnn import MaskTrackRCNN + +__all__ = ['Mask2FormerVideo', 'MaskTrackRCNN'] diff --git a/mmdet/models/vis/mask2former_vis.py b/mmdet/models/vis/mask2former_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab04296e120622f4b5e28739f4c3323d253f7d5 --- /dev/null +++ b/mmdet/models/vis/mask2former_vis.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +from torch import Tensor + +from mmdet.models.mot import BaseMOTModel +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample, TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class Mask2FormerVideo(BaseMOTModel): + r"""Implementation of `Masked-attention Mask + Transformer for Universal Image Segmentation + `_. + + Args: + backbone (dict): Configuration of backbone. Defaults to None. + track_head (dict): Configuration of track head. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + Defaults to None. + init_cfg (dict or list[dict]): Configuration of initialization. + Defaults to None. + """ + + def __init__(self, + backbone: Optional[dict] = None, + track_head: Optional[dict] = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super(BaseMOTModel, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + if backbone is not None: + self.backbone = MODELS.build(backbone) + + if track_head is not None: + self.track_head = MODELS.build(track_head) + + self.num_classes = self.track_head.num_classes + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Overload in order to load mmdet pretrained ckpt.""" + for key in list(state_dict): + if key.startswith('panoptic_head'): + state_dict[key.replace('panoptic', + 'track')] = state_dict.pop(key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) + + def loss(self, inputs: Tensor, data_samples: TrackSampleList, + **kwargs) -> Union[dict, tuple]: + """ + Args: + inputs (Tensor): Input images of shape (N, T, C, H, W). + These should usually be mean centered and std scaled. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + # shape (N * T, C, H, W) + img = inputs.flatten(0, 1) + + x = self.backbone(img) + losses = self.track_head.loss(x, data_samples) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: TrackSampleList, + rescale: bool = True) -> TrackSampleList: + """Predict results from a batch of inputs and data samples with + postprocessing. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of frames in a video. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + TrackSampleList: Tracking results of the inputs. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + + assert len(data_samples) == 1, \ + 'Mask2former only support 1 batch size per gpu for now.' + + # [T, C, H, W] + img = inputs[0] + track_data_sample = data_samples[0] + feats = self.backbone(img) + pred_track_ins_list = self.track_head.predict(feats, track_data_sample, + rescale) + + det_data_samples_list = [] + for idx, pred_track_ins in enumerate(pred_track_ins_list): + img_data_sample = track_data_sample[idx] + img_data_sample.pred_track_instances = pred_track_ins + det_data_samples_list.append(img_data_sample) + + results = TrackDataSample() + results.video_data_samples = det_data_samples_list + return [results] diff --git a/mmdet/models/vis/masktrack_rcnn.py b/mmdet/models/vis/masktrack_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..9c28e7b8529d3d53d5a59ecff0ea46662d035f23 --- /dev/null +++ b/mmdet/models/vis/masktrack_rcnn.py @@ -0,0 +1,181 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from torch import Tensor + +from mmdet.models.mot import BaseMOTModel +from mmdet.registry import MODELS +from mmdet.structures import TrackSampleList +from mmdet.utils import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class MaskTrackRCNN(BaseMOTModel): + """Video Instance Segmentation. + + This video instance segmentor is the implementation of`MaskTrack R-CNN + `_. + + Args: + detector (dict): Configuration of detector. Defaults to None. + track_head (dict): Configuration of track head. Defaults to None. + tracker (dict): Configuration of tracker. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`TrackDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or list[dict]): Configuration of initialization. + Defaults to None. + """ + + def __init__(self, + detector: Optional[dict] = None, + track_head: Optional[dict] = None, + tracker: Optional[dict] = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__(data_preprocessor, init_cfg) + + if detector is not None: + self.detector = MODELS.build(detector) + assert hasattr(self.detector, 'roi_head'), \ + 'MaskTrack R-CNN only supports two stage detectors.' + + if track_head is not None: + self.track_head = MODELS.build(track_head) + if tracker is not None: + self.tracker = MODELS.build(tracker) + + def loss(self, inputs: Tensor, data_samples: TrackSampleList, + **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding + input images. Typically these should be mean centered and std + scaled. The N denotes batch size. The T denotes the number of + frames. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict: A dictionary of loss components. + """ + + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + assert inputs.size(1) == 2, \ + 'MaskTrackRCNN can only have 1 key frame and 1 reference frame.' + + # split the data_samples into two aspects: key frames and reference + # frames + ref_data_samples, key_data_samples = [], [] + key_frame_inds, ref_frame_inds = [], [] + + # set cat_id of gt_labels to 0 in RPN + for track_data_sample in data_samples: + key_data_sample = track_data_sample.get_key_frames()[0] + key_data_samples.append(key_data_sample) + ref_data_sample = track_data_sample.get_ref_frames()[0] + ref_data_samples.append(ref_data_sample) + key_frame_inds.append(track_data_sample.key_frames_inds[0]) + ref_frame_inds.append(track_data_sample.ref_frames_inds[0]) + + key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64) + ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64) + batch_inds = torch.arange(len(inputs)) + key_imgs = inputs[batch_inds, key_frame_inds].contiguous() + ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous() + + x = self.detector.extract_feat(key_imgs) + ref_x = self.detector.extract_feat(ref_imgs) + + losses = dict() + + # RPN forward and loss + if self.detector.with_rpn: + proposal_cfg = self.detector.train_cfg.get( + 'rpn_proposal', self.detector.test_cfg.rpn) + + rpn_losses, rpn_results_list = self.detector.rpn_head. \ + loss_and_predict(x, + key_data_samples, + proposal_cfg=proposal_cfg, + **kwargs) + + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in keys: + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + # TODO: Not support currently, should have a check at Fast R-CNN + assert key_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + rpn_results_list = [ + key_data_sample.proposals + for key_data_sample in key_data_samples + ] + + losses_detect = self.detector.roi_head.loss(x, rpn_results_list, + key_data_samples, **kwargs) + losses.update(losses_detect) + + losses_track = self.track_head.loss(x, ref_x, rpn_results_list, + data_samples, **kwargs) + losses.update(losses_track) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: TrackSampleList, + rescale: bool = True, + **kwargs) -> TrackSampleList: + """Test without augmentation. + + Args: + inputs (Tensor): of shape (N, T, C, H, W) encoding + input images. The N denotes batch size. + The T denotes the number of frames in a video. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `video_data_samples`. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + TrackSampleList: Tracking results of the inputs. + """ + assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + + assert len(data_samples) == 1, \ + 'MaskTrackRCNN only support 1 batch size per gpu for now.' + + track_data_sample = data_samples[0] + video_len = len(track_data_sample) + if track_data_sample[0].frame_id == 0: + self.tracker.reset() + + for frame_id in range(video_len): + img_data_sample = track_data_sample[frame_id] + single_img = inputs[:, frame_id].contiguous() + x = self.detector.extract_feat(single_img) + + rpn_results_list = self.detector.rpn_head.predict( + x, [img_data_sample]) + # det_results List[InstanceData] + det_results = self.detector.roi_head.predict( + x, rpn_results_list, [img_data_sample], rescale=rescale) + assert len(det_results) == 1, 'Batch inference is not supported.' + assert 'masks' in det_results[0], 'There are no mask results.' + + img_data_sample.pred_instances = det_results[0] + frame_pred_track_instances = self.tracker.track( + model=self, feats=x, data_sample=img_data_sample, **kwargs) + img_data_sample.pred_track_instances = frame_pred_track_instances + + return [track_data_sample] diff --git a/mmdet/registry.py b/mmdet/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5b2b28a4f80a488994b48a99043a20c604e55e --- /dev/null +++ b/mmdet/registry.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMDetection provides 17 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry( + 'runner', parent=MMENGINE_RUNNERS, locations=['mmdet.engine.runner']) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', + parent=MMENGINE_RUNNER_CONSTRUCTORS, + locations=['mmdet.engine.runner']) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry( + 'loop', parent=MMENGINE_LOOPS, locations=['mmdet.engine.runner']) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['mmdet.engine.hooks']) + +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['mmdet.datasets']) +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmdet.datasets.samplers']) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmdet.datasets.transforms']) + +# manage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmdet.models']) +# manage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmdet.models']) +# manage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmdet.models']) + +# manage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmdet.engine.optimizers']) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry( + 'optim_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmdet.engine.optimizers']) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmdet.engine.optimizers']) +# manage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmdet.engine.schedulers']) +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['mmdet.evaluation']) +# manage evaluator +EVALUATOR = Registry( + 'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmdet.evaluation']) + +# manage task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmdet.models']) + +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmdet.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmdet.visualization']) + +# manage logprocessor +LOG_PROCESSORS = Registry( + 'log_processor', + parent=MMENGINE_LOG_PROCESSORS, + # TODO: update the location when mmdet has its own log processor + locations=['mmdet.engine']) diff --git a/mmdet/structures/.DS_Store b/mmdet/structures/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..39d10d1b4a0fd79e325357426e08735d1fa2af36 Binary files /dev/null and b/mmdet/structures/.DS_Store differ diff --git a/mmdet/structures/__init__.py b/mmdet/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..381c6a4f4549c2c4395d994cbd860a3e52eb9994 --- /dev/null +++ b/mmdet/structures/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .det_data_sample import DetDataSample, OptSampleList, SampleList +from .reid_data_sample import ReIDDataSample +from .track_data_sample import (OptTrackSampleList, TrackDataSample, + TrackSampleList) + +__all__ = [ + 'DetDataSample', 'SampleList', 'OptSampleList', 'TrackDataSample', + 'TrackSampleList', 'OptTrackSampleList', 'ReIDDataSample' +] diff --git a/mmdet/structures/__pycache__/__init__.cpython-311.pyc b/mmdet/structures/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa84cce715bbfd71bf176692169f04b4b29c3e14 Binary files /dev/null and b/mmdet/structures/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/structures/__pycache__/det_data_sample.cpython-311.pyc b/mmdet/structures/__pycache__/det_data_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9537591ea0a2b67ad1ce49e913284368bf816456 Binary files /dev/null and b/mmdet/structures/__pycache__/det_data_sample.cpython-311.pyc differ diff --git a/mmdet/structures/__pycache__/reid_data_sample.cpython-311.pyc b/mmdet/structures/__pycache__/reid_data_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..858bcb3fec7e755d07d2750c61b00413e8d1859c Binary files /dev/null and b/mmdet/structures/__pycache__/reid_data_sample.cpython-311.pyc differ diff --git a/mmdet/structures/__pycache__/track_data_sample.cpython-311.pyc b/mmdet/structures/__pycache__/track_data_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f65db0013736a1c381817914999a7936fb1efa6a Binary files /dev/null and b/mmdet/structures/__pycache__/track_data_sample.cpython-311.pyc differ diff --git a/mmdet/structures/bbox/__init__.py b/mmdet/structures/bbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d531986509ad1b2141118449aab39343bbde82c --- /dev/null +++ b/mmdet/structures/bbox/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_boxes import BaseBoxes +from .bbox_overlaps import bbox_overlaps +from .box_type import (autocast_box_type, convert_box_type, get_box_type, + register_box, register_box_converter) +from .horizontal_boxes import HorizontalBoxes +from .transforms import bbox_cxcyah_to_xyxy # noqa: E501 +from .transforms import (bbox2corner, bbox2distance, bbox2result, bbox2roi, + bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping, + bbox_mapping_back, bbox_project, bbox_rescale, + bbox_xyxy_to_cxcyah, bbox_xyxy_to_cxcywh, cat_boxes, + corner2bbox, distance2bbox, empty_box_as, + find_inside_bboxes, get_box_tensor, get_box_wh, + roi2bbox, scale_boxes, stack_boxes) + +__all__ = [ + 'bbox_overlaps', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', + 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', + 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh', + 'find_inside_bboxes', 'bbox2corner', 'corner2bbox', 'bbox_project', + 'BaseBoxes', 'convert_box_type', 'get_box_type', 'register_box', + 'register_box_converter', 'HorizontalBoxes', 'autocast_box_type', + 'cat_boxes', 'stack_boxes', 'scale_boxes', 'get_box_wh', 'get_box_tensor', + 'empty_box_as', 'bbox_xyxy_to_cxcyah', 'bbox_cxcyah_to_xyxy' +] diff --git a/mmdet/structures/bbox/__pycache__/__init__.cpython-311.pyc b/mmdet/structures/bbox/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58501ea27a8848202664e0a82296dd409985ea51 Binary files /dev/null and b/mmdet/structures/bbox/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/structures/bbox/__pycache__/base_boxes.cpython-311.pyc b/mmdet/structures/bbox/__pycache__/base_boxes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f92ae4c9a32f7e8f30db5e567f531add8a97733 Binary files /dev/null and b/mmdet/structures/bbox/__pycache__/base_boxes.cpython-311.pyc differ diff --git a/mmdet/structures/bbox/base_boxes.py b/mmdet/structures/bbox/base_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed667664a8a57a1b9b7e422af03d41274882747 --- /dev/null +++ b/mmdet/structures/bbox/base_boxes.py @@ -0,0 +1,549 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod, abstractproperty, abstractstaticmethod +from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union + +import numpy as np +import torch +from torch import BoolTensor, Tensor + +from mmdet.structures.mask.structures import BitmapMasks, PolygonMasks + +T = TypeVar('T') +DeviceType = Union[str, torch.device] +IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor, + torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray] +MaskType = Union[BitmapMasks, PolygonMasks] + + +class BaseBoxes(metaclass=ABCMeta): + """The base class for 2D box types. + + The functions of ``BaseBoxes`` lie in three fields: + + - Verify the boxes shape. + - Support tensor-like operations. + - Define abstract functions for 2D boxes. + + In ``__init__`` , ``BaseBoxes`` verifies the validity of the data shape + w.r.t ``box_dim``. The tensor with the dimension >= 2 and the length + of the last dimension being ``box_dim`` will be regarded as valid. + ``BaseBoxes`` will restore them at the field ``tensor``. It's necessary + to override ``box_dim`` in subclass to guarantee the data shape is + correct. + + There are many basic tensor-like functions implemented in ``BaseBoxes``. + In most cases, users can operate ``BaseBoxes`` instance like a normal + tensor. To protect the validity of data shape, All tensor-like functions + cannot modify the last dimension of ``self.tensor``. + + When creating a new box type, users need to inherit from ``BaseBoxes`` + and override abstract methods and specify the ``box_dim``. Then, register + the new box type by using the decorator ``register_box_type``. + + Args: + data (Tensor or np.ndarray or Sequence): The box data with shape + (..., box_dim). + dtype (torch.dtype, Optional): data type of boxes. Defaults to None. + device (str or torch.device, Optional): device of boxes. + Default to None. + clone (bool): Whether clone ``boxes`` or not. Defaults to True. + """ + + # Used to verify the last dimension length + # Should override it in subclass. + box_dim: int = 0 + + def __init__(self, + data: Union[Tensor, np.ndarray, Sequence], + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceType] = None, + clone: bool = True) -> None: + if isinstance(data, (np.ndarray, Tensor, Sequence)): + data = torch.as_tensor(data) + else: + raise TypeError('boxes should be Tensor, ndarray, or Sequence, ', + f'but got {type(data)}') + + if device is not None or dtype is not None: + data = data.to(dtype=dtype, device=device) + # Clone the data to avoid potential bugs + if clone: + data = data.clone() + # handle the empty input like [] + if data.numel() == 0: + data = data.reshape((-1, self.box_dim)) + + assert data.dim() >= 2 and data.size(-1) == self.box_dim, \ + ('The boxes dimension must >= 2 and the length of the last ' + f'dimension must be {self.box_dim}, but got boxes with ' + f'shape {data.shape}.') + self.tensor = data + + def convert_to(self, dst_type: Union[str, type]) -> 'BaseBoxes': + """Convert self to another box type. + + Args: + dst_type (str or type): destination box type. + + Returns: + :obj:`BaseBoxes`: destination box type object . + """ + from .box_type import convert_box_type + return convert_box_type(self, dst_type=dst_type) + + def empty_boxes(self: T, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceType] = None) -> T: + """Create empty box. + + Args: + dtype (torch.dtype, Optional): data type of boxes. + device (str or torch.device, Optional): device of boxes. + + Returns: + T: empty boxes with shape of (0, box_dim). + """ + empty_box = self.tensor.new_zeros( + 0, self.box_dim, dtype=dtype, device=device) + return type(self)(empty_box, clone=False) + + def fake_boxes(self: T, + sizes: Tuple[int], + fill: float = 0, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceType] = None) -> T: + """Create fake boxes with specific sizes and fill values. + + Args: + sizes (Tuple[int]): The size of fake boxes. The last value must + be equal with ``self.box_dim``. + fill (float): filling value. Defaults to 0. + dtype (torch.dtype, Optional): data type of boxes. + device (str or torch.device, Optional): device of boxes. + + Returns: + T: Fake boxes with shape of ``sizes``. + """ + fake_boxes = self.tensor.new_full( + sizes, fill, dtype=dtype, device=device) + return type(self)(fake_boxes, clone=False) + + def __getitem__(self: T, index: IndexType) -> T: + """Rewrite getitem to protect the last dimension shape.""" + boxes = self.tensor + if isinstance(index, np.ndarray): + index = torch.as_tensor(index, device=self.device) + if isinstance(index, Tensor) and index.dtype == torch.bool: + assert index.dim() < boxes.dim() + elif isinstance(index, tuple): + assert len(index) < boxes.dim() + # `Ellipsis`(...) is commonly used in index like [None, ...]. + # When `Ellipsis` is in index, it must be the last item. + if Ellipsis in index: + assert index[-1] is Ellipsis + + boxes = boxes[index] + if boxes.dim() == 1: + boxes = boxes.reshape(1, -1) + return type(self)(boxes, clone=False) + + def __setitem__(self: T, index: IndexType, values: Union[Tensor, T]) -> T: + """Rewrite setitem to protect the last dimension shape.""" + assert type(values) is type(self), \ + 'The value to be set must be the same box type as self' + values = values.tensor + + if isinstance(index, np.ndarray): + index = torch.as_tensor(index, device=self.device) + if isinstance(index, Tensor) and index.dtype == torch.bool: + assert index.dim() < self.tensor.dim() + elif isinstance(index, tuple): + assert len(index) < self.tensor.dim() + # `Ellipsis`(...) is commonly used in index like [None, ...]. + # When `Ellipsis` is in index, it must be the last item. + if Ellipsis in index: + assert index[-1] is Ellipsis + + self.tensor[index] = values + + def __len__(self) -> int: + """Return the length of self.tensor first dimension.""" + return self.tensor.size(0) + + def __deepcopy__(self, memo): + """Only clone the ``self.tensor`` when applying deepcopy.""" + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + other.tensor = self.tensor.clone() + return other + + def __repr__(self) -> str: + """Return a strings that describes the object.""" + return self.__class__.__name__ + '(\n' + str(self.tensor) + ')' + + def new_tensor(self, *args, **kwargs) -> Tensor: + """Reload ``new_tensor`` from self.tensor.""" + return self.tensor.new_tensor(*args, **kwargs) + + def new_full(self, *args, **kwargs) -> Tensor: + """Reload ``new_full`` from self.tensor.""" + return self.tensor.new_full(*args, **kwargs) + + def new_empty(self, *args, **kwargs) -> Tensor: + """Reload ``new_empty`` from self.tensor.""" + return self.tensor.new_empty(*args, **kwargs) + + def new_ones(self, *args, **kwargs) -> Tensor: + """Reload ``new_ones`` from self.tensor.""" + return self.tensor.new_ones(*args, **kwargs) + + def new_zeros(self, *args, **kwargs) -> Tensor: + """Reload ``new_zeros`` from self.tensor.""" + return self.tensor.new_zeros(*args, **kwargs) + + def size(self, dim: Optional[int] = None) -> Union[int, torch.Size]: + """Reload new_zeros from self.tensor.""" + # self.tensor.size(dim) cannot work when dim=None. + return self.tensor.size() if dim is None else self.tensor.size(dim) + + def dim(self) -> int: + """Reload ``dim`` from self.tensor.""" + return self.tensor.dim() + + @property + def device(self) -> torch.device: + """Reload ``device`` from self.tensor.""" + return self.tensor.device + + @property + def dtype(self) -> torch.dtype: + """Reload ``dtype`` from self.tensor.""" + return self.tensor.dtype + + @property + def shape(self) -> torch.Size: + return self.tensor.shape + + def numel(self) -> int: + """Reload ``numel`` from self.tensor.""" + return self.tensor.numel() + + def numpy(self) -> np.ndarray: + """Reload ``numpy`` from self.tensor.""" + return self.tensor.numpy() + + def to(self: T, *args, **kwargs) -> T: + """Reload ``to`` from self.tensor.""" + return type(self)(self.tensor.to(*args, **kwargs), clone=False) + + def cpu(self: T) -> T: + """Reload ``cpu`` from self.tensor.""" + return type(self)(self.tensor.cpu(), clone=False) + + def cuda(self: T, *args, **kwargs) -> T: + """Reload ``cuda`` from self.tensor.""" + return type(self)(self.tensor.cuda(*args, **kwargs), clone=False) + + def clone(self: T) -> T: + """Reload ``clone`` from self.tensor.""" + return type(self)(self.tensor) + + def detach(self: T) -> T: + """Reload ``detach`` from self.tensor.""" + return type(self)(self.tensor.detach(), clone=False) + + def view(self: T, *shape: Tuple[int]) -> T: + """Reload ``view`` from self.tensor.""" + return type(self)(self.tensor.view(shape), clone=False) + + def reshape(self: T, *shape: Tuple[int]) -> T: + """Reload ``reshape`` from self.tensor.""" + return type(self)(self.tensor.reshape(shape), clone=False) + + def expand(self: T, *sizes: Tuple[int]) -> T: + """Reload ``expand`` from self.tensor.""" + return type(self)(self.tensor.expand(sizes), clone=False) + + def repeat(self: T, *sizes: Tuple[int]) -> T: + """Reload ``repeat`` from self.tensor.""" + return type(self)(self.tensor.repeat(sizes), clone=False) + + def transpose(self: T, dim0: int, dim1: int) -> T: + """Reload ``transpose`` from self.tensor.""" + ndim = self.tensor.dim() + assert dim0 != -1 and dim0 != ndim - 1 + assert dim1 != -1 and dim1 != ndim - 1 + return type(self)(self.tensor.transpose(dim0, dim1), clone=False) + + def permute(self: T, *dims: Tuple[int]) -> T: + """Reload ``permute`` from self.tensor.""" + assert dims[-1] == -1 or dims[-1] == self.tensor.dim() - 1 + return type(self)(self.tensor.permute(dims), clone=False) + + def split(self: T, + split_size_or_sections: Union[int, Sequence[int]], + dim: int = 0) -> List[T]: + """Reload ``split`` from self.tensor.""" + assert dim != -1 and dim != self.tensor.dim() - 1 + boxes_list = self.tensor.split(split_size_or_sections, dim=dim) + return [type(self)(boxes, clone=False) for boxes in boxes_list] + + def chunk(self: T, chunks: int, dim: int = 0) -> List[T]: + """Reload ``chunk`` from self.tensor.""" + assert dim != -1 and dim != self.tensor.dim() - 1 + boxes_list = self.tensor.chunk(chunks, dim=dim) + return [type(self)(boxes, clone=False) for boxes in boxes_list] + + def unbind(self: T, dim: int = 0) -> T: + """Reload ``unbind`` from self.tensor.""" + assert dim != -1 and dim != self.tensor.dim() - 1 + boxes_list = self.tensor.unbind(dim=dim) + return [type(self)(boxes, clone=False) for boxes in boxes_list] + + def flatten(self: T, start_dim: int = 0, end_dim: int = -2) -> T: + """Reload ``flatten`` from self.tensor.""" + assert end_dim != -1 and end_dim != self.tensor.dim() - 1 + return type(self)(self.tensor.flatten(start_dim, end_dim), clone=False) + + def squeeze(self: T, dim: Optional[int] = None) -> T: + """Reload ``squeeze`` from self.tensor.""" + boxes = self.tensor.squeeze() if dim is None else \ + self.tensor.squeeze(dim) + return type(self)(boxes, clone=False) + + def unsqueeze(self: T, dim: int) -> T: + """Reload ``unsqueeze`` from self.tensor.""" + assert dim != -1 and dim != self.tensor.dim() + return type(self)(self.tensor.unsqueeze(dim), clone=False) + + @classmethod + def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T: + """Cancatenates a box instance list into one single box instance. + Similar to ``torch.cat``. + + Args: + box_list (Sequence[T]): A sequence of box instances. + dim (int): The dimension over which the box are concatenated. + Defaults to 0. + + Returns: + T: Concatenated box instance. + """ + assert isinstance(box_list, Sequence) + if len(box_list) == 0: + raise ValueError('box_list should not be a empty list.') + + assert dim != -1 and dim != box_list[0].dim() - 1 + assert all(isinstance(boxes, cls) for boxes in box_list) + + th_box_list = [boxes.tensor for boxes in box_list] + return cls(torch.cat(th_box_list, dim=dim), clone=False) + + @classmethod + def stack(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T: + """Concatenates a sequence of tensors along a new dimension. Similar to + ``torch.stack``. + + Args: + box_list (Sequence[T]): A sequence of box instances. + dim (int): Dimension to insert. Defaults to 0. + + Returns: + T: Concatenated box instance. + """ + assert isinstance(box_list, Sequence) + if len(box_list) == 0: + raise ValueError('box_list should not be a empty list.') + + assert dim != -1 and dim != box_list[0].dim() + assert all(isinstance(boxes, cls) for boxes in box_list) + + th_box_list = [boxes.tensor for boxes in box_list] + return cls(torch.stack(th_box_list, dim=dim), clone=False) + + @abstractproperty + def centers(self) -> Tensor: + """Return a tensor representing the centers of boxes.""" + pass + + @abstractproperty + def areas(self) -> Tensor: + """Return a tensor representing the areas of boxes.""" + pass + + @abstractproperty + def widths(self) -> Tensor: + """Return a tensor representing the widths of boxes.""" + pass + + @abstractproperty + def heights(self) -> Tensor: + """Return a tensor representing the heights of boxes.""" + pass + + @abstractmethod + def flip_(self, + img_shape: Tuple[int, int], + direction: str = 'horizontal') -> None: + """Flip boxes horizontally or vertically in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + direction (str): Flip direction, options are "horizontal", + "vertical" and "diagonal". Defaults to "horizontal" + """ + pass + + @abstractmethod + def translate_(self, distances: Tuple[float, float]) -> None: + """Translate boxes in-place. + + Args: + distances (Tuple[float, float]): translate distances. The first + is horizontal distance and the second is vertical distance. + """ + pass + + @abstractmethod + def clip_(self, img_shape: Tuple[int, int]) -> None: + """Clip boxes according to the image shape in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + """ + pass + + @abstractmethod + def rotate_(self, center: Tuple[float, float], angle: float) -> None: + """Rotate all boxes in-place. + + Args: + center (Tuple[float, float]): Rotation origin. + angle (float): Rotation angle represented in degrees. Positive + values mean clockwise rotation. + """ + pass + + @abstractmethod + def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None: + """Geometric transformat boxes in-place. + + Args: + homography_matrix (Tensor or np.ndarray]): + Shape (3, 3) for geometric transformation. + """ + pass + + @abstractmethod + def rescale_(self, scale_factor: Tuple[float, float]) -> None: + """Rescale boxes w.r.t. rescale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling boxes. + The length should be 2. + """ + pass + + @abstractmethod + def resize_(self, scale_factor: Tuple[float, float]) -> None: + """Resize the box width and height w.r.t scale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling box + shapes. The length should be 2. + """ + pass + + @abstractmethod + def is_inside(self, + img_shape: Tuple[int, int], + all_inside: bool = False, + allowed_border: int = 0) -> BoolTensor: + """Find boxes inside the image. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + all_inside (bool): Whether the boxes are all inside the image or + part inside the image. Defaults to False. + allowed_border (int): Boxes that extend beyond the image shape + boundary by more than ``allowed_border`` are considered + "outside" Defaults to 0. + Returns: + BoolTensor: A BoolTensor indicating whether the box is inside + the image. Assuming the original boxes have shape (m, n, box_dim), + the output has shape (m, n). + """ + pass + + @abstractmethod + def find_inside_points(self, + points: Tensor, + is_aligned: bool = False) -> BoolTensor: + """Find inside box points. Boxes dimension must be 2. + + Args: + points (Tensor): Points coordinates. Has shape of (m, 2). + is_aligned (bool): Whether ``points`` has been aligned with boxes + or not. If True, the length of boxes and ``points`` should be + the same. Defaults to False. + + Returns: + BoolTensor: A BoolTensor indicating whether a point is inside + boxes. Assuming the boxes has shape of (n, box_dim), if + ``is_aligned`` is False. The index has shape of (m, n). If + ``is_aligned`` is True, m should be equal to n and the index has + shape of (m, ). + """ + pass + + @abstractstaticmethod + def overlaps(boxes1: 'BaseBoxes', + boxes2: 'BaseBoxes', + mode: str = 'iou', + is_aligned: bool = False, + eps: float = 1e-6) -> Tensor: + """Calculate overlap between two set of boxes with their types + converted to the present box type. + + Args: + boxes1 (:obj:`BaseBoxes`): BaseBoxes with shape of (m, box_dim) + or empty. + boxes2 (:obj:`BaseBoxes`): BaseBoxes with shape of (n, box_dim) + or empty. + mode (str): "iou" (intersection over union), "iof" (intersection + over foreground). Defaults to "iou". + is_aligned (bool): If True, then m and n must be equal. Defaults + to False. + eps (float): A value added to the denominator for numerical + stability. Defaults to 1e-6. + + Returns: + Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) + """ + pass + + @abstractstaticmethod + def from_instance_masks(masks: MaskType) -> 'BaseBoxes': + """Create boxes from instance masks. + + Args: + masks (:obj:`BitmapMasks` or :obj:`PolygonMasks`): BitmapMasks or + PolygonMasks instance with length of n. + + Returns: + :obj:`BaseBoxes`: Converted boxes with shape of (n, box_dim). + """ + pass diff --git a/mmdet/structures/bbox/bbox_overlaps.py b/mmdet/structures/bbox/bbox_overlaps.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3435d28b38a5479a6c791f52a76d8ba293a6eb --- /dev/null +++ b/mmdet/structures/bbox/bbox_overlaps.py @@ -0,0 +1,199 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def fp16_clamp(x, min=None, max=None): + if not x.is_cuda and x.dtype == torch.float16: + # clamp for cpu float16, tensor fp16 has no clamp implementation + return x.float().clamp(min, max).half() + + return x.clamp(min, max) + + +def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): + """Calculate overlap between two set of bboxes. + + FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 + Note: + Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', + there are some new generated variable when calculating IOU + using bbox_overlaps function: + + 1) is_aligned is False + area1: M x 1 + area2: N x 1 + lt: M x N x 2 + rb: M x N x 2 + wh: M x N x 2 + overlap: M x N x 1 + union: M x N x 1 + ious: M x N x 1 + + Total memory: + S = (9 x N x M + N + M) * 4 Byte, + + When using FP16, we can reduce: + R = (9 x N x M + N + M) * 4 / 2 Byte + R large than (N + M) * 4 * 2 is always true when N and M >= 1. + Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, + N + 1 < 3 * N, when N or M is 1. + + Given M = 40 (ground truth), N = 400000 (three anchor boxes + in per grid, FPN, R-CNNs), + R = 275 MB (one times) + + A special case (dense detection), M = 512 (ground truth), + R = 3516 MB = 3.43 GB + + When the batch size is B, reduce: + B x R + + Therefore, CUDA memory runs out frequently. + + Experiments on GeForce RTX 2080Ti (11019 MiB): + + | dtype | M | N | Use | Real | Ideal | + |:----:|:----:|:----:|:----:|:----:|:----:| + | FP32 | 512 | 400000 | 8020 MiB | -- | -- | + | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | + | FP32 | 40 | 400000 | 1540 MiB | -- | -- | + | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | + + 2) is_aligned is True + area1: N x 1 + area2: N x 1 + lt: N x 2 + rb: N x 2 + wh: N x 2 + overlap: N x 1 + union: N x 1 + ious: N x 1 + + Total memory: + S = 11 x N * 4 Byte + + When using FP16, we can reduce: + R = 11 x N * 4 / 2 Byte + + So do the 'giou' (large than 'iou'). + + Time-wise, FP16 is generally faster than FP32. + + When gpu_assign_thr is not -1, it takes more time on cpu + but not reduce memory. + There, we can reduce half the memory and keep the speed. + + If ``is_aligned`` is ``False``, then calculate the overlaps between each + bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned + pair of bboxes1 and bboxes2. + + Args: + bboxes1 (Tensor): shape (B, m, 4) in format or empty. + bboxes2 (Tensor): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned`` is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union), "iof" (intersection over + foreground) or "giou" (generalized intersection over union). + Default "iou". + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + + Returns: + Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) + + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2) + >>> assert overlaps.shape == (3, 3) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) + >>> assert overlaps.shape == (3, ) + + Example: + >>> empty = torch.empty(0, 4) + >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) + """ + + assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}' + # Either the boxes are empty or the length of boxes' last dimension is 4 + assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) + assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) + + # Batch dim must be the same + # Batch dim: (B1, B2, ... Bn) + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.size(-2) + cols = bboxes2.size(-2) + if is_aligned: + assert rows == cols + + if rows * cols == 0: + if is_aligned: + return bboxes1.new(batch_shape + (rows, )) + else: + return bboxes1.new(batch_shape + (rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1]) + + if is_aligned: + lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] + rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] + + wh = fp16_clamp(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + + if mode in ['iou', 'giou']: + union = area1 + area2 - overlap + else: + union = area1 + if mode == 'giou': + enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) + enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) + else: + lt = torch.max(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = torch.min(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = fp16_clamp(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + + if mode in ['iou', 'giou']: + union = area1[..., None] + area2[..., None, :] - overlap + else: + union = area1[..., None] + if mode == 'giou': + enclosed_lt = torch.min(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) + enclosed_rb = torch.max(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) + + eps = union.new_tensor([eps]) + union = torch.max(union, eps) + ious = overlap / union + if mode in ['iou', 'iof']: + return ious + # calculate gious + enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) + enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] + enclose_area = torch.max(enclose_area, eps) + gious = ious - (enclose_area - union) / enclose_area + return gious diff --git a/mmdet/structures/bbox/box_type.py b/mmdet/structures/bbox/box_type.py new file mode 100644 index 0000000000000000000000000000000000000000..c7eb5494c36c8efcbb414897f7c2532a6d3a1ddb --- /dev/null +++ b/mmdet/structures/bbox/box_type.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import Tensor + +from .base_boxes import BaseBoxes + +BoxType = Union[np.ndarray, Tensor, BaseBoxes] + +box_types: dict = {} +_box_type_to_name: dict = {} +box_converters: dict = {} + + +def _register_box(name: str, box_type: Type, force: bool = False) -> None: + """Register a box type. + + Args: + name (str): The name of box type. + box_type (type): Box mode class to be registered. + force (bool): Whether to override an existing class with the same + name. Defaults to False. + """ + assert issubclass(box_type, BaseBoxes) + name = name.lower() + + if not force and (name in box_types or box_type in _box_type_to_name): + raise KeyError(f'box type {name} has been registered') + elif name in box_types: + _box_type = box_types.pop(name) + _box_type_to_name.pop(_box_type) + elif box_type in _box_type_to_name: + _name = _box_type_to_name.pop(box_type) + box_types.pop(_name) + + box_types[name] = box_type + _box_type_to_name[box_type] = name + + +def register_box(name: str, + box_type: Type = None, + force: bool = False) -> Union[Type, Callable]: + """Register a box type. + + A record will be added to ``bbox_types``, whose key is the box type name + and value is the box type itself. Simultaneously, a reverse dictionary + ``_box_type_to_name`` will be updated. It can be used as a decorator or + a normal function. + + Args: + name (str): The name of box type. + bbox_type (type, Optional): Box type class to be registered. + Defaults to None. + force (bool): Whether to override the existing box type with the same + name. Defaults to False. + + Examples: + >>> from mmdet.structures.bbox import register_box + >>> from mmdet.structures.bbox import BaseBoxes + + >>> # as a decorator + >>> @register_box('hbox') + >>> class HorizontalBoxes(BaseBoxes): + >>> pass + + >>> # as a normal function + >>> class RotatedBoxes(BaseBoxes): + >>> pass + >>> register_box('rbox', RotatedBoxes) + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # use it as a normal method: register_box(name, box_type=BoxCls) + if box_type is not None: + _register_box(name=name, box_type=box_type, force=force) + return box_type + + # use it as a decorator: @register_box(name) + def _register(cls): + _register_box(name=name, box_type=cls, force=force) + return cls + + return _register + + +def _register_box_converter(src_type: Union[str, type], + dst_type: Union[str, type], + converter: Callable, + force: bool = False) -> None: + """Register a box converter. + + Args: + src_type (str or type): source box type name or class. + dst_type (str or type): destination box type name or class. + converter (Callable): Convert function. + force (bool): Whether to override the existing box type with the same + name. Defaults to False. + """ + assert callable(converter) + src_type_name, _ = get_box_type(src_type) + dst_type_name, _ = get_box_type(dst_type) + + converter_name = src_type_name + '2' + dst_type_name + if not force and converter_name in box_converters: + raise KeyError(f'The box converter from {src_type_name} to ' + f'{dst_type_name} has been registered.') + + box_converters[converter_name] = converter + + +def register_box_converter(src_type: Union[str, type], + dst_type: Union[str, type], + converter: Optional[Callable] = None, + force: bool = False) -> Callable: + """Register a box converter. + + A record will be added to ``box_converter``, whose key is + '{src_type_name}2{dst_type_name}' and value is the convert function. + It can be used as a decorator or a normal function. + + Args: + src_type (str or type): source box type name or class. + dst_type (str or type): destination box type name or class. + converter (Callable): Convert function. Defaults to None. + force (bool): Whether to override the existing box type with the same + name. Defaults to False. + + Examples: + >>> from mmdet.structures.bbox import register_box_converter + >>> # as a decorator + >>> @register_box_converter('hbox', 'rbox') + >>> def converter_A(boxes): + >>> pass + + >>> # as a normal function + >>> def converter_B(boxes): + >>> pass + >>> register_box_converter('rbox', 'hbox', converter_B) + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # use it as a normal method: + # register_box_converter(src_type, dst_type, converter=Func) + if converter is not None: + _register_box_converter( + src_type=src_type, + dst_type=dst_type, + converter=converter, + force=force) + return converter + + # use it as a decorator: @register_box_converter(name) + def _register(func): + _register_box_converter( + src_type=src_type, dst_type=dst_type, converter=func, force=force) + return func + + return _register + + +def get_box_type(box_type: Union[str, type]) -> Tuple[str, type]: + """get both box type name and class. + + Args: + box_type (str or type): Single box type name or class. + + Returns: + Tuple[str, type]: A tuple of box type name and class. + """ + if isinstance(box_type, str): + type_name = box_type.lower() + assert type_name in box_types, \ + f"Box type {type_name} hasn't been registered in box_types." + type_cls = box_types[type_name] + elif issubclass(box_type, BaseBoxes): + assert box_type in _box_type_to_name, \ + f"Box type {box_type} hasn't been registered in box_types." + type_name = _box_type_to_name[box_type] + type_cls = box_type + else: + raise KeyError('box_type must be a str or class inheriting from ' + f'BaseBoxes, but got {type(box_type)}.') + return type_name, type_cls + + +def convert_box_type(boxes: BoxType, + *, + src_type: Union[str, type] = None, + dst_type: Union[str, type] = None) -> BoxType: + """Convert boxes from source type to destination type. + + If ``boxes`` is a instance of BaseBoxes, the ``src_type`` will be set + as the type of ``boxes``. + + Args: + boxes (np.ndarray or Tensor or :obj:`BaseBoxes`): boxes need to + convert. + src_type (str or type, Optional): source box type. Defaults to None. + dst_type (str or type, Optional): destination box type. Defaults to + None. + + Returns: + Union[np.ndarray, Tensor, :obj:`BaseBoxes`]: Converted boxes. It's type + is consistent with the input's type. + """ + assert dst_type is not None + dst_type_name, dst_type_cls = get_box_type(dst_type) + + is_box_cls = False + is_numpy = False + if isinstance(boxes, BaseBoxes): + src_type_name, _ = get_box_type(type(boxes)) + is_box_cls = True + elif isinstance(boxes, (Tensor, np.ndarray)): + assert src_type is not None + src_type_name, _ = get_box_type(src_type) + if isinstance(boxes, np.ndarray): + is_numpy = True + else: + raise TypeError('boxes must be a instance of BaseBoxes, Tensor or ' + f'ndarray, but get {type(boxes)}.') + + if src_type_name == dst_type_name: + return boxes + + converter_name = src_type_name + '2' + dst_type_name + assert converter_name in box_converters, \ + "Convert function hasn't been registered in box_converters." + converter = box_converters[converter_name] + + if is_box_cls: + boxes = converter(boxes.tensor) + return dst_type_cls(boxes) + elif is_numpy: + boxes = converter(torch.from_numpy(boxes)) + return boxes.numpy() + else: + return converter(boxes) + + +def autocast_box_type(dst_box_type='hbox') -> Callable: + """A decorator which automatically casts results['gt_bboxes'] to the + destination box type. + + It commenly used in mmdet.datasets.transforms to make the transforms up- + compatible with the np.ndarray type of results['gt_bboxes']. + + The speed of processing of np.ndarray and BaseBoxes data are the same: + + - np.ndarray: 0.0509 img/s + - BaseBoxes: 0.0551 img/s + + Args: + dst_box_type (str): Destination box type. + """ + _, box_type_cls = get_box_type(dst_box_type) + + def decorator(func: Callable) -> Callable: + + def wrapper(self, results: dict, *args, **kwargs) -> dict: + if ('gt_bboxes' not in results + or isinstance(results['gt_bboxes'], BaseBoxes)): + return func(self, results) + elif isinstance(results['gt_bboxes'], np.ndarray): + results['gt_bboxes'] = box_type_cls( + results['gt_bboxes'], clone=False) + if 'mix_results' in results: + for res in results['mix_results']: + if isinstance(res['gt_bboxes'], np.ndarray): + res['gt_bboxes'] = box_type_cls( + res['gt_bboxes'], clone=False) + + _results = func(self, results, *args, **kwargs) + + # In some cases, the function will process gt_bboxes in-place + # Simultaneously convert inputting and outputting gt_bboxes + # back to np.ndarray + if isinstance(_results, dict) and 'gt_bboxes' in _results: + if isinstance(_results['gt_bboxes'], BaseBoxes): + _results['gt_bboxes'] = _results['gt_bboxes'].numpy() + if isinstance(results['gt_bboxes'], BaseBoxes): + results['gt_bboxes'] = results['gt_bboxes'].numpy() + return _results + else: + raise TypeError( + "auto_box_type requires results['gt_bboxes'] to " + 'be BaseBoxes or np.ndarray, but got ' + f"{type(results['gt_bboxes'])}") + + return wrapper + + return decorator diff --git a/mmdet/structures/bbox/horizontal_boxes.py b/mmdet/structures/bbox/horizontal_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a78518105fda02cef2d3a2bcaceb410759165c --- /dev/null +++ b/mmdet/structures/bbox/horizontal_boxes.py @@ -0,0 +1,432 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, TypeVar, Union + +import cv2 +import numpy as np +import torch +from torch import BoolTensor, Tensor + +from mmdet.structures.mask.structures import BitmapMasks, PolygonMasks +from .base_boxes import BaseBoxes +from .bbox_overlaps import bbox_overlaps +from .box_type import register_box + +T = TypeVar('T') +DeviceType = Union[str, torch.device] +MaskType = Union[BitmapMasks, PolygonMasks] + + +@register_box(name='hbox') +class HorizontalBoxes(BaseBoxes): + """The horizontal box class used in MMDetection by default. + + The ``box_dim`` of ``HorizontalBoxes`` is 4, which means the length of + the last dimension of the data should be 4. Two modes of box data are + supported in ``HorizontalBoxes``: + + - 'xyxy': Each row of data indicates (x1, y1, x2, y2), which are the + coordinates of the left-top and right-bottom points. + - 'cxcywh': Each row of data indicates (x, y, w, h), where (x, y) are the + coordinates of the box centers and (w, h) are the width and height. + + ``HorizontalBoxes`` only restores 'xyxy' mode of data. If the the data is + in 'cxcywh' mode, users need to input ``in_mode='cxcywh'`` and The code + will convert the 'cxcywh' data to 'xyxy' automatically. + + Args: + data (Tensor or np.ndarray or Sequence): The box data with shape of + (..., 4). + dtype (torch.dtype, Optional): data type of boxes. Defaults to None. + device (str or torch.device, Optional): device of boxes. + Default to None. + clone (bool): Whether clone ``boxes`` or not. Defaults to True. + mode (str, Optional): the mode of boxes. If it is 'cxcywh', the + `data` will be converted to 'xyxy' mode. Defaults to None. + """ + + box_dim: int = 4 + + def __init__(self, + data: Union[Tensor, np.ndarray], + dtype: torch.dtype = None, + device: DeviceType = None, + clone: bool = True, + in_mode: Optional[str] = None) -> None: + super().__init__(data=data, dtype=dtype, device=device, clone=clone) + if isinstance(in_mode, str): + if in_mode not in ('xyxy', 'cxcywh'): + raise ValueError(f'Get invalid mode {in_mode}.') + if in_mode == 'cxcywh': + self.tensor = self.cxcywh_to_xyxy(self.tensor) + + @staticmethod + def cxcywh_to_xyxy(boxes: Tensor) -> Tensor: + """Convert box coordinates from (cx, cy, w, h) to (x1, y1, x2, y2). + + Args: + boxes (Tensor): cxcywh boxes tensor with shape of (..., 4). + + Returns: + Tensor: xyxy boxes tensor with shape of (..., 4). + """ + ctr, wh = boxes.split((2, 2), dim=-1) + return torch.cat([(ctr - wh / 2), (ctr + wh / 2)], dim=-1) + + @staticmethod + def xyxy_to_cxcywh(boxes: Tensor) -> Tensor: + """Convert box coordinates from (x1, y1, x2, y2) to (cx, cy, w, h). + + Args: + boxes (Tensor): xyxy boxes tensor with shape of (..., 4). + + Returns: + Tensor: cxcywh boxes tensor with shape of (..., 4). + """ + xy1, xy2 = boxes.split((2, 2), dim=-1) + return torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1) + + @property + def cxcywh(self) -> Tensor: + """Return a tensor representing the cxcywh boxes.""" + return self.xyxy_to_cxcywh(self.tensor) + + @property + def centers(self) -> Tensor: + """Return a tensor representing the centers of boxes.""" + boxes = self.tensor + return (boxes[..., :2] + boxes[..., 2:]) / 2 + + @property + def areas(self) -> Tensor: + """Return a tensor representing the areas of boxes.""" + boxes = self.tensor + return (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1]) + + @property + def widths(self) -> Tensor: + """Return a tensor representing the widths of boxes.""" + boxes = self.tensor + return boxes[..., 2] - boxes[..., 0] + + @property + def heights(self) -> Tensor: + """Return a tensor representing the heights of boxes.""" + boxes = self.tensor + return boxes[..., 3] - boxes[..., 1] + + def flip_(self, + img_shape: Tuple[int, int], + direction: str = 'horizontal') -> None: + """Flip boxes horizontally or vertically in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + direction (str): Flip direction, options are "horizontal", + "vertical" and "diagonal". Defaults to "horizontal" + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + flipped = self.tensor + boxes = flipped.clone() + if direction == 'horizontal': + flipped[..., 0] = img_shape[1] - boxes[..., 2] + flipped[..., 2] = img_shape[1] - boxes[..., 0] + elif direction == 'vertical': + flipped[..., 1] = img_shape[0] - boxes[..., 3] + flipped[..., 3] = img_shape[0] - boxes[..., 1] + else: + flipped[..., 0] = img_shape[1] - boxes[..., 2] + flipped[..., 1] = img_shape[0] - boxes[..., 3] + flipped[..., 2] = img_shape[1] - boxes[..., 0] + flipped[..., 3] = img_shape[0] - boxes[..., 1] + + def translate_(self, distances: Tuple[float, float]) -> None: + """Translate boxes in-place. + + Args: + distances (Tuple[float, float]): translate distances. The first + is horizontal distance and the second is vertical distance. + """ + boxes = self.tensor + assert len(distances) == 2 + self.tensor = boxes + boxes.new_tensor(distances).repeat(2) + + def clip_(self, img_shape: Tuple[int, int]) -> None: + """Clip boxes according to the image shape in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + """ + boxes = self.tensor + boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1]) + boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0]) + + def rotate_(self, center: Tuple[float, float], angle: float) -> None: + """Rotate all boxes in-place. + + Args: + center (Tuple[float, float]): Rotation origin. + angle (float): Rotation angle represented in degrees. Positive + values mean clockwise rotation. + """ + boxes = self.tensor + rotation_matrix = boxes.new_tensor( + cv2.getRotationMatrix2D(center, -angle, 1)) + + corners = self.hbox2corner(boxes) + corners = torch.cat( + [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1) + corners_T = torch.transpose(corners, -1, -2) + corners_T = torch.matmul(rotation_matrix, corners_T) + corners = torch.transpose(corners_T, -1, -2) + self.tensor = self.corner2hbox(corners) + + def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None: + """Geometric transformat boxes in-place. + + Args: + homography_matrix (Tensor or np.ndarray]): + Shape (3, 3) for geometric transformation. + """ + boxes = self.tensor + if isinstance(homography_matrix, np.ndarray): + homography_matrix = boxes.new_tensor(homography_matrix) + corners = self.hbox2corner(boxes) + corners = torch.cat( + [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1) + corners_T = torch.transpose(corners, -1, -2) + corners_T = torch.matmul(homography_matrix, corners_T) + corners = torch.transpose(corners_T, -1, -2) + # Convert to homogeneous coordinates by normalization + corners = corners[..., :2] / corners[..., 2:3] + self.tensor = self.corner2hbox(corners) + + @staticmethod + def hbox2corner(boxes: Tensor) -> Tensor: + """Convert box coordinates from (x1, y1, x2, y2) to corners ((x1, y1), + (x2, y1), (x1, y2), (x2, y2)). + + Args: + boxes (Tensor): Horizontal box tensor with shape of (..., 4). + + Returns: + Tensor: Corner tensor with shape of (..., 4, 2). + """ + x1, y1, x2, y2 = torch.split(boxes, 1, dim=-1) + corners = torch.cat([x1, y1, x2, y1, x1, y2, x2, y2], dim=-1) + return corners.reshape(*corners.shape[:-1], 4, 2) + + @staticmethod + def corner2hbox(corners: Tensor) -> Tensor: + """Convert box coordinates from corners ((x1, y1), (x2, y1), (x1, y2), + (x2, y2)) to (x1, y1, x2, y2). + + Args: + corners (Tensor): Corner tensor with shape of (..., 4, 2). + + Returns: + Tensor: Horizontal box tensor with shape of (..., 4). + """ + if corners.numel() == 0: + return corners.new_zeros((0, 4)) + min_xy = corners.min(dim=-2)[0] + max_xy = corners.max(dim=-2)[0] + return torch.cat([min_xy, max_xy], dim=-1) + + def rescale_(self, scale_factor: Tuple[float, float]) -> None: + """Rescale boxes w.r.t. rescale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling boxes. + The length should be 2. + """ + boxes = self.tensor + assert len(scale_factor) == 2 + scale_factor = boxes.new_tensor(scale_factor).repeat(2) + self.tensor = boxes * scale_factor + + def resize_(self, scale_factor: Tuple[float, float]) -> None: + """Resize the box width and height w.r.t scale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling box + shapes. The length should be 2. + """ + boxes = self.tensor + assert len(scale_factor) == 2 + ctrs = (boxes[..., 2:] + boxes[..., :2]) / 2 + wh = boxes[..., 2:] - boxes[..., :2] + scale_factor = boxes.new_tensor(scale_factor) + wh = wh * scale_factor + xy1 = ctrs - 0.5 * wh + xy2 = ctrs + 0.5 * wh + self.tensor = torch.cat([xy1, xy2], dim=-1) + + def is_inside(self, + img_shape: Tuple[int, int], + all_inside: bool = False, + allowed_border: int = 0) -> BoolTensor: + """Find boxes inside the image. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + all_inside (bool): Whether the boxes are all inside the image or + part inside the image. Defaults to False. + allowed_border (int): Boxes that extend beyond the image shape + boundary by more than ``allowed_border`` are considered + "outside" Defaults to 0. + Returns: + BoolTensor: A BoolTensor indicating whether the box is inside + the image. Assuming the original boxes have shape (m, n, 4), + the output has shape (m, n). + """ + img_h, img_w = img_shape + boxes = self.tensor + if all_inside: + return (boxes[:, 0] >= -allowed_border) & \ + (boxes[:, 1] >= -allowed_border) & \ + (boxes[:, 2] < img_w + allowed_border) & \ + (boxes[:, 3] < img_h + allowed_border) + else: + return (boxes[..., 0] < img_w + allowed_border) & \ + (boxes[..., 1] < img_h + allowed_border) & \ + (boxes[..., 2] > -allowed_border) & \ + (boxes[..., 3] > -allowed_border) + + def find_inside_points(self, + points: Tensor, + is_aligned: bool = False) -> BoolTensor: + """Find inside box points. Boxes dimension must be 2. + + Args: + points (Tensor): Points coordinates. Has shape of (m, 2). + is_aligned (bool): Whether ``points`` has been aligned with boxes + or not. If True, the length of boxes and ``points`` should be + the same. Defaults to False. + + Returns: + BoolTensor: A BoolTensor indicating whether a point is inside + boxes. Assuming the boxes has shape of (n, 4), if ``is_aligned`` + is False. The index has shape of (m, n). If ``is_aligned`` is + True, m should be equal to n and the index has shape of (m, ). + """ + boxes = self.tensor + assert boxes.dim() == 2, 'boxes dimension must be 2.' + + if not is_aligned: + boxes = boxes[None, :, :] + points = points[:, None, :] + else: + assert boxes.size(0) == points.size(0) + + x_min, y_min, x_max, y_max = boxes.unbind(dim=-1) + return (points[..., 0] >= x_min) & (points[..., 0] <= x_max) & \ + (points[..., 1] >= y_min) & (points[..., 1] <= y_max) + + def create_masks(self, img_shape: Tuple[int, int]) -> BitmapMasks: + """ + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + + Returns: + :obj:`BitmapMasks`: Converted masks + """ + img_h, img_w = img_shape + boxes = self.tensor + + xmin, ymin = boxes[:, 0:1], boxes[:, 1:2] + xmax, ymax = boxes[:, 2:3], boxes[:, 3:4] + gt_masks = np.zeros((len(boxes), img_h, img_w), dtype=np.uint8) + for i in range(len(boxes)): + gt_masks[i, + int(ymin[i]):int(ymax[i]), + int(xmin[i]):int(xmax[i])] = 1 + return BitmapMasks(gt_masks, img_h, img_w) + + @staticmethod + def overlaps(boxes1: BaseBoxes, + boxes2: BaseBoxes, + mode: str = 'iou', + is_aligned: bool = False, + eps: float = 1e-6) -> Tensor: + """Calculate overlap between two set of boxes with their types + converted to ``HorizontalBoxes``. + + Args: + boxes1 (:obj:`BaseBoxes`): BaseBoxes with shape of (m, box_dim) + or empty. + boxes2 (:obj:`BaseBoxes`): BaseBoxes with shape of (n, box_dim) + or empty. + mode (str): "iou" (intersection over union), "iof" (intersection + over foreground). Defaults to "iou". + is_aligned (bool): If True, then m and n must be equal. Defaults + to False. + eps (float): A value added to the denominator for numerical + stability. Defaults to 1e-6. + + Returns: + Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) + """ + boxes1 = boxes1.convert_to('hbox') + boxes2 = boxes2.convert_to('hbox') + return bbox_overlaps( + boxes1.tensor, + boxes2.tensor, + mode=mode, + is_aligned=is_aligned, + eps=eps) + + @staticmethod + def from_instance_masks(masks: MaskType) -> 'HorizontalBoxes': + """Create horizontal boxes from instance masks. + + Args: + masks (:obj:`BitmapMasks` or :obj:`PolygonMasks`): BitmapMasks or + PolygonMasks instance with length of n. + + Returns: + :obj:`HorizontalBoxes`: Converted boxes with shape of (n, 4). + """ + num_masks = len(masks) + boxes = np.zeros((num_masks, 4), dtype=np.float32) + if isinstance(masks, BitmapMasks): + x_any = masks.masks.any(axis=1) + y_any = masks.masks.any(axis=2) + for idx in range(num_masks): + x = np.where(x_any[idx, :])[0] + y = np.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + # use +1 for x_max and y_max so that the right and bottom + # boundary of instance masks are fully included by the box + boxes[idx, :] = np.array( + [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32) + elif isinstance(masks, PolygonMasks): + for idx, poly_per_obj in enumerate(masks.masks): + # simply use a number that is big enough for comparison with + # coordinates + xy_min = np.array([masks.width * 2, masks.height * 2], + dtype=np.float32) + xy_max = np.zeros(2, dtype=np.float32) + for p in poly_per_obj: + xy = np.array(p).reshape(-1, 2).astype(np.float32) + xy_min = np.minimum(xy_min, np.min(xy, axis=0)) + xy_max = np.maximum(xy_max, np.max(xy, axis=0)) + boxes[idx, :2] = xy_min + boxes[idx, 2:] = xy_max + else: + raise TypeError( + '`masks` must be `BitmapMasks` or `PolygonMasks`, ' + f'but got {type(masks)}.') + return HorizontalBoxes(boxes) diff --git a/mmdet/structures/bbox/transforms.py b/mmdet/structures/bbox/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..287e6aa6fcaeaf09a8a2838a04a97157cd02a00c --- /dev/null +++ b/mmdet/structures/bbox/transforms.py @@ -0,0 +1,498 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor + +from mmdet.structures.bbox import BaseBoxes + + +def find_inside_bboxes(bboxes: Tensor, img_h: int, img_w: int) -> Tensor: + """Find bboxes as long as a part of bboxes is inside the image. + + Args: + bboxes (Tensor): Shape (N, 4). + img_h (int): Image height. + img_w (int): Image width. + + Returns: + Tensor: Index of the remaining bboxes. + """ + inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \ + & (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0) + return inside_inds + + +def bbox_flip(bboxes: Tensor, + img_shape: Tuple[int], + direction: str = 'horizontal') -> Tensor: + """Flip bboxes horizontally or vertically. + + Args: + bboxes (Tensor): Shape (..., 4*k) + img_shape (Tuple[int]): Image shape. + direction (str): Flip direction, options are "horizontal", "vertical", + "diagonal". Default: "horizontal" + + Returns: + Tensor: Flipped bboxes. + """ + assert bboxes.shape[-1] % 4 == 0 + assert direction in ['horizontal', 'vertical', 'diagonal'] + flipped = bboxes.clone() + if direction == 'horizontal': + flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4] + flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4] + elif direction == 'vertical': + flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4] + flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4] + else: + flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4] + flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4] + flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4] + flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4] + return flipped + + +def bbox_mapping(bboxes: Tensor, + img_shape: Tuple[int], + scale_factor: Union[float, Tuple[float]], + flip: bool, + flip_direction: str = 'horizontal') -> Tensor: + """Map bboxes from the original image scale to testing scale.""" + new_bboxes = bboxes * bboxes.new_tensor(scale_factor) + if flip: + new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction) + return new_bboxes + + +def bbox_mapping_back(bboxes: Tensor, + img_shape: Tuple[int], + scale_factor: Union[float, Tuple[float]], + flip: bool, + flip_direction: str = 'horizontal') -> Tensor: + """Map bboxes from testing scale to original image scale.""" + new_bboxes = bbox_flip(bboxes, img_shape, + flip_direction) if flip else bboxes + new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor) + return new_bboxes.view(bboxes.shape) + + +def bbox2roi(bbox_list: List[Union[Tensor, BaseBoxes]]) -> Tensor: + """Convert a list of bboxes to roi format. + + Args: + bbox_list (List[Union[Tensor, :obj:`BaseBoxes`]): a list of bboxes + corresponding to a batch of images. + + Returns: + Tensor: shape (n, box_dim + 1), where ``box_dim`` depends on the + different box types. For example, If the box type in ``bbox_list`` + is HorizontalBoxes, the output shape is (n, 5). Each row of data + indicates [batch_ind, x1, y1, x2, y2]. + """ + rois_list = [] + for img_id, bboxes in enumerate(bbox_list): + bboxes = get_box_tensor(bboxes) + img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) + rois = torch.cat([img_inds, bboxes], dim=-1) + rois_list.append(rois) + rois = torch.cat(rois_list, 0) + return rois + + +def roi2bbox(rois: Tensor) -> List[Tensor]: + """Convert rois to bounding box format. + + Args: + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + List[Tensor]: Converted boxes of corresponding rois. + """ + bbox_list = [] + img_ids = torch.unique(rois[:, 0].cpu(), sorted=True) + for img_id in img_ids: + inds = (rois[:, 0] == img_id.item()) + bbox = rois[inds, 1:] + bbox_list.append(bbox) + return bbox_list + + +# TODO remove later +def bbox2result(bboxes: Union[Tensor, np.ndarray], labels: Union[Tensor, + np.ndarray], + num_classes: int) -> List[np.ndarray]: + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (Tensor | np.ndarray): shape (n, 5) + labels (Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, including background class + + Returns: + List(np.ndarray]): bbox results of each class + """ + if bboxes.shape[0] == 0: + return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] + else: + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + return [bboxes[labels == i, :] for i in range(num_classes)] + + +def distance2bbox( + points: Tensor, + distance: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, + Sequence[Sequence[int]]]] = None +) -> Tensor: + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (B, N, 2) or (N, 2). + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4) + max_shape (Union[Sequence[int], Tensor, Sequence[Sequence[int]]], + optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + + Returns: + Tensor: Boxes with shape (N, 4) or (B, N, 4) + """ + + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + + bboxes = torch.stack([x1, y1, x2, y2], -1) + + if max_shape is not None: + if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export(): + # speed up + bboxes[:, 0::2].clamp_(min=0, max=max_shape[1]) + bboxes[:, 1::2].clamp_(min=0, max=max_shape[0]) + return bboxes + + # clip bboxes with dynamic `min` and `max` for onnx + if torch.onnx.is_in_onnx_export(): + # TODO: delete + from mmdet.core.export import dynamic_clip_for_onnx + x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape) + bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return bboxes + if not isinstance(max_shape, torch.Tensor): + max_shape = x1.new_tensor(max_shape) + max_shape = max_shape[..., :2].type_as(x1) + if max_shape.ndim == 2: + assert bboxes.ndim == 3 + assert max_shape.size(0) == bboxes.size(0) + + min_xy = x1.new_tensor(0) + max_xy = torch.cat([max_shape, max_shape], + dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + return bboxes + + +def bbox2distance(points: Tensor, + bbox: Tensor, + max_dis: Optional[float] = None, + eps: float = 0.1) -> Tensor: + """Decode bounding box based on distances. + + Args: + points (Tensor): Shape (n, 2) or (b, n, 2), [x, y]. + bbox (Tensor): Shape (n, 4) or (b, n, 4), "xyxy" format + max_dis (float, optional): Upper bound of the distance. + eps (float): a small value to ensure target < max_dis, instead <= + + Returns: + Tensor: Decoded distances. + """ + left = points[..., 0] - bbox[..., 0] + top = points[..., 1] - bbox[..., 1] + right = bbox[..., 2] - points[..., 0] + bottom = bbox[..., 3] - points[..., 1] + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return torch.stack([left, top, right, bottom], -1) + + +def bbox_rescale(bboxes: Tensor, scale_factor: float = 1.0) -> Tensor: + """Rescale bounding box w.r.t. scale_factor. + + Args: + bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois + scale_factor (float): rescale factor + + Returns: + Tensor: Rescaled bboxes. + """ + if bboxes.size(1) == 5: + bboxes_ = bboxes[:, 1:] + inds_ = bboxes[:, 0] + else: + bboxes_ = bboxes + cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5 + cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5 + w = bboxes_[:, 2] - bboxes_[:, 0] + h = bboxes_[:, 3] - bboxes_[:, 1] + w = w * scale_factor + h = h * scale_factor + x1 = cx - 0.5 * w + x2 = cx + 0.5 * w + y1 = cy - 0.5 * h + y2 = cy + 0.5 * h + if bboxes.size(1) == 5: + rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1) + else: + rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return rescaled_bboxes + + +def bbox_cxcywh_to_xyxy(bbox: Tensor) -> Tensor: + """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)] + return torch.cat(bbox_new, dim=-1) + + +def bbox_xyxy_to_cxcywh(bbox: Tensor) -> Tensor: + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)] + return torch.cat(bbox_new, dim=-1) + + +def bbox2corner(bboxes: torch.Tensor) -> torch.Tensor: + """Convert bbox coordinates from (x1, y1, x2, y2) to corners ((x1, y1), + (x2, y1), (x1, y2), (x2, y2)). + + Args: + bboxes (Tensor): Shape (n, 4) for bboxes. + Returns: + Tensor: Shape (n*4, 2) for corners. + """ + x1, y1, x2, y2 = torch.split(bboxes, 1, dim=1) + return torch.cat([x1, y1, x2, y1, x1, y2, x2, y2], dim=1).reshape(-1, 2) + + +def corner2bbox(corners: torch.Tensor) -> torch.Tensor: + """Convert bbox coordinates from corners ((x1, y1), (x2, y1), (x1, y2), + (x2, y2)) to (x1, y1, x2, y2). + + Args: + corners (Tensor): Shape (n*4, 2) for corners. + Returns: + Tensor: Shape (n, 4) for bboxes. + """ + corners = corners.reshape(-1, 4, 2) + min_xy = corners.min(dim=1)[0] + max_xy = corners.max(dim=1)[0] + return torch.cat([min_xy, max_xy], dim=1) + + +def bbox_project( + bboxes: Union[torch.Tensor, np.ndarray], + homography_matrix: Union[torch.Tensor, np.ndarray], + img_shape: Optional[Tuple[int, int]] = None +) -> Union[torch.Tensor, np.ndarray]: + """Geometric transformation for bbox. + + Args: + bboxes (Union[torch.Tensor, np.ndarray]): Shape (n, 4) for bboxes. + homography_matrix (Union[torch.Tensor, np.ndarray]): + Shape (3, 3) for geometric transformation. + img_shape (Tuple[int, int], optional): Image shape. Defaults to None. + Returns: + Union[torch.Tensor, np.ndarray]: Converted bboxes. + """ + bboxes_type = type(bboxes) + if bboxes_type is np.ndarray: + bboxes = torch.from_numpy(bboxes) + if isinstance(homography_matrix, np.ndarray): + homography_matrix = torch.from_numpy(homography_matrix) + corners = bbox2corner(bboxes) + corners = torch.cat( + [corners, corners.new_ones(corners.shape[0], 1)], dim=1) + corners = torch.matmul(homography_matrix, corners.t()).t() + # Convert to homogeneous coordinates by normalization + corners = corners[:, :2] / corners[:, 2:3] + bboxes = corner2bbox(corners) + if img_shape is not None: + bboxes[:, 0::2] = bboxes[:, 0::2].clamp(0, img_shape[1]) + bboxes[:, 1::2] = bboxes[:, 1::2].clamp(0, img_shape[0]) + if bboxes_type is np.ndarray: + bboxes = bboxes.numpy() + return bboxes + + +def cat_boxes(data_list: List[Union[Tensor, BaseBoxes]], + dim: int = 0) -> Union[Tensor, BaseBoxes]: + """Concatenate boxes with type of tensor or box type. + + Args: + data_list (List[Union[Tensor, :obj:`BaseBoxes`]]): A list of tensors + or box types need to be concatenated. + dim (int): The dimension over which the box are concatenated. + Defaults to 0. + + Returns: + Union[Tensor, :obj`BaseBoxes`]: Concatenated results. + """ + if data_list and isinstance(data_list[0], BaseBoxes): + return data_list[0].cat(data_list, dim=dim) + else: + return torch.cat(data_list, dim=dim) + + +def stack_boxes(data_list: List[Union[Tensor, BaseBoxes]], + dim: int = 0) -> Union[Tensor, BaseBoxes]: + """Stack boxes with type of tensor or box type. + + Args: + data_list (List[Union[Tensor, :obj:`BaseBoxes`]]): A list of tensors + or box types need to be stacked. + dim (int): The dimension over which the box are stacked. + Defaults to 0. + + Returns: + Union[Tensor, :obj`BaseBoxes`]: Stacked results. + """ + if data_list and isinstance(data_list[0], BaseBoxes): + return data_list[0].stack(data_list, dim=dim) + else: + return torch.stack(data_list, dim=dim) + + +def scale_boxes(boxes: Union[Tensor, BaseBoxes], + scale_factor: Tuple[float, float]) -> Union[Tensor, BaseBoxes]: + """Scale boxes with type of tensor or box type. + + Args: + boxes (Tensor or :obj:`BaseBoxes`): boxes need to be scaled. Its type + can be a tensor or a box type. + scale_factor (Tuple[float, float]): factors for scaling boxes. + The length should be 2. + + Returns: + Union[Tensor, :obj:`BaseBoxes`]: Scaled boxes. + """ + if isinstance(boxes, BaseBoxes): + boxes.rescale_(scale_factor) + return boxes + else: + # Tensor boxes will be treated as horizontal boxes + repeat_num = int(boxes.size(-1) / 2) + scale_factor = boxes.new_tensor(scale_factor).repeat((1, repeat_num)) + return boxes * scale_factor + + +def get_box_wh(boxes: Union[Tensor, BaseBoxes]) -> Tuple[Tensor, Tensor]: + """Get the width and height of boxes with type of tensor or box type. + + Args: + boxes (Tensor or :obj:`BaseBoxes`): boxes with type of tensor + or box type. + + Returns: + Tuple[Tensor, Tensor]: the width and height of boxes. + """ + if isinstance(boxes, BaseBoxes): + w = boxes.widths + h = boxes.heights + else: + # Tensor boxes will be treated as horizontal boxes by defaults + w = boxes[:, 2] - boxes[:, 0] + h = boxes[:, 3] - boxes[:, 1] + return w, h + + +def get_box_tensor(boxes: Union[Tensor, BaseBoxes]) -> Tensor: + """Get tensor data from box type boxes. + + Args: + boxes (Tensor or BaseBoxes): boxes with type of tensor or box type. + If its type is a tensor, the boxes will be directly returned. + If its type is a box type, the `boxes.tensor` will be returned. + + Returns: + Tensor: boxes tensor. + """ + if isinstance(boxes, BaseBoxes): + boxes = boxes.tensor + return boxes + + +def empty_box_as(boxes: Union[Tensor, BaseBoxes]) -> Union[Tensor, BaseBoxes]: + """Generate empty box according to input ``boxes` type and device. + + Args: + boxes (Tensor or :obj:`BaseBoxes`): boxes with type of tensor + or box type. + + Returns: + Union[Tensor, BaseBoxes]: Generated empty box. + """ + if isinstance(boxes, BaseBoxes): + return boxes.empty_boxes() + else: + # Tensor boxes will be treated as horizontal boxes by defaults + return boxes.new_zeros(0, 4) + + +def bbox_xyxy_to_cxcyah(bboxes: torch.Tensor) -> torch.Tensor: + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, ratio, h). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + cx = (bboxes[:, 2] + bboxes[:, 0]) / 2 + cy = (bboxes[:, 3] + bboxes[:, 1]) / 2 + w = bboxes[:, 2] - bboxes[:, 0] + h = bboxes[:, 3] - bboxes[:, 1] + xyah = torch.stack([cx, cy, w / h, h], -1) + return xyah + + +def bbox_cxcyah_to_xyxy(bboxes: torch.Tensor) -> torch.Tensor: + """Convert bbox coordinates from (cx, cy, ratio, h) to (x1, y1, x2, y2). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + Returns: + Tensor: Converted bboxes. + """ + cx, cy, ratio, h = bboxes.split((1, 1, 1, 1), dim=-1) + w = ratio * h + x1y1x2y2 = [cx - w / 2.0, cy - h / 2.0, cx + w / 2.0, cy + h / 2.0] + return torch.cat(x1y1x2y2, dim=-1) diff --git a/mmdet/structures/det_data_sample.py b/mmdet/structures/det_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..37dd74725ed2ff5eb8a088c9d23a9ac5469b07a3 --- /dev/null +++ b/mmdet/structures/det_data_sample.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.structures import BaseDataElement, InstanceData, PixelData + + +class DetDataSample(BaseDataElement): + """A data structure interface of MMDetection. They are used as interfaces + between different components. + + The attributes in ``DetDataSample`` are divided into several parts: + + - ``proposals``(InstanceData): Region proposals used in two-stage + detectors. + - ``gt_instances``(InstanceData): Ground truth of instance annotations. + - ``pred_instances``(InstanceData): Instances of detection predictions. + - ``pred_track_instances``(InstanceData): Instances of tracking + predictions. + - ``ignored_instances``(InstanceData): Instances to be ignored during + training/testing. + - ``gt_panoptic_seg``(PixelData): Ground truth of panoptic + segmentation. + - ``pred_panoptic_seg``(PixelData): Prediction of panoptic + segmentation. + - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation. + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.structures import InstanceData + >>> from mmdet.structures import DetDataSample + + >>> data_sample = DetDataSample() + >>> img_meta = dict(img_shape=(800, 1196), + ... pad_shape=(800, 1216)) + >>> gt_instances = InstanceData(metainfo=img_meta) + >>> gt_instances.bboxes = torch.rand((5, 4)) + >>> gt_instances.labels = torch.rand((5,)) + >>> data_sample.gt_instances = gt_instances + >>> assert 'img_shape' in data_sample.gt_instances.metainfo_keys() + >>> len(data_sample.gt_instances) + 5 + >>> print(data_sample) + + ) at 0x7f21fb1b9880> + >>> pred_instances = InstanceData(metainfo=img_meta) + >>> pred_instances.bboxes = torch.rand((5, 4)) + >>> pred_instances.scores = torch.rand((5,)) + >>> data_sample = DetDataSample(pred_instances=pred_instances) + >>> assert 'pred_instances' in data_sample + + >>> pred_track_instances = InstanceData(metainfo=img_meta) + >>> pred_track_instances.bboxes = torch.rand((5, 4)) + >>> pred_track_instances.scores = torch.rand((5,)) + >>> data_sample = DetDataSample( + ... pred_track_instances=pred_track_instances) + >>> assert 'pred_track_instances' in data_sample + + >>> data_sample = DetDataSample() + >>> gt_instances_data = dict( + ... bboxes=torch.rand(2, 4), + ... labels=torch.rand(2), + ... masks=np.random.rand(2, 2, 2)) + >>> gt_instances = InstanceData(**gt_instances_data) + >>> data_sample.gt_instances = gt_instances + >>> assert 'gt_instances' in data_sample + >>> assert 'masks' in data_sample.gt_instances + + >>> data_sample = DetDataSample() + >>> gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(2, 4)) + >>> gt_panoptic_seg = PixelData(**gt_panoptic_seg_data) + >>> data_sample.gt_panoptic_seg = gt_panoptic_seg + >>> print(data_sample) + + gt_panoptic_seg: + ) at 0x7f66c2bb7280> + >>> data_sample = DetDataSample() + >>> gt_segm_seg_data = dict(segm_seg=torch.rand(2, 2, 2)) + >>> gt_segm_seg = PixelData(**gt_segm_seg_data) + >>> data_sample.gt_segm_seg = gt_segm_seg + >>> assert 'gt_segm_seg' in data_sample + >>> assert 'segm_seg' in data_sample.gt_segm_seg + """ + + @property + def proposals(self) -> InstanceData: + return self._proposals + + @proposals.setter + def proposals(self, value: InstanceData): + self.set_field(value, '_proposals', dtype=InstanceData) + + @proposals.deleter + def proposals(self): + del self._proposals + + @property + def gt_instances(self) -> InstanceData: + return self._gt_instances + + @gt_instances.setter + def gt_instances(self, value: InstanceData): + self.set_field(value, '_gt_instances', dtype=InstanceData) + + @gt_instances.deleter + def gt_instances(self): + del self._gt_instances + + @property + def pred_instances(self) -> InstanceData: + return self._pred_instances + + @pred_instances.setter + def pred_instances(self, value: InstanceData): + self.set_field(value, '_pred_instances', dtype=InstanceData) + + @pred_instances.deleter + def pred_instances(self): + del self._pred_instances + + # directly add ``pred_track_instances`` in ``DetDataSample`` + # so that the ``TrackDataSample`` does not bother to access the + # instance-level information. + @property + def pred_track_instances(self) -> InstanceData: + return self._pred_track_instances + + @pred_track_instances.setter + def pred_track_instances(self, value: InstanceData): + self.set_field(value, '_pred_track_instances', dtype=InstanceData) + + @pred_track_instances.deleter + def pred_track_instances(self): + del self._pred_track_instances + + @property + def ignored_instances(self) -> InstanceData: + return self._ignored_instances + + @ignored_instances.setter + def ignored_instances(self, value: InstanceData): + self.set_field(value, '_ignored_instances', dtype=InstanceData) + + @ignored_instances.deleter + def ignored_instances(self): + del self._ignored_instances + + @property + def gt_panoptic_seg(self) -> PixelData: + return self._gt_panoptic_seg + + @gt_panoptic_seg.setter + def gt_panoptic_seg(self, value: PixelData): + self.set_field(value, '_gt_panoptic_seg', dtype=PixelData) + + @gt_panoptic_seg.deleter + def gt_panoptic_seg(self): + del self._gt_panoptic_seg + + @property + def pred_panoptic_seg(self) -> PixelData: + return self._pred_panoptic_seg + + @pred_panoptic_seg.setter + def pred_panoptic_seg(self, value: PixelData): + self.set_field(value, '_pred_panoptic_seg', dtype=PixelData) + + @pred_panoptic_seg.deleter + def pred_panoptic_seg(self): + del self._pred_panoptic_seg + + @property + def gt_sem_seg(self) -> PixelData: + return self._gt_sem_seg + + @gt_sem_seg.setter + def gt_sem_seg(self, value: PixelData): + self.set_field(value, '_gt_sem_seg', dtype=PixelData) + + @gt_sem_seg.deleter + def gt_sem_seg(self): + del self._gt_sem_seg + + @property + def pred_sem_seg(self) -> PixelData: + return self._pred_sem_seg + + @pred_sem_seg.setter + def pred_sem_seg(self, value: PixelData): + self.set_field(value, '_pred_sem_seg', dtype=PixelData) + + @pred_sem_seg.deleter + def pred_sem_seg(self): + del self._pred_sem_seg + + +SampleList = List[DetDataSample] +OptSampleList = Optional[SampleList] diff --git a/mmdet/structures/mask/__init__.py b/mmdet/structures/mask/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f78394701df1b493259c4c23a79aea5c5cb8be95 --- /dev/null +++ b/mmdet/structures/mask/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mask_target import mask_target +from .structures import (BaseInstanceMasks, BitmapMasks, PolygonMasks, + bitmap_to_polygon, polygon_to_bitmap) +from .utils import encode_mask_results, mask2bbox, split_combined_polys + +__all__ = [ + 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks', + 'PolygonMasks', 'encode_mask_results', 'mask2bbox', 'polygon_to_bitmap', + 'bitmap_to_polygon' +] diff --git a/mmdet/structures/mask/__pycache__/__init__.cpython-311.pyc b/mmdet/structures/mask/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c9e4aab6561077b6c966d51e953a514144b02ee Binary files /dev/null and b/mmdet/structures/mask/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmdet/structures/mask/__pycache__/mask_target.cpython-311.pyc b/mmdet/structures/mask/__pycache__/mask_target.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..237f2912a7d921d87523ec2bade87c759d1737d8 Binary files /dev/null and b/mmdet/structures/mask/__pycache__/mask_target.cpython-311.pyc differ diff --git a/mmdet/structures/mask/__pycache__/structures.cpython-311.pyc b/mmdet/structures/mask/__pycache__/structures.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca56b23743d08ae82abb61ec7d3372b5411f4b85 Binary files /dev/null and b/mmdet/structures/mask/__pycache__/structures.cpython-311.pyc differ diff --git a/mmdet/structures/mask/mask_target.py b/mmdet/structures/mask/mask_target.py new file mode 100644 index 0000000000000000000000000000000000000000..b2fc5f1878300446b114c9f57c6a885fea8c927c --- /dev/null +++ b/mmdet/structures/mask/mask_target.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from torch.nn.modules.utils import _pair + + +def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, + cfg): + """Compute mask target for positive proposals in multiple images. + + Args: + pos_proposals_list (list[Tensor]): Positive proposals in multiple + images, each has shape (num_pos, 4). + pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each + positive proposals, each has shape (num_pos,). + gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of + each image. + cfg (dict): Config dict that specifies the mask size. + + Returns: + Tensor: Mask target of each image, has shape (num_pos, w, h). + + Example: + >>> from mmengine.config import Config + >>> import mmdet + >>> from mmdet.data_elements.mask import BitmapMasks + >>> from mmdet.data_elements.mask.mask_target import * + >>> H, W = 17, 18 + >>> cfg = Config({'mask_size': (13, 14)}) + >>> rng = np.random.RandomState(0) + >>> # Positive proposals (tl_x, tl_y, br_x, br_y) for each image + >>> pos_proposals_list = [ + >>> torch.Tensor([ + >>> [ 7.2425, 5.5929, 13.9414, 14.9541], + >>> [ 7.3241, 3.6170, 16.3850, 15.3102], + >>> ]), + >>> torch.Tensor([ + >>> [ 4.8448, 6.4010, 7.0314, 9.7681], + >>> [ 5.9790, 2.6989, 7.4416, 4.8580], + >>> [ 0.0000, 0.0000, 0.1398, 9.8232], + >>> ]), + >>> ] + >>> # Corresponding class index for each proposal for each image + >>> pos_assigned_gt_inds_list = [ + >>> torch.LongTensor([7, 0]), + >>> torch.LongTensor([5, 4, 1]), + >>> ] + >>> # Ground truth mask for each true object for each image + >>> gt_masks_list = [ + >>> BitmapMasks(rng.rand(8, H, W), height=H, width=W), + >>> BitmapMasks(rng.rand(6, H, W), height=H, width=W), + >>> ] + >>> mask_targets = mask_target( + >>> pos_proposals_list, pos_assigned_gt_inds_list, + >>> gt_masks_list, cfg) + >>> assert mask_targets.shape == (5,) + cfg['mask_size'] + """ + cfg_list = [cfg for _ in range(len(pos_proposals_list))] + mask_targets = map(mask_target_single, pos_proposals_list, + pos_assigned_gt_inds_list, gt_masks_list, cfg_list) + mask_targets = list(mask_targets) + if len(mask_targets) > 0: + mask_targets = torch.cat(mask_targets) + return mask_targets + + +def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg): + """Compute mask target for each positive proposal in the image. + + Args: + pos_proposals (Tensor): Positive proposals. + pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals. + gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap + or Polygon. + cfg (dict): Config dict that indicate the mask size. + + Returns: + Tensor: Mask target of each positive proposals in the image. + + Example: + >>> from mmengine.config import Config + >>> import mmdet + >>> from mmdet.data_elements.mask import BitmapMasks + >>> from mmdet.data_elements.mask.mask_target import * # NOQA + >>> H, W = 32, 32 + >>> cfg = Config({'mask_size': (7, 11)}) + >>> rng = np.random.RandomState(0) + >>> # Masks for each ground truth box (relative to the image) + >>> gt_masks_data = rng.rand(3, H, W) + >>> gt_masks = BitmapMasks(gt_masks_data, height=H, width=W) + >>> # Predicted positive boxes in one image + >>> pos_proposals = torch.FloatTensor([ + >>> [ 16.2, 5.5, 19.9, 20.9], + >>> [ 17.3, 13.6, 19.3, 19.3], + >>> [ 14.8, 16.4, 17.0, 23.7], + >>> [ 0.0, 0.0, 16.0, 16.0], + >>> [ 4.0, 0.0, 20.0, 16.0], + >>> ]) + >>> # For each predicted proposal, its assignment to a gt mask + >>> pos_assigned_gt_inds = torch.LongTensor([0, 1, 2, 1, 1]) + >>> mask_targets = mask_target_single( + >>> pos_proposals, pos_assigned_gt_inds, gt_masks, cfg) + >>> assert mask_targets.shape == (5,) + cfg['mask_size'] + """ + device = pos_proposals.device + mask_size = _pair(cfg.mask_size) + binarize = not cfg.get('soft_mask_target', False) + num_pos = pos_proposals.size(0) + if num_pos > 0: + proposals_np = pos_proposals.cpu().numpy() + maxh, maxw = gt_masks.height, gt_masks.width + proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw) + proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh) + pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() + + mask_targets = gt_masks.crop_and_resize( + proposals_np, + mask_size, + device=device, + inds=pos_assigned_gt_inds, + binarize=binarize).to_ndarray() + + mask_targets = torch.from_numpy(mask_targets).float().to(device) + else: + mask_targets = pos_proposals.new_zeros((0, ) + mask_size) + + return mask_targets diff --git a/mmdet/structures/mask/structures.py b/mmdet/structures/mask/structures.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fdd27570b0d11d92eba4e8f854e153750135a4 --- /dev/null +++ b/mmdet/structures/mask/structures.py @@ -0,0 +1,1193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from abc import ABCMeta, abstractmethod +from typing import Sequence, Type, TypeVar + +import cv2 +import mmcv +import numpy as np +import pycocotools.mask as maskUtils +import shapely.geometry as geometry +import torch +from mmcv.ops.roi_align import roi_align + +T = TypeVar('T') + + +class BaseInstanceMasks(metaclass=ABCMeta): + """Base class for instance masks.""" + + @abstractmethod + def rescale(self, scale, interpolation='nearest'): + """Rescale masks as large as possible while keeping the aspect ratio. + For details can refer to `mmcv.imrescale`. + + Args: + scale (tuple[int]): The maximum size (h, w) of rescaled mask. + interpolation (str): Same as :func:`mmcv.imrescale`. + + Returns: + BaseInstanceMasks: The rescaled masks. + """ + + @abstractmethod + def resize(self, out_shape, interpolation='nearest'): + """Resize masks to the given out_shape. + + Args: + out_shape: Target (h, w) of resized mask. + interpolation (str): See :func:`mmcv.imresize`. + + Returns: + BaseInstanceMasks: The resized masks. + """ + + @abstractmethod + def flip(self, flip_direction='horizontal'): + """Flip masks alone the given direction. + + Args: + flip_direction (str): Either 'horizontal' or 'vertical'. + + Returns: + BaseInstanceMasks: The flipped masks. + """ + + @abstractmethod + def pad(self, out_shape, pad_val): + """Pad masks to the given size of (h, w). + + Args: + out_shape (tuple[int]): Target (h, w) of padded mask. + pad_val (int): The padded value. + + Returns: + BaseInstanceMasks: The padded masks. + """ + + @abstractmethod + def crop(self, bbox): + """Crop each mask by the given bbox. + + Args: + bbox (ndarray): Bbox in format [x1, y1, x2, y2], shape (4, ). + + Return: + BaseInstanceMasks: The cropped masks. + """ + + @abstractmethod + def crop_and_resize(self, + bboxes, + out_shape, + inds, + device, + interpolation='bilinear', + binarize=True): + """Crop and resize masks by the given bboxes. + + This function is mainly used in mask targets computation. + It firstly align mask to bboxes by assigned_inds, then crop mask by the + assigned bbox and resize to the size of (mask_h, mask_w) + + Args: + bboxes (Tensor): Bboxes in format [x1, y1, x2, y2], shape (N, 4) + out_shape (tuple[int]): Target (h, w) of resized mask + inds (ndarray): Indexes to assign masks to each bbox, + shape (N,) and values should be between [0, num_masks - 1]. + device (str): Device of bboxes + interpolation (str): See `mmcv.imresize` + binarize (bool): if True fractional values are rounded to 0 or 1 + after the resize operation. if False and unsupported an error + will be raised. Defaults to True. + + Return: + BaseInstanceMasks: the cropped and resized masks. + """ + + @abstractmethod + def expand(self, expanded_h, expanded_w, top, left): + """see :class:`Expand`.""" + + @property + @abstractmethod + def areas(self): + """ndarray: areas of each instance.""" + + @abstractmethod + def to_ndarray(self): + """Convert masks to the format of ndarray. + + Return: + ndarray: Converted masks in the format of ndarray. + """ + + @abstractmethod + def to_tensor(self, dtype, device): + """Convert masks to the format of Tensor. + + Args: + dtype (str): Dtype of converted mask. + device (torch.device): Device of converted masks. + + Returns: + Tensor: Converted masks in the format of Tensor. + """ + + @abstractmethod + def translate(self, + out_shape, + offset, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Translate the masks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + offset (int | float): The offset for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + border_value (int | float): Border value. Default 0. + interpolation (str): Same as :func:`mmcv.imtranslate`. + + Returns: + Translated masks. + """ + + def shear(self, + out_shape, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear the masks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + magnitude (int | float): The magnitude used for shear. + direction (str): The shear direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. Default 0. + interpolation (str): Same as in :func:`mmcv.imshear`. + + Returns: + ndarray: Sheared masks. + """ + + @abstractmethod + def rotate(self, out_shape, angle, center=None, scale=1.0, border_value=0): + """Rotate the masks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + angle (int | float): Rotation angle in degrees. Positive values + mean counter-clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the + rotation in source image. If not specified, the center of + the image will be used. + scale (int | float): Isotropic scale factor. + border_value (int | float): Border value. Default 0 for masks. + + Returns: + Rotated masks. + """ + + def get_bboxes(self, dst_type='hbb'): + """Get the certain type boxes from masks. + + Please refer to ``mmdet.structures.bbox.box_type`` for more details of + the box type. + + Args: + dst_type: Destination box type. + + Returns: + :obj:`BaseBoxes`: Certain type boxes. + """ + from ..bbox import get_box_type + _, box_type_cls = get_box_type(dst_type) + return box_type_cls.from_instance_masks(self) + + @classmethod + @abstractmethod + def cat(cls: Type[T], masks: Sequence[T]) -> T: + """Concatenate a sequence of masks into one single mask instance. + + Args: + masks (Sequence[T]): A sequence of mask instances. + + Returns: + T: Concatenated mask instance. + """ + + +class BitmapMasks(BaseInstanceMasks): + """This class represents masks in the form of bitmaps. + + Args: + masks (ndarray): ndarray of masks in shape (N, H, W), where N is + the number of objects. + height (int): height of masks + width (int): width of masks + + Example: + >>> from mmdet.data_elements.mask.structures import * # NOQA + >>> num_masks, H, W = 3, 32, 32 + >>> rng = np.random.RandomState(0) + >>> masks = (rng.rand(num_masks, H, W) > 0.1).astype(np.int64) + >>> self = BitmapMasks(masks, height=H, width=W) + + >>> # demo crop_and_resize + >>> num_boxes = 5 + >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes) + >>> out_shape = (14, 14) + >>> inds = torch.randint(0, len(self), size=(num_boxes,)) + >>> device = 'cpu' + >>> interpolation = 'bilinear' + >>> new = self.crop_and_resize( + ... bboxes, out_shape, inds, device, interpolation) + >>> assert len(new) == num_boxes + >>> assert new.height, new.width == out_shape + """ + + def __init__(self, masks, height, width): + self.height = height + self.width = width + if len(masks) == 0: + self.masks = np.empty((0, self.height, self.width), dtype=np.uint8) + else: + assert isinstance(masks, (list, np.ndarray)) + if isinstance(masks, list): + assert isinstance(masks[0], np.ndarray) + assert masks[0].ndim == 2 # (H, W) + else: + assert masks.ndim == 3 # (N, H, W) + + self.masks = np.stack(masks).reshape(-1, height, width) + assert self.masks.shape[1] == self.height + assert self.masks.shape[2] == self.width + + def __getitem__(self, index): + """Index the BitmapMask. + + Args: + index (int | ndarray): Indices in the format of integer or ndarray. + + Returns: + :obj:`BitmapMasks`: Indexed bitmap masks. + """ + masks = self.masks[index].reshape(-1, self.height, self.width) + return BitmapMasks(masks, self.height, self.width) + + def __iter__(self): + return iter(self.masks) + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += f'num_masks={len(self.masks)}, ' + s += f'height={self.height}, ' + s += f'width={self.width})' + return s + + def __len__(self): + """Number of masks.""" + return len(self.masks) + + def rescale(self, scale, interpolation='nearest'): + """See :func:`BaseInstanceMasks.rescale`.""" + if len(self.masks) == 0: + new_w, new_h = mmcv.rescale_size((self.width, self.height), scale) + rescaled_masks = np.empty((0, new_h, new_w), dtype=np.uint8) + else: + rescaled_masks = np.stack([ + mmcv.imrescale(mask, scale, interpolation=interpolation) + for mask in self.masks + ]) + height, width = rescaled_masks.shape[1:] + return BitmapMasks(rescaled_masks, height, width) + + def resize(self, out_shape, interpolation='nearest'): + """See :func:`BaseInstanceMasks.resize`.""" + if len(self.masks) == 0: + resized_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + resized_masks = np.stack([ + mmcv.imresize( + mask, out_shape[::-1], interpolation=interpolation) + for mask in self.masks + ]) + return BitmapMasks(resized_masks, *out_shape) + + def flip(self, flip_direction='horizontal'): + """See :func:`BaseInstanceMasks.flip`.""" + assert flip_direction in ('horizontal', 'vertical', 'diagonal') + + if len(self.masks) == 0: + flipped_masks = self.masks + else: + flipped_masks = np.stack([ + mmcv.imflip(mask, direction=flip_direction) + for mask in self.masks + ]) + return BitmapMasks(flipped_masks, self.height, self.width) + + def pad(self, out_shape, pad_val=0): + """See :func:`BaseInstanceMasks.pad`.""" + if len(self.masks) == 0: + padded_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + padded_masks = np.stack([ + mmcv.impad(mask, shape=out_shape, pad_val=pad_val) + for mask in self.masks + ]) + return BitmapMasks(padded_masks, *out_shape) + + def crop(self, bbox): + """See :func:`BaseInstanceMasks.crop`.""" + assert isinstance(bbox, np.ndarray) + assert bbox.ndim == 1 + + # clip the boundary + bbox = bbox.copy() + bbox[0::2] = np.clip(bbox[0::2], 0, self.width) + bbox[1::2] = np.clip(bbox[1::2], 0, self.height) + x1, y1, x2, y2 = bbox + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) + + if len(self.masks) == 0: + cropped_masks = np.empty((0, h, w), dtype=np.uint8) + else: + cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w] + return BitmapMasks(cropped_masks, h, w) + + def crop_and_resize(self, + bboxes, + out_shape, + inds, + device='cpu', + interpolation='bilinear', + binarize=True): + """See :func:`BaseInstanceMasks.crop_and_resize`.""" + if len(self.masks) == 0: + empty_masks = np.empty((0, *out_shape), dtype=np.uint8) + return BitmapMasks(empty_masks, *out_shape) + + # convert bboxes to tensor + if isinstance(bboxes, np.ndarray): + bboxes = torch.from_numpy(bboxes).to(device=device) + if isinstance(inds, np.ndarray): + inds = torch.from_numpy(inds).to(device=device) + + num_bbox = bboxes.shape[0] + fake_inds = torch.arange( + num_bbox, device=device).to(dtype=bboxes.dtype)[:, None] + rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5 + rois = rois.to(device=device) + if num_bbox > 0: + gt_masks_th = torch.from_numpy(self.masks).to(device).index_select( + 0, inds).to(dtype=rois.dtype) + targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape, + 1.0, 0, 'avg', True).squeeze(1) + if binarize: + resized_masks = (targets >= 0.5).cpu().numpy() + else: + resized_masks = targets.cpu().numpy() + else: + resized_masks = [] + return BitmapMasks(resized_masks, *out_shape) + + def expand(self, expanded_h, expanded_w, top, left): + """See :func:`BaseInstanceMasks.expand`.""" + if len(self.masks) == 0: + expanded_mask = np.empty((0, expanded_h, expanded_w), + dtype=np.uint8) + else: + expanded_mask = np.zeros((len(self), expanded_h, expanded_w), + dtype=np.uint8) + expanded_mask[:, top:top + self.height, + left:left + self.width] = self.masks + return BitmapMasks(expanded_mask, expanded_h, expanded_w) + + def translate(self, + out_shape, + offset, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Translate the BitmapMasks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + offset (int | float): The offset for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + border_value (int | float): Border value. Default 0 for masks. + interpolation (str): Same as :func:`mmcv.imtranslate`. + + Returns: + BitmapMasks: Translated BitmapMasks. + + Example: + >>> from mmdet.data_elements.mask.structures import BitmapMasks + >>> self = BitmapMasks.random(dtype=np.uint8) + >>> out_shape = (32, 32) + >>> offset = 4 + >>> direction = 'horizontal' + >>> border_value = 0 + >>> interpolation = 'bilinear' + >>> # Note, There seem to be issues when: + >>> # * the mask dtype is not supported by cv2.AffineWarp + >>> new = self.translate(out_shape, offset, direction, + >>> border_value, interpolation) + >>> assert len(new) == len(self) + >>> assert new.height, new.width == out_shape + """ + if len(self.masks) == 0: + translated_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + masks = self.masks + if masks.shape[-2:] != out_shape: + empty_masks = np.zeros((masks.shape[0], *out_shape), + dtype=masks.dtype) + min_h = min(out_shape[0], masks.shape[1]) + min_w = min(out_shape[1], masks.shape[2]) + empty_masks[:, :min_h, :min_w] = masks[:, :min_h, :min_w] + masks = empty_masks + translated_masks = mmcv.imtranslate( + masks.transpose((1, 2, 0)), + offset, + direction, + border_value=border_value, + interpolation=interpolation) + if translated_masks.ndim == 2: + translated_masks = translated_masks[:, :, None] + translated_masks = translated_masks.transpose( + (2, 0, 1)).astype(self.masks.dtype) + return BitmapMasks(translated_masks, *out_shape) + + def shear(self, + out_shape, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear the BitmapMasks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + magnitude (int | float): The magnitude used for shear. + direction (str): The shear direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as in :func:`mmcv.imshear`. + + Returns: + BitmapMasks: The sheared masks. + """ + if len(self.masks) == 0: + sheared_masks = np.empty((0, *out_shape), dtype=np.uint8) + else: + sheared_masks = mmcv.imshear( + self.masks.transpose((1, 2, 0)), + magnitude, + direction, + border_value=border_value, + interpolation=interpolation) + if sheared_masks.ndim == 2: + sheared_masks = sheared_masks[:, :, None] + sheared_masks = sheared_masks.transpose( + (2, 0, 1)).astype(self.masks.dtype) + return BitmapMasks(sheared_masks, *out_shape) + + def rotate(self, + out_shape, + angle, + center=None, + scale=1.0, + border_value=0, + interpolation='bilinear'): + """Rotate the BitmapMasks. + + Args: + out_shape (tuple[int]): Shape for output mask, format (h, w). + angle (int | float): Rotation angle in degrees. Positive values + mean counter-clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the + rotation in source image. If not specified, the center of + the image will be used. + scale (int | float): Isotropic scale factor. + border_value (int | float): Border value. Default 0 for masks. + interpolation (str): Same as in :func:`mmcv.imrotate`. + + Returns: + BitmapMasks: Rotated BitmapMasks. + """ + if len(self.masks) == 0: + rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype) + else: + rotated_masks = mmcv.imrotate( + self.masks.transpose((1, 2, 0)), + angle, + center=center, + scale=scale, + border_value=border_value, + interpolation=interpolation) + if rotated_masks.ndim == 2: + # case when only one mask, (h, w) + rotated_masks = rotated_masks[:, :, None] # (h, w, 1) + rotated_masks = rotated_masks.transpose( + (2, 0, 1)).astype(self.masks.dtype) + return BitmapMasks(rotated_masks, *out_shape) + + @property + def areas(self): + """See :py:attr:`BaseInstanceMasks.areas`.""" + return self.masks.sum((1, 2)) + + def to_ndarray(self): + """See :func:`BaseInstanceMasks.to_ndarray`.""" + return self.masks + + def to_tensor(self, dtype, device): + """See :func:`BaseInstanceMasks.to_tensor`.""" + return torch.tensor(self.masks, dtype=dtype, device=device) + + @classmethod + def random(cls, + num_masks=3, + height=32, + width=32, + dtype=np.uint8, + rng=None): + """Generate random bitmap masks for demo / testing purposes. + + Example: + >>> from mmdet.data_elements.mask.structures import BitmapMasks + >>> self = BitmapMasks.random() + >>> print('self = {}'.format(self)) + self = BitmapMasks(num_masks=3, height=32, width=32) + """ + from mmdet.utils.util_random import ensure_rng + rng = ensure_rng(rng) + masks = (rng.rand(num_masks, height, width) > 0.1).astype(dtype) + self = cls(masks, height=height, width=width) + return self + + @classmethod + def cat(cls: Type[T], masks: Sequence[T]) -> T: + """Concatenate a sequence of masks into one single mask instance. + + Args: + masks (Sequence[BitmapMasks]): A sequence of mask instances. + + Returns: + BitmapMasks: Concatenated mask instance. + """ + assert isinstance(masks, Sequence) + if len(masks) == 0: + raise ValueError('masks should not be an empty list.') + assert all(isinstance(m, cls) for m in masks) + + mask_array = np.concatenate([m.masks for m in masks], axis=0) + return cls(mask_array, *mask_array.shape[1:]) + + +class PolygonMasks(BaseInstanceMasks): + """This class represents masks in the form of polygons. + + Polygons is a list of three levels. The first level of the list + corresponds to objects, the second level to the polys that compose the + object, the third level to the poly coordinates + + Args: + masks (list[list[ndarray]]): The first level of the list + corresponds to objects, the second level to the polys that + compose the object, the third level to the poly coordinates + height (int): height of masks + width (int): width of masks + + Example: + >>> from mmdet.data_elements.mask.structures import * # NOQA + >>> masks = [ + >>> [ np.array([0, 0, 10, 0, 10, 10., 0, 10, 0, 0]) ] + >>> ] + >>> height, width = 16, 16 + >>> self = PolygonMasks(masks, height, width) + + >>> # demo translate + >>> new = self.translate((16, 16), 4., direction='horizontal') + >>> assert np.all(new.masks[0][0][1::2] == masks[0][0][1::2]) + >>> assert np.all(new.masks[0][0][0::2] == masks[0][0][0::2] + 4) + + >>> # demo crop_and_resize + >>> num_boxes = 3 + >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes) + >>> out_shape = (16, 16) + >>> inds = torch.randint(0, len(self), size=(num_boxes,)) + >>> device = 'cpu' + >>> interpolation = 'bilinear' + >>> new = self.crop_and_resize( + ... bboxes, out_shape, inds, device, interpolation) + >>> assert len(new) == num_boxes + >>> assert new.height, new.width == out_shape + """ + + def __init__(self, masks, height, width): + assert isinstance(masks, list) + if len(masks) > 0: + assert isinstance(masks[0], list) + assert isinstance(masks[0][0], np.ndarray) + + self.height = height + self.width = width + self.masks = masks + + def __getitem__(self, index): + """Index the polygon masks. + + Args: + index (ndarray | List): The indices. + + Returns: + :obj:`PolygonMasks`: The indexed polygon masks. + """ + if isinstance(index, np.ndarray): + if index.dtype == bool: + index = np.where(index)[0].tolist() + else: + index = index.tolist() + if isinstance(index, list): + masks = [self.masks[i] for i in index] + else: + try: + masks = self.masks[index] + except Exception: + raise ValueError( + f'Unsupported input of type {type(index)} for indexing!') + if len(masks) and isinstance(masks[0], np.ndarray): + masks = [masks] # ensure a list of three levels + return PolygonMasks(masks, self.height, self.width) + + def __iter__(self): + return iter(self.masks) + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += f'num_masks={len(self.masks)}, ' + s += f'height={self.height}, ' + s += f'width={self.width})' + return s + + def __len__(self): + """Number of masks.""" + return len(self.masks) + + def rescale(self, scale, interpolation=None): + """see :func:`BaseInstanceMasks.rescale`""" + new_w, new_h = mmcv.rescale_size((self.width, self.height), scale) + if len(self.masks) == 0: + rescaled_masks = PolygonMasks([], new_h, new_w) + else: + rescaled_masks = self.resize((new_h, new_w)) + return rescaled_masks + + def resize(self, out_shape, interpolation=None): + """see :func:`BaseInstanceMasks.resize`""" + if len(self.masks) == 0: + resized_masks = PolygonMasks([], *out_shape) + else: + h_scale = out_shape[0] / self.height + w_scale = out_shape[1] / self.width + resized_masks = [] + for poly_per_obj in self.masks: + resized_poly = [] + for p in poly_per_obj: + p = p.copy() + p[0::2] = p[0::2] * w_scale + p[1::2] = p[1::2] * h_scale + resized_poly.append(p) + resized_masks.append(resized_poly) + resized_masks = PolygonMasks(resized_masks, *out_shape) + return resized_masks + + def flip(self, flip_direction='horizontal'): + """see :func:`BaseInstanceMasks.flip`""" + assert flip_direction in ('horizontal', 'vertical', 'diagonal') + if len(self.masks) == 0: + flipped_masks = PolygonMasks([], self.height, self.width) + else: + flipped_masks = [] + for poly_per_obj in self.masks: + flipped_poly_per_obj = [] + for p in poly_per_obj: + p = p.copy() + if flip_direction == 'horizontal': + p[0::2] = self.width - p[0::2] + elif flip_direction == 'vertical': + p[1::2] = self.height - p[1::2] + else: + p[0::2] = self.width - p[0::2] + p[1::2] = self.height - p[1::2] + flipped_poly_per_obj.append(p) + flipped_masks.append(flipped_poly_per_obj) + flipped_masks = PolygonMasks(flipped_masks, self.height, + self.width) + return flipped_masks + + def crop(self, bbox): + """see :func:`BaseInstanceMasks.crop`""" + assert isinstance(bbox, np.ndarray) + assert bbox.ndim == 1 + + # clip the boundary + bbox = bbox.copy() + bbox[0::2] = np.clip(bbox[0::2], 0, self.width) + bbox[1::2] = np.clip(bbox[1::2], 0, self.height) + x1, y1, x2, y2 = bbox + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) + + if len(self.masks) == 0: + cropped_masks = PolygonMasks([], h, w) + else: + # reference: https://github.com/facebookresearch/fvcore/blob/main/fvcore/transforms/transform.py # noqa + crop_box = geometry.box(x1, y1, x2, y2).buffer(0.0) + cropped_masks = [] + # suppress shapely warnings util it incorporates GEOS>=3.11.2 + # reference: https://github.com/shapely/shapely/issues/1345 + initial_settings = np.seterr() + np.seterr(invalid='ignore') + for poly_per_obj in self.masks: + cropped_poly_per_obj = [] + for p in poly_per_obj: + p = p.copy() + p = geometry.Polygon(p.reshape(-1, 2)).buffer(0.0) + # polygon must be valid to perform intersection. + if not p.is_valid: + continue + cropped = p.intersection(crop_box) + if cropped.is_empty: + continue + if isinstance(cropped, + geometry.collection.BaseMultipartGeometry): + cropped = cropped.geoms + else: + cropped = [cropped] + # one polygon may be cropped to multiple ones + for poly in cropped: + # ignore lines or points + if not isinstance( + poly, geometry.Polygon) or not poly.is_valid: + continue + coords = np.asarray(poly.exterior.coords) + # remove an extra identical vertex at the end + coords = coords[:-1] + coords[:, 0] -= x1 + coords[:, 1] -= y1 + cropped_poly_per_obj.append(coords.reshape(-1)) + # a dummy polygon to avoid misalignment between masks and boxes + if len(cropped_poly_per_obj) == 0: + cropped_poly_per_obj = [np.array([0, 0, 0, 0, 0, 0])] + cropped_masks.append(cropped_poly_per_obj) + np.seterr(**initial_settings) + cropped_masks = PolygonMasks(cropped_masks, h, w) + return cropped_masks + + def pad(self, out_shape, pad_val=0): + """padding has no effect on polygons`""" + return PolygonMasks(self.masks, *out_shape) + + def expand(self, *args, **kwargs): + """TODO: Add expand for polygon""" + raise NotImplementedError + + def crop_and_resize(self, + bboxes, + out_shape, + inds, + device='cpu', + interpolation='bilinear', + binarize=True): + """see :func:`BaseInstanceMasks.crop_and_resize`""" + out_h, out_w = out_shape + if len(self.masks) == 0: + return PolygonMasks([], out_h, out_w) + + if not binarize: + raise ValueError('Polygons are always binary, ' + 'setting binarize=False is unsupported') + + resized_masks = [] + for i in range(len(bboxes)): + mask = self.masks[inds[i]] + bbox = bboxes[i, :] + x1, y1, x2, y2 = bbox + w = np.maximum(x2 - x1, 1) + h = np.maximum(y2 - y1, 1) + h_scale = out_h / max(h, 0.1) # avoid too large scale + w_scale = out_w / max(w, 0.1) + + resized_mask = [] + for p in mask: + p = p.copy() + # crop + # pycocotools will clip the boundary + p[0::2] = p[0::2] - bbox[0] + p[1::2] = p[1::2] - bbox[1] + + # resize + p[0::2] = p[0::2] * w_scale + p[1::2] = p[1::2] * h_scale + resized_mask.append(p) + resized_masks.append(resized_mask) + return PolygonMasks(resized_masks, *out_shape) + + def translate(self, + out_shape, + offset, + direction='horizontal', + border_value=None, + interpolation=None): + """Translate the PolygonMasks. + + Example: + >>> self = PolygonMasks.random(dtype=np.int64) + >>> out_shape = (self.height, self.width) + >>> new = self.translate(out_shape, 4., direction='horizontal') + >>> assert np.all(new.masks[0][0][1::2] == self.masks[0][0][1::2]) + >>> assert np.all(new.masks[0][0][0::2] == self.masks[0][0][0::2] + 4) # noqa: E501 + """ + assert border_value is None or border_value == 0, \ + 'Here border_value is not '\ + f'used, and defaultly should be None or 0. got {border_value}.' + if len(self.masks) == 0: + translated_masks = PolygonMasks([], *out_shape) + else: + translated_masks = [] + for poly_per_obj in self.masks: + translated_poly_per_obj = [] + for p in poly_per_obj: + p = p.copy() + if direction == 'horizontal': + p[0::2] = np.clip(p[0::2] + offset, 0, out_shape[1]) + elif direction == 'vertical': + p[1::2] = np.clip(p[1::2] + offset, 0, out_shape[0]) + translated_poly_per_obj.append(p) + translated_masks.append(translated_poly_per_obj) + translated_masks = PolygonMasks(translated_masks, *out_shape) + return translated_masks + + def shear(self, + out_shape, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """See :func:`BaseInstanceMasks.shear`.""" + if len(self.masks) == 0: + sheared_masks = PolygonMasks([], *out_shape) + else: + sheared_masks = [] + if direction == 'horizontal': + shear_matrix = np.stack([[1, magnitude], + [0, 1]]).astype(np.float32) + elif direction == 'vertical': + shear_matrix = np.stack([[1, 0], [magnitude, + 1]]).astype(np.float32) + for poly_per_obj in self.masks: + sheared_poly = [] + for p in poly_per_obj: + p = np.stack([p[0::2], p[1::2]], axis=0) # [2, n] + new_coords = np.matmul(shear_matrix, p) # [2, n] + new_coords[0, :] = np.clip(new_coords[0, :], 0, + out_shape[1]) + new_coords[1, :] = np.clip(new_coords[1, :], 0, + out_shape[0]) + sheared_poly.append( + new_coords.transpose((1, 0)).reshape(-1)) + sheared_masks.append(sheared_poly) + sheared_masks = PolygonMasks(sheared_masks, *out_shape) + return sheared_masks + + def rotate(self, + out_shape, + angle, + center=None, + scale=1.0, + border_value=0, + interpolation='bilinear'): + """See :func:`BaseInstanceMasks.rotate`.""" + if len(self.masks) == 0: + rotated_masks = PolygonMasks([], *out_shape) + else: + rotated_masks = [] + rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale) + for poly_per_obj in self.masks: + rotated_poly = [] + for p in poly_per_obj: + p = p.copy() + coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2] + # pad 1 to convert from format [x, y] to homogeneous + # coordinates format [x, y, 1] + coords = np.concatenate( + (coords, np.ones((coords.shape[0], 1), coords.dtype)), + axis=1) # [n, 3] + rotated_coords = np.matmul( + rotate_matrix[None, :, :], + coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2] + rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0, + out_shape[1]) + rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0, + out_shape[0]) + rotated_poly.append(rotated_coords.reshape(-1)) + rotated_masks.append(rotated_poly) + rotated_masks = PolygonMasks(rotated_masks, *out_shape) + return rotated_masks + + def to_bitmap(self): + """convert polygon masks to bitmap masks.""" + bitmap_masks = self.to_ndarray() + return BitmapMasks(bitmap_masks, self.height, self.width) + + @property + def areas(self): + """Compute areas of masks. + + This func is modified from `detectron2 + `_. + The function only works with Polygons using the shoelace formula. + + Return: + ndarray: areas of each instance + """ # noqa: W501 + area = [] + for polygons_per_obj in self.masks: + area_per_obj = 0 + for p in polygons_per_obj: + area_per_obj += self._polygon_area(p[0::2], p[1::2]) + area.append(area_per_obj) + return np.asarray(area) + + def _polygon_area(self, x, y): + """Compute the area of a component of a polygon. + + Using the shoelace formula: + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + + Args: + x (ndarray): x coordinates of the component + y (ndarray): y coordinates of the component + + Return: + float: the are of the component + """ # noqa: 501 + return 0.5 * np.abs( + np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + def to_ndarray(self): + """Convert masks to the format of ndarray.""" + if len(self.masks) == 0: + return np.empty((0, self.height, self.width), dtype=np.uint8) + bitmap_masks = [] + for poly_per_obj in self.masks: + bitmap_masks.append( + polygon_to_bitmap(poly_per_obj, self.height, self.width)) + return np.stack(bitmap_masks) + + def to_tensor(self, dtype, device): + """See :func:`BaseInstanceMasks.to_tensor`.""" + if len(self.masks) == 0: + return torch.empty((0, self.height, self.width), + dtype=dtype, + device=device) + ndarray_masks = self.to_ndarray() + return torch.tensor(ndarray_masks, dtype=dtype, device=device) + + @classmethod + def random(cls, + num_masks=3, + height=32, + width=32, + n_verts=5, + dtype=np.float32, + rng=None): + """Generate random polygon masks for demo / testing purposes. + + Adapted from [1]_ + + References: + .. [1] https://gitlab.kitware.com/computer-vision/kwimage/-/blob/928cae35ca8/kwimage/structs/polygon.py#L379 # noqa: E501 + + Example: + >>> from mmdet.data_elements.mask.structures import PolygonMasks + >>> self = PolygonMasks.random() + >>> print('self = {}'.format(self)) + """ + from mmdet.utils.util_random import ensure_rng + rng = ensure_rng(rng) + + def _gen_polygon(n, irregularity, spikeyness): + """Creates the polygon by sampling points on a circle around the + centre. Random noise is added by varying the angular spacing + between sequential points, and by varying the radial distance of + each point from the centre. + + Based on original code by Mike Ounsworth + + Args: + n (int): number of vertices + irregularity (float): [0,1] indicating how much variance there + is in the angular spacing of vertices. [0,1] will map to + [0, 2pi/numberOfVerts] + spikeyness (float): [0,1] indicating how much variance there is + in each vertex from the circle of radius aveRadius. [0,1] + will map to [0, aveRadius] + + Returns: + a list of vertices, in CCW order. + """ + from scipy.stats import truncnorm + + # Generate around the unit circle + cx, cy = (0.0, 0.0) + radius = 1 + + tau = np.pi * 2 + + irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / n + spikeyness = np.clip(spikeyness, 1e-9, 1) + + # generate n angle steps + lower = (tau / n) - irregularity + upper = (tau / n) + irregularity + angle_steps = rng.uniform(lower, upper, n) + + # normalize the steps so that point 0 and point n+1 are the same + k = angle_steps.sum() / (2 * np.pi) + angles = (angle_steps / k).cumsum() + rng.uniform(0, tau) + + # Convert high and low values to be wrt the standard normal range + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html + low = 0 + high = 2 * radius + mean = radius + std = spikeyness + a = (low - mean) / std + b = (high - mean) / std + tnorm = truncnorm(a=a, b=b, loc=mean, scale=std) + + # now generate the points + radii = tnorm.rvs(n, random_state=rng) + x_pts = cx + radii * np.cos(angles) + y_pts = cy + radii * np.sin(angles) + + points = np.hstack([x_pts[:, None], y_pts[:, None]]) + + # Scale to 0-1 space + points = points - points.min(axis=0) + points = points / points.max(axis=0) + + # Randomly place within 0-1 space + points = points * (rng.rand() * .8 + .2) + min_pt = points.min(axis=0) + max_pt = points.max(axis=0) + + high = (1 - max_pt) + low = (0 - min_pt) + offset = (rng.rand(2) * (high - low)) + low + points = points + offset + return points + + def _order_vertices(verts): + """ + References: + https://stackoverflow.com/questions/1709283/how-can-i-sort-a-coordinate-list-for-a-rectangle-counterclockwise + """ + mlat = verts.T[0].sum() / len(verts) + mlng = verts.T[1].sum() / len(verts) + + tau = np.pi * 2 + angle = (np.arctan2(mlat - verts.T[0], verts.T[1] - mlng) + + tau) % tau + sortx = angle.argsort() + verts = verts.take(sortx, axis=0) + return verts + + # Generate a random exterior for each requested mask + masks = [] + for _ in range(num_masks): + exterior = _order_vertices(_gen_polygon(n_verts, 0.9, 0.9)) + exterior = (exterior * [(width, height)]).astype(dtype) + masks.append([exterior.ravel()]) + + self = cls(masks, height, width) + return self + + @classmethod + def cat(cls: Type[T], masks: Sequence[T]) -> T: + """Concatenate a sequence of masks into one single mask instance. + + Args: + masks (Sequence[PolygonMasks]): A sequence of mask instances. + + Returns: + PolygonMasks: Concatenated mask instance. + """ + assert isinstance(masks, Sequence) + if len(masks) == 0: + raise ValueError('masks should not be an empty list.') + assert all(isinstance(m, cls) for m in masks) + + mask_list = list(itertools.chain(*[m.masks for m in masks])) + return cls(mask_list, masks[0].height, masks[0].width) + + +def polygon_to_bitmap(polygons, height, width): + """Convert masks from the form of polygons to bitmaps. + + Args: + polygons (list[ndarray]): masks in polygon representation + height (int): mask height + width (int): mask width + + Return: + ndarray: the converted masks in bitmap representation + """ + rles = maskUtils.frPyObjects(polygons, height, width) + rle = maskUtils.merge(rles) + bitmap_mask = maskUtils.decode(rle).astype(bool) + return bitmap_mask + + +def bitmap_to_polygon(bitmap): + """Convert masks from the form of bitmaps to polygons. + + Args: + bitmap (ndarray): masks in bitmap representation. + + Return: + list[ndarray]: the converted mask in polygon representation. + bool: whether the mask has holes. + """ + bitmap = np.ascontiguousarray(bitmap).astype(np.uint8) + # cv2.RETR_CCOMP: retrieves all of the contours and organizes them + # into a two-level hierarchy. At the top level, there are external + # boundaries of the components. At the second level, there are + # boundaries of the holes. If there is another contour inside a hole + # of a connected component, it is still put at the top level. + # cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points. + outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + contours = outs[-2] + hierarchy = outs[-1] + if hierarchy is None: + return [], False + # hierarchy[i]: 4 elements, for the indexes of next, previous, + # parent, or nested contours. If there is no corresponding contour, + # it will be -1. + with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any() + contours = [c.reshape(-1, 2) for c in contours] + return contours, with_hole diff --git a/mmdet/structures/mask/utils.py b/mmdet/structures/mask/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd445e4fce1a312949f222d54d230a1a622d726 --- /dev/null +++ b/mmdet/structures/mask/utils.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pycocotools.mask as mask_util +import torch +from mmengine.utils import slice_list + + +def split_combined_polys(polys, poly_lens, polys_per_mask): + """Split the combined 1-D polys into masks. + + A mask is represented as a list of polys, and a poly is represented as + a 1-D array. In dataset, all masks are concatenated into a single 1-D + tensor. Here we need to split the tensor into original representations. + + Args: + polys (list): a list (length = image num) of 1-D tensors + poly_lens (list): a list (length = image num) of poly length + polys_per_mask (list): a list (length = image num) of poly number + of each mask + + Returns: + list: a list (length = image num) of list (length = mask num) of \ + list (length = poly num) of numpy array. + """ + mask_polys_list = [] + for img_id in range(len(polys)): + polys_single = polys[img_id] + polys_lens_single = poly_lens[img_id].tolist() + polys_per_mask_single = polys_per_mask[img_id].tolist() + + split_polys = slice_list(polys_single, polys_lens_single) + mask_polys = slice_list(split_polys, polys_per_mask_single) + mask_polys_list.append(mask_polys) + return mask_polys_list + + +# TODO: move this function to more proper place +def encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. + + Args: + mask_results (list): bitmap mask results. + + Returns: + list | tuple: RLE encoded mask. + """ + encoded_mask_results = [] + for mask in mask_results: + encoded_mask_results.append( + mask_util.encode( + np.array(mask[:, :, np.newaxis], order='F', + dtype='uint8'))[0]) # encoded with RLE + return encoded_mask_results + + +def mask2bbox(masks): + """Obtain tight bounding boxes of binary masks. + + Args: + masks (Tensor): Binary mask of shape (n, h, w). + + Returns: + Tensor: Bboxe with shape (n, 4) of \ + positive region in binary mask. + """ + N = masks.shape[0] + bboxes = masks.new_zeros((N, 4), dtype=torch.float32) + x_any = torch.any(masks, dim=1) + y_any = torch.any(masks, dim=2) + for i in range(N): + x = torch.where(x_any[i, :])[0] + y = torch.where(y_any[i, :])[0] + if len(x) > 0 and len(y) > 0: + bboxes[i, :] = bboxes.new_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1]) + + return bboxes diff --git a/mmdet/structures/reid_data_sample.py b/mmdet/structures/reid_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..69958eece3671c9040c1f5561e724ca2d5f8e155 --- /dev/null +++ b/mmdet/structures/reid_data_sample.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from numbers import Number +from typing import Sequence, Union + +import mmengine +import numpy as np +import torch +from mmengine.structures import BaseDataElement, LabelData + + +def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int], + num_classes: int = None) -> LabelData: + """Convert label of various python types to :obj:`mmengine.LabelData`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + num_classes (int, optional): The number of classes. If not None, set + it to the metainfo. Defaults to None. + + Returns: + :obj:`mmengine.LabelData`: The foramtted label data. + """ + + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif isinstance(value, Sequence) and not mmengine.utils.is_str(value): + value = torch.tensor(value) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + + metainfo = {} + if num_classes is not None: + metainfo['num_classes'] = num_classes + if value.max() >= num_classes: + raise ValueError(f'The label data ({value}) should not ' + f'exceed num_classes ({num_classes}).') + label = LabelData(label=value, metainfo=metainfo) + return label + + +class ReIDDataSample(BaseDataElement): + """A data structure interface of ReID task. + + It's used as interfaces between different components. + + Meta field: + img_shape (Tuple): The shape of the corresponding input image. + Used for visualization. + ori_shape (Tuple): The original shape of the corresponding image. + Used for visualization. + num_classes (int): The number of all categories. + Used for label format conversion. + + Data field: + gt_label (LabelData): The ground truth label. + pred_label (LabelData): The predicted label. + scores (torch.Tensor): The outputs of model. + """ + + @property + def gt_label(self): + return self._gt_label + + @gt_label.setter + def gt_label(self, value: LabelData): + self.set_field(value, '_gt_label', dtype=LabelData) + + @gt_label.deleter + def gt_label(self): + del self._gt_label + + def set_gt_label( + self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] + ) -> 'ReIDDataSample': + """Set label of ``gt_label``.""" + label = format_label(value, self.get('num_classes')) + if 'gt_label' in self: # setting for the second time + self.gt_label.label = label.label + else: # setting for the first time + self.gt_label = label + return self + + def set_gt_score(self, value: torch.Tensor) -> 'ReIDDataSample': + """Set score of ``gt_label``.""" + assert isinstance(value, torch.Tensor), \ + f'The value should be a torch.Tensor but got {type(value)}.' + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + if 'num_classes' in self: + assert value.size(0) == self.num_classes, \ + f"The length of value ({value.size(0)}) doesn't "\ + f'match the num_classes ({self.num_classes}).' + metainfo = {'num_classes': self.num_classes} + else: + metainfo = {'num_classes': value.size(0)} + + if 'gt_label' in self: # setting for the second time + self.gt_label.score = value + else: # setting for the first time + self.gt_label = LabelData(score=value, metainfo=metainfo) + return self + + @property + def pred_feature(self): + return self._pred_feature + + @pred_feature.setter + def pred_feature(self, value: torch.Tensor): + self.set_field(value, '_pred_feature', dtype=torch.Tensor) + + @pred_feature.deleter + def pred_feature(self): + del self._pred_feature diff --git a/mmdet/structures/track_data_sample.py b/mmdet/structures/track_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..d005a5a42f57682d0b76d60d3dae463c4b4dc727 --- /dev/null +++ b/mmdet/structures/track_data_sample.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence + +import numpy as np +import torch +from mmengine.structures import BaseDataElement + +from .det_data_sample import DetDataSample + + +class TrackDataSample(BaseDataElement): + """A data structure interface of tracking task in MMDetection. It is used + as interfaces between different components. + + This data structure can be viewd as a wrapper of multiple DetDataSample to + some extent. Specifically, it only contains a property: + ``video_data_samples`` which is a list of DetDataSample, each of which + corresponds to a single frame. If you want to get the property of a single + frame, you must first get the corresponding ``DetDataSample`` by indexing + and then get the property of the frame, such as ``gt_instances``, + ``pred_instances`` and so on. As for metainfo, it differs from + ``DetDataSample`` in that each value corresponds to the metainfo key is a + list where each element corresponds to information of a single frame. + + Examples: + >>> import torch + >>> from mmengine.structures import InstanceData + >>> from mmdet.structures import DetDataSample, TrackDataSample + >>> track_data_sample = TrackDataSample() + >>> # set the 1st frame + >>> frame1_data_sample = DetDataSample(metainfo=dict( + ... img_shape=(100, 100), frame_id=0)) + >>> frame1_gt_instances = InstanceData() + >>> frame1_gt_instances.bbox = torch.zeros([2, 4]) + >>> frame1_data_sample.gt_instances = frame1_gt_instances + >>> # set the 2nd frame + >>> frame2_data_sample = DetDataSample(metainfo=dict( + ... img_shape=(100, 100), frame_id=1)) + >>> frame2_gt_instances = InstanceData() + >>> frame2_gt_instances.bbox = torch.ones([3, 4]) + >>> frame2_data_sample.gt_instances = frame2_gt_instances + >>> track_data_sample.video_data_samples = [frame1_data_sample, + ... frame2_data_sample] + >>> # set metainfo for track_data_sample + >>> track_data_sample.set_metainfo(dict(key_frames_inds=[0])) + >>> track_data_sample.set_metainfo(dict(ref_frames_inds=[1])) + >>> print(track_data_sample) + + ) at 0x7f64bd223340>, + ) at 0x7f64bd1346d0>] + ) at 0x7f64bd2237f0> + >>> print(len(track_data_sample)) + 2 + >>> key_data_sample = track_data_sample.get_key_frames() + >>> print(key_data_sample[0].frame_id) + 0 + >>> ref_data_sample = track_data_sample.get_ref_frames() + >>> print(ref_data_sample[0].frame_id) + 1 + >>> frame1_data_sample = track_data_sample[0] + >>> print(frame1_data_sample.gt_instances.bbox) + tensor([[0., 0., 0., 0.], + [0., 0., 0., 0.]]) + >>> # Tensor-like methods + >>> cuda_track_data_sample = track_data_sample.to('cuda') + >>> cuda_track_data_sample = track_data_sample.cuda() + >>> cpu_track_data_sample = track_data_sample.cpu() + >>> cpu_track_data_sample = track_data_sample.to('cpu') + >>> fp16_instances = cuda_track_data_sample.to( + ... device=None, dtype=torch.float16, non_blocking=False, + ... copy=False, memory_format=torch.preserve_format) + """ + + @property + def video_data_samples(self) -> List[DetDataSample]: + return self._video_data_samples + + @video_data_samples.setter + def video_data_samples(self, value: List[DetDataSample]): + if isinstance(value, DetDataSample): + value = [value] + assert isinstance(value, list), 'video_data_samples must be a list' + assert isinstance( + value[0], DetDataSample + ), 'video_data_samples must be a list of DetDataSample, but got ' + f'{value[0]}' + self.set_field(value, '_video_data_samples', dtype=list) + + @video_data_samples.deleter + def video_data_samples(self): + del self._video_data_samples + + def __getitem__(self, index): + assert hasattr(self, + '_video_data_samples'), 'video_data_samples not set' + return self._video_data_samples[index] + + def get_key_frames(self): + assert hasattr(self, 'key_frames_inds'), \ + 'key_frames_inds not set' + assert isinstance(self.key_frames_inds, Sequence) + key_frames_info = [] + for index in self.key_frames_inds: + key_frames_info.append(self[index]) + return key_frames_info + + def get_ref_frames(self): + assert hasattr(self, 'ref_frames_inds'), \ + 'ref_frames_inds not set' + ref_frames_info = [] + assert isinstance(self.ref_frames_inds, Sequence) + for index in self.ref_frames_inds: + ref_frames_info.append(self[index]) + return ref_frames_info + + def __len__(self): + return len(self._video_data_samples) if hasattr( + self, '_video_data_samples') else 0 + + # TODO: add UT for this Tensor-like method + # Tensor-like methods + def to(self, *args, **kwargs) -> 'BaseDataElement': + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if hasattr(v, 'to'): + v = v.to(*args, **kwargs) + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + # Tensor-like methods + def cpu(self) -> 'BaseDataElement': + """Convert all tensors to CPU in data.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cpu() + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + # Tensor-like methods + def cuda(self) -> 'BaseDataElement': + """Convert all tensors to GPU in data.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cuda() + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + # Tensor-like methods + def npu(self) -> 'BaseDataElement': + """Convert all tensors to NPU in data.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.npu() + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + # Tensor-like methods + def detach(self) -> 'BaseDataElement': + """Detach all tensors in data.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach() + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + # Tensor-like methods + def numpy(self) -> 'BaseDataElement': + """Convert all tensors to np.ndarray in data.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach().cpu().numpy() + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + def to_tensor(self) -> 'BaseDataElement': + """Convert all np.ndarray to tensor in data.""" + new_data = self.new() + for k, v_list in self.items(): + data_list = [] + for v in v_list: + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + elif isinstance(v, BaseDataElement): + v = v.to_tensor() + data_list.append(v) + if len(data_list) > 0: + new_data.set_data({f'{k}': data_list}) + return new_data + + # Tensor-like methods + def clone(self) -> 'BaseDataElement': + """Deep copy the current data element. + + Returns: + BaseDataElement: The copy of current data element. + """ + clone_data = self.__class__() + clone_data.set_metainfo(dict(self.metainfo_items())) + + for k, v_list in self.items(): + clone_item_list = [] + for v in v_list: + clone_item_list.append(v.clone()) + clone_data.set_data({k: clone_item_list}) + return clone_data + + +TrackSampleList = List[TrackDataSample] +OptTrackSampleList = Optional[TrackSampleList] diff --git a/mmdet/testing/__init__.py b/mmdet/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..766fb471022ee6f2e4e1ff13a52040ae57772e53 --- /dev/null +++ b/mmdet/testing/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403 +from ._utils import (demo_mm_inputs, demo_mm_proposals, + demo_mm_sampling_results, demo_track_inputs, + get_detector_cfg, get_roi_head_cfg, random_boxes, + replace_to_ceph) + +__all__ = [ + 'demo_mm_inputs', 'get_detector_cfg', 'get_roi_head_cfg', + 'demo_mm_proposals', 'demo_mm_sampling_results', 'replace_to_ceph', + 'demo_track_inputs', 'VideoDataSampleFeeder', 'random_boxes' +] diff --git a/mmdet/testing/_fast_stop_training_hook.py b/mmdet/testing/_fast_stop_training_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e3d11439f875d2c9a6ce6b8a0b33acc832c2c5 --- /dev/null +++ b/mmdet/testing/_fast_stop_training_hook.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import Hook + +from mmdet.registry import HOOKS + + +@HOOKS.register_module() +class FastStopTrainingHook(Hook): + """Set runner's epoch information to the model.""" + + def __init__(self, by_epoch, save_ckpt=False, stop_iter_or_epoch=5): + self.by_epoch = by_epoch + self.save_ckpt = save_ckpt + self.stop_iter_or_epoch = stop_iter_or_epoch + + def after_train_iter(self, runner, batch_idx: int, data_batch: None, + outputs: None) -> None: + if self.save_ckpt and self.by_epoch: + # If it is epoch-based and want to save weights, + # we must run at least 1 epoch. + return + if runner.iter >= self.stop_iter_or_epoch: + raise RuntimeError('quick exit') + + def after_train_epoch(self, runner) -> None: + if runner.epoch >= self.stop_iter_or_epoch - 1: + raise RuntimeError('quick exit') diff --git a/mmdet/testing/_utils.py b/mmdet/testing/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d3a86deab17e9c5acd1b1fe7f42e0bfa78943d --- /dev/null +++ b/mmdet/testing/_utils.py @@ -0,0 +1,469 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from os.path import dirname, exists, join + +import numpy as np +import torch +from mmengine.config import Config +from mmengine.dataset import pseudo_collate +from mmengine.structures import InstanceData, PixelData + +from mmdet.utils.util_random import ensure_rng +from ..registry import TASK_UTILS +from ..structures import DetDataSample, TrackDataSample +from ..structures.bbox import HorizontalBoxes + + +def _get_config_directory(): + """Find the predefined detector config directory.""" + try: + # Assume we are running in the source mmdetection repo + repo_dpath = dirname(dirname(dirname(__file__))) + except NameError: + # For IPython development when this __file__ is not defined + import mmdet + repo_dpath = dirname(dirname(mmdet.__file__)) + config_dpath = join(repo_dpath, 'configs') + if not exists(config_dpath): + raise Exception('Cannot find config path') + return config_dpath + + +def _get_config_module(fname): + """Load a configuration as a python module.""" + config_dpath = _get_config_directory() + config_fpath = join(config_dpath, fname) + config_mod = Config.fromfile(config_fpath) + return config_mod + + +def get_detector_cfg(fname): + """Grab configs necessary to create a detector. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + return model + + +def get_roi_head_cfg(fname): + """Grab configs necessary to create a roi_head. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + + roi_head = model.roi_head + train_cfg = None if model.train_cfg is None else model.train_cfg.rcnn + test_cfg = None if model.test_cfg is None else model.test_cfg.rcnn + roi_head.update(dict(train_cfg=train_cfg, test_cfg=test_cfg)) + return roi_head + + +def _rand_bboxes(rng, num_boxes, w, h): + cx, cy, bw, bh = rng.rand(num_boxes, 4).T + + tl_x = ((cx * w) - (w * bw / 2)).clip(0, w) + tl_y = ((cy * h) - (h * bh / 2)).clip(0, h) + br_x = ((cx * w) + (w * bw / 2)).clip(0, w) + br_y = ((cy * h) + (h * bh / 2)).clip(0, h) + + bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T + return bboxes + + +def _rand_masks(rng, num_boxes, bboxes, img_w, img_h): + from mmdet.structures.mask import BitmapMasks + masks = np.zeros((num_boxes, img_h, img_w)) + for i, bbox in enumerate(bboxes): + bbox = bbox.astype(np.int32) + mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) > + 0.3).astype(np.int64) + masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask + return BitmapMasks(masks, height=img_h, width=img_w) + + +def demo_mm_inputs(batch_size=2, + image_shapes=(3, 128, 128), + num_items=None, + num_classes=10, + sem_seg_output_strides=1, + with_mask=False, + with_semantic=False, + use_box_type=False, + device='cpu', + texts=None, + custom_entities=False): + """Create a superset of inputs needed to run test or train batches. + + Args: + batch_size (int): batch size. Defaults to 2. + image_shapes (List[tuple], Optional): image shape. + Defaults to (3, 128, 128) + num_items (None | List[int]): specifies the number + of boxes in each batch item. Default to None. + num_classes (int): number of different labels a + box might have. Defaults to 10. + with_mask (bool): Whether to return mask annotation. + Defaults to False. + with_semantic (bool): whether to return semantic. + Defaults to False. + device (str): Destination device type. Defaults to cpu. + """ + rng = np.random.RandomState(0) + + if isinstance(image_shapes, list): + assert len(image_shapes) == batch_size + else: + image_shapes = [image_shapes] * batch_size + + if isinstance(num_items, list): + assert len(num_items) == batch_size + + if texts is not None: + assert batch_size == len(texts) + + packed_inputs = [] + for idx in range(batch_size): + image_shape = image_shapes[idx] + c, h, w = image_shape + + image = rng.randint(0, 255, size=image_shape, dtype=np.uint8) + + mm_inputs = dict() + mm_inputs['inputs'] = torch.from_numpy(image).to(device) + + img_meta = { + 'img_id': idx, + 'img_shape': image_shape[1:], + 'ori_shape': image_shape[1:], + 'filename': '.png', + 'scale_factor': np.array([1.1, 1.2]), + 'flip': False, + 'flip_direction': None, + 'border': [1, 1, 1, 1] # Only used by CenterNet + } + + if texts: + img_meta['text'] = texts[idx] + img_meta['custom_entities'] = custom_entities + + data_sample = DetDataSample() + data_sample.set_metainfo(img_meta) + + # gt_instances + gt_instances = InstanceData() + if num_items is None: + num_boxes = rng.randint(1, 10) + else: + num_boxes = num_items[idx] + + bboxes = _rand_bboxes(rng, num_boxes, w, h) + labels = rng.randint(1, num_classes, size=num_boxes) + # TODO: remove this part when all model adapted with BaseBoxes + if use_box_type: + gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32) + else: + gt_instances.bboxes = torch.FloatTensor(bboxes) + gt_instances.labels = torch.LongTensor(labels) + + if with_mask: + masks = _rand_masks(rng, num_boxes, bboxes, w, h) + gt_instances.masks = masks + + # TODO: waiting for ci to be fixed + # masks = np.random.randint(0, 2, (len(bboxes), h, w), dtype=np.uint8) + # gt_instances.mask = BitmapMasks(masks, h, w) + + data_sample.gt_instances = gt_instances + + # ignore_instances + ignore_instances = InstanceData() + bboxes = _rand_bboxes(rng, num_boxes, w, h) + if use_box_type: + ignore_instances.bboxes = HorizontalBoxes( + bboxes, dtype=torch.float32) + else: + ignore_instances.bboxes = torch.FloatTensor(bboxes) + data_sample.ignored_instances = ignore_instances + + # gt_sem_seg + if with_semantic: + # assume gt_semantic_seg using scale 1/8 of the img + gt_semantic_seg = torch.from_numpy( + np.random.randint( + 0, + num_classes, (1, h // sem_seg_output_strides, + w // sem_seg_output_strides), + dtype=np.uint8)) + gt_sem_seg_data = dict(sem_seg=gt_semantic_seg) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + mm_inputs['data_samples'] = data_sample.to(device) + + # TODO: gt_ignore + + packed_inputs.append(mm_inputs) + data = pseudo_collate(packed_inputs) + return data + + +def demo_mm_proposals(image_shapes, num_proposals, device='cpu'): + """Create a list of fake porposals. + + Args: + image_shapes (list[tuple[int]]): Batch image shapes. + num_proposals (int): The number of fake proposals. + """ + rng = np.random.RandomState(0) + + results = [] + for img_shape in image_shapes: + result = InstanceData() + w, h = img_shape[1:] + proposals = _rand_bboxes(rng, num_proposals, w, h) + result.bboxes = torch.from_numpy(proposals).float() + result.scores = torch.from_numpy(rng.rand(num_proposals)).float() + result.labels = torch.zeros(num_proposals).long() + results.append(result.to(device)) + return results + + +def demo_mm_sampling_results(proposals_list, + batch_gt_instances, + batch_gt_instances_ignore=None, + assigner_cfg=None, + sampler_cfg=None, + feats=None): + """Create sample results that can be passed to BBoxHead.get_targets.""" + assert len(proposals_list) == len(batch_gt_instances) + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None for _ in batch_gt_instances] + else: + assert len(batch_gt_instances_ignore) == len(batch_gt_instances) + + default_assigner_cfg = dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1) + assigner_cfg = assigner_cfg if assigner_cfg is not None \ + else default_assigner_cfg + default_sampler_cfg = dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True) + sampler_cfg = sampler_cfg if sampler_cfg is not None \ + else default_sampler_cfg + bbox_assigner = TASK_UTILS.build(assigner_cfg) + bbox_sampler = TASK_UTILS.build(sampler_cfg) + + sampling_results = [] + for i in range(len(batch_gt_instances)): + if feats is not None: + feats = [lvl_feat[i][None] for lvl_feat in feats] + # rename proposals.bboxes to proposals.priors + proposals = proposals_list[i] + proposals.priors = proposals.pop('bboxes') + + assign_result = bbox_assigner.assign(proposals, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = bbox_sampler.sample( + assign_result, proposals, batch_gt_instances[i], feats=feats) + sampling_results.append(sampling_result) + + return sampling_results + + +def demo_track_inputs(batch_size=1, + num_frames=2, + key_frames_inds=None, + image_shapes=(3, 128, 128), + num_items=None, + num_classes=1, + with_mask=False, + with_semantic=False): + """Create a superset of inputs needed to run test or train batches. + + Args: + batch_size (int): batch size. Default to 1. + num_frames (int): The number of frames. + key_frames_inds (List): The indices of key frames. + image_shapes (List[tuple], Optional): image shape. + Default to (3, 128, 128) + num_items (None | List[int]): specifies the number + of boxes in each batch item. Default to None. + num_classes (int): number of different labels a + box might have. Default to 1. + with_mask (bool): Whether to return mask annotation. + Defaults to False. + with_semantic (bool): whether to return semantic. + Default to False. + """ + rng = np.random.RandomState(0) + + # Make sure the length of image_shapes is equal to ``batch_size`` + if isinstance(image_shapes, list): + assert len(image_shapes) == batch_size + else: + image_shapes = [image_shapes] * batch_size + + packed_inputs = [] + for idx in range(batch_size): + mm_inputs = dict(inputs=dict()) + _, h, w = image_shapes[idx] + + imgs = rng.randint( + 0, 255, size=(num_frames, *image_shapes[idx]), dtype=np.uint8) + mm_inputs['inputs'] = torch.from_numpy(imgs) + + img_meta = { + 'img_id': idx, + 'img_shape': image_shapes[idx][-2:], + 'ori_shape': image_shapes[idx][-2:], + 'filename': '.png', + 'scale_factor': np.array([1.1, 1.2]), + 'flip': False, + 'flip_direction': None, + 'is_video_data': True, + } + + video_data_samples = [] + for i in range(num_frames): + data_sample = DetDataSample() + img_meta['frame_id'] = i + data_sample.set_metainfo(img_meta) + + # gt_instances + gt_instances = InstanceData() + if num_items is None: + num_boxes = rng.randint(1, 10) + else: + num_boxes = num_items[idx] + + bboxes = _rand_bboxes(rng, num_boxes, w, h) + labels = rng.randint(0, num_classes, size=num_boxes) + instances_id = rng.randint(100, num_classes + 100, size=num_boxes) + gt_instances.bboxes = torch.FloatTensor(bboxes) + gt_instances.labels = torch.LongTensor(labels) + gt_instances.instances_ids = torch.LongTensor(instances_id) + + if with_mask: + masks = _rand_masks(rng, num_boxes, bboxes, w, h) + gt_instances.masks = masks + + data_sample.gt_instances = gt_instances + # ignore_instances + ignore_instances = InstanceData() + bboxes = _rand_bboxes(rng, num_boxes, w, h) + ignore_instances.bboxes = bboxes + data_sample.ignored_instances = ignore_instances + + video_data_samples.append(data_sample) + + track_data_sample = TrackDataSample() + track_data_sample.video_data_samples = video_data_samples + if key_frames_inds is not None: + assert isinstance( + key_frames_inds, + list) and len(key_frames_inds) < num_frames and max( + key_frames_inds) < num_frames + ref_frames_inds = [ + i for i in range(num_frames) if i not in key_frames_inds + ] + track_data_sample.set_metainfo( + dict(key_frames_inds=key_frames_inds)) + track_data_sample.set_metainfo( + dict(ref_frames_inds=ref_frames_inds)) + mm_inputs['data_samples'] = track_data_sample + + # TODO: gt_ignore + packed_inputs.append(mm_inputs) + data = pseudo_collate(packed_inputs) + return data + + +def random_boxes(num=1, scale=1, rng=None): + """Simple version of ``kwimage.Boxes.random`` + Returns: + Tensor: shape (n, 4) in x1, y1, x2, y2 format. + References: + https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 # noqa: E501 + Example: + >>> num = 3 + >>> scale = 512 + >>> rng = 0 + >>> boxes = random_boxes(num, scale, rng) + >>> print(boxes) + tensor([[280.9925, 278.9802, 308.6148, 366.1769], + [216.9113, 330.6978, 224.0446, 456.5878], + [405.3632, 196.3221, 493.3953, 270.7942]]) + """ + rng = ensure_rng(rng) + + tlbr = rng.rand(num, 4).astype(np.float32) + + tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2]) + tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3]) + br_x = np.maximum(tlbr[:, 0], tlbr[:, 2]) + br_y = np.maximum(tlbr[:, 1], tlbr[:, 3]) + + tlbr[:, 0] = tl_x * scale + tlbr[:, 1] = tl_y * scale + tlbr[:, 2] = br_x * scale + tlbr[:, 3] = br_y * scale + + boxes = torch.from_numpy(tlbr) + return boxes + + +# TODO: Support full ceph +def replace_to_ceph(cfg): + backend_args = dict( + backend='petrel', + path_mapping=dict({ + './data/': 's3://openmmlab/datasets/detection/', + 'data/': 's3://openmmlab/datasets/detection/' + })) + + # TODO: name is a reserved interface, which will be used later. + def _process_pipeline(dataset, name): + + def replace_img(pipeline): + if pipeline['type'] == 'LoadImageFromFile': + pipeline['backend_args'] = backend_args + + def replace_ann(pipeline): + if pipeline['type'] == 'LoadAnnotations' or pipeline[ + 'type'] == 'LoadPanopticAnnotations': + pipeline['backend_args'] = backend_args + + if 'pipeline' in dataset: + replace_img(dataset.pipeline[0]) + replace_ann(dataset.pipeline[1]) + if 'dataset' in dataset: + # dataset wrapper + replace_img(dataset.dataset.pipeline[0]) + replace_ann(dataset.dataset.pipeline[1]) + else: + # dataset wrapper + replace_img(dataset.dataset.pipeline[0]) + replace_ann(dataset.dataset.pipeline[1]) + + def _process_evaluator(evaluator, name): + if evaluator['type'] == 'CocoPanopticMetric': + evaluator['backend_args'] = backend_args + + # half ceph + _process_pipeline(cfg.train_dataloader.dataset, cfg.filename) + _process_pipeline(cfg.val_dataloader.dataset, cfg.filename) + _process_pipeline(cfg.test_dataloader.dataset, cfg.filename) + _process_evaluator(cfg.val_evaluator, cfg.filename) + _process_evaluator(cfg.test_evaluator, cfg.filename) diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..449a890bac411f84790eb3d014175e3a48757847 --- /dev/null +++ b/mmdet/utils/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .collect_env import collect_env +from .compat_config import compat_cfg +from .dist_utils import (all_reduce_dict, allreduce_grads, reduce_mean, + sync_random_seed) +from .logger import get_caller_name, log_img_scale +from .memory import AvoidCUDAOOM, AvoidOOM +from .misc import (find_latest_checkpoint, get_test_pipeline_cfg, + update_data_root) +from .mot_error_visualize import imshow_mot_errors +from .replace_cfg_vals import replace_cfg_vals +from .setup_env import (register_all_modules, setup_cache_size_limit_of_dynamo, + setup_multi_processes) +from .split_batch import split_batch +from .typing_utils import (ConfigType, InstanceList, MultiConfig, + OptConfigType, OptInstanceList, OptMultiConfig, + OptPixelList, PixelList, RangeType) + +__all__ = [ + 'collect_env', 'find_latest_checkpoint', 'update_data_root', + 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', + 'split_batch', 'register_all_modules', 'replace_cfg_vals', 'AvoidOOM', + 'AvoidCUDAOOM', 'all_reduce_dict', 'allreduce_grads', 'reduce_mean', + 'sync_random_seed', 'ConfigType', 'InstanceList', 'MultiConfig', + 'OptConfigType', 'OptInstanceList', 'OptMultiConfig', 'OptPixelList', + 'PixelList', 'RangeType', 'get_test_pipeline_cfg', + 'setup_cache_size_limit_of_dynamo', 'imshow_mot_errors' +] diff --git a/mmdet/utils/benchmark.py b/mmdet/utils/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..5419b2d175e3c48c063a39ae28758b386f9ab597 --- /dev/null +++ b/mmdet/utils/benchmark.py @@ -0,0 +1,529 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import time +from functools import partial +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import fuse_conv_bn +# TODO need update +# from mmcv.runner import wrap_fp16_model +from mmengine import MMLogger +from mmengine.config import Config +from mmengine.device import get_max_cuda_memory +from mmengine.dist import get_world_size +from mmengine.runner import Runner, load_checkpoint +from mmengine.utils.dl_utils import set_multi_processing +from torch.nn.parallel import DistributedDataParallel + +from mmdet.registry import DATASETS, MODELS + +try: + import psutil +except ImportError: + psutil = None + + +def custom_round(value: Union[int, float], + factor: Union[int, float], + precision: int = 2) -> float: + """Custom round function.""" + return round(value / factor, precision) + + +gb_round = partial(custom_round, factor=1024**3) + + +def print_log(msg: str, logger: Optional[MMLogger] = None) -> None: + """Print a log message.""" + if logger is None: + print(msg, flush=True) + else: + logger.info(msg) + + +def print_process_memory(p: psutil.Process, + logger: Optional[MMLogger] = None) -> None: + """print process memory info.""" + mem_used = gb_round(psutil.virtual_memory().used) + memory_full_info = p.memory_full_info() + uss_mem = gb_round(memory_full_info.uss) + if hasattr(memory_full_info, 'pss'): + pss_mem = gb_round(memory_full_info.pss) + + for children in p.children(): + child_mem_info = children.memory_full_info() + uss_mem += gb_round(child_mem_info.uss) + if hasattr(child_mem_info, 'pss'): + pss_mem += gb_round(child_mem_info.pss) + + process_count = 1 + len(p.children()) + + log_msg = f'(GB) mem_used: {mem_used:.2f} | uss: {uss_mem:.2f} | ' + if hasattr(memory_full_info, 'pss'): + log_msg += f'pss: {pss_mem:.2f} | ' + log_msg += f'total_proc: {process_count}' + print_log(log_msg, logger) + + +class BaseBenchmark: + """The benchmark base class. + + The ``run`` method is an external calling interface, and it will + call the ``run_once`` method ``repeat_num`` times for benchmarking. + Finally, call the ``average_multiple_runs`` method to further process + the results of multiple runs. + + Args: + max_iter (int): maximum iterations of benchmark. + log_interval (int): interval of logging. + num_warmup (int): Number of Warmup. + logger (MMLogger, optional): Formatted logger used to record messages. + """ + + def __init__(self, + max_iter: int, + log_interval: int, + num_warmup: int, + logger: Optional[MMLogger] = None): + self.max_iter = max_iter + self.log_interval = log_interval + self.num_warmup = num_warmup + self.logger = logger + + def run(self, repeat_num: int = 1) -> dict: + """benchmark entry method. + + Args: + repeat_num (int): Number of repeat benchmark. + Defaults to 1. + """ + assert repeat_num >= 1 + + results = [] + for _ in range(repeat_num): + results.append(self.run_once()) + + results = self.average_multiple_runs(results) + return results + + def run_once(self) -> dict: + """Executes the benchmark once.""" + raise NotImplementedError() + + def average_multiple_runs(self, results: List[dict]) -> dict: + """Average the results of multiple runs.""" + raise NotImplementedError() + + +class InferenceBenchmark(BaseBenchmark): + """The inference benchmark class. It will be statistical inference FPS, + CUDA memory and CPU memory information. + + Args: + cfg (mmengine.Config): config. + checkpoint (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + distributed (bool): distributed testing flag. + is_fuse_conv_bn (bool): Whether to fuse conv and bn, this will + slightly increase the inference speed. + max_iter (int): maximum iterations of benchmark. Defaults to 2000. + log_interval (int): interval of logging. Defaults to 50. + num_warmup (int): Number of Warmup. Defaults to 5. + logger (MMLogger, optional): Formatted logger used to record messages. + """ + + def __init__(self, + cfg: Config, + checkpoint: str, + distributed: bool, + is_fuse_conv_bn: bool, + max_iter: int = 2000, + log_interval: int = 50, + num_warmup: int = 5, + logger: Optional[MMLogger] = None): + super().__init__(max_iter, log_interval, num_warmup, logger) + + assert get_world_size( + ) == 1, 'Inference benchmark does not allow distributed multi-GPU' + + self.cfg = copy.deepcopy(cfg) + self.distributed = distributed + + if psutil is None: + raise ImportError('psutil is not installed, please install it by: ' + 'pip install psutil') + + self._process = psutil.Process() + env_cfg = self.cfg.get('env_cfg') + if env_cfg.get('cudnn_benchmark'): + torch.backends.cudnn.benchmark = True + + mp_cfg: dict = env_cfg.get('mp_cfg', {}) + set_multi_processing(**mp_cfg, distributed=self.distributed) + + print_log('before build: ', self.logger) + print_process_memory(self._process, self.logger) + + self.model = self._init_model(checkpoint, is_fuse_conv_bn) + + # Because multiple processes will occupy additional CPU resources, + # FPS statistics will be more unstable when num_workers is not 0. + # It is reasonable to set num_workers to 0. + dataloader_cfg = cfg.test_dataloader + dataloader_cfg['num_workers'] = 0 + dataloader_cfg['batch_size'] = 1 + dataloader_cfg['persistent_workers'] = False + self.data_loader = Runner.build_dataloader(dataloader_cfg) + + print_log('after build: ', self.logger) + print_process_memory(self._process, self.logger) + + def _init_model(self, checkpoint: str, is_fuse_conv_bn: bool) -> nn.Module: + """Initialize the model.""" + model = MODELS.build(self.cfg.model) + # TODO need update + # fp16_cfg = self.cfg.get('fp16', None) + # if fp16_cfg is not None: + # wrap_fp16_model(model) + + load_checkpoint(model, checkpoint, map_location='cpu') + if is_fuse_conv_bn: + model = fuse_conv_bn(model) + + model = model.cuda() + + if self.distributed: + model = DistributedDataParallel( + model, + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=False) + + model.eval() + return model + + def run_once(self) -> dict: + """Executes the benchmark once.""" + pure_inf_time = 0 + fps = 0 + + for i, data in enumerate(self.data_loader): + + if (i + 1) % self.log_interval == 0: + print_log('==================================', self.logger) + + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + self.model.test_step(data) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= self.num_warmup: + pure_inf_time += elapsed + if (i + 1) % self.log_interval == 0: + fps = (i + 1 - self.num_warmup) / pure_inf_time + cuda_memory = get_max_cuda_memory() + + print_log( + f'Done image [{i + 1:<3}/{self.max_iter}], ' + f'fps: {fps:.1f} img/s, ' + f'times per image: {1000 / fps:.1f} ms/img, ' + f'cuda memory: {cuda_memory} MB', self.logger) + print_process_memory(self._process, self.logger) + + if (i + 1) == self.max_iter: + fps = (i + 1 - self.num_warmup) / pure_inf_time + break + + return {'fps': fps} + + def average_multiple_runs(self, results: List[dict]) -> dict: + """Average the results of multiple runs.""" + print_log('============== Done ==================', self.logger) + + fps_list_ = [round(result['fps'], 1) for result in results] + avg_fps_ = sum(fps_list_) / len(fps_list_) + outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_} + + if len(fps_list_) > 1: + times_pre_image_list_ = [ + round(1000 / result['fps'], 1) for result in results + ] + avg_times_pre_image_ = sum(times_pre_image_list_) / len( + times_pre_image_list_) + + print_log( + f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, ' + 'times per image: ' + f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] ' + 'ms/img', self.logger) + else: + print_log( + f'Overall fps: {fps_list_[0]:.1f} img/s, ' + f'times per image: {1000 / fps_list_[0]:.1f} ms/img', + self.logger) + + print_log(f'cuda memory: {get_max_cuda_memory()} MB', self.logger) + print_process_memory(self._process, self.logger) + + return outputs + + +class DataLoaderBenchmark(BaseBenchmark): + """The dataloader benchmark class. It will be statistical inference FPS and + CPU memory information. + + Args: + cfg (mmengine.Config): config. + distributed (bool): distributed testing flag. + dataset_type (str): benchmark data type, only supports ``train``, + ``val`` and ``test``. + max_iter (int): maximum iterations of benchmark. Defaults to 2000. + log_interval (int): interval of logging. Defaults to 50. + num_warmup (int): Number of Warmup. Defaults to 5. + logger (MMLogger, optional): Formatted logger used to record messages. + """ + + def __init__(self, + cfg: Config, + distributed: bool, + dataset_type: str, + max_iter: int = 2000, + log_interval: int = 50, + num_warmup: int = 5, + logger: Optional[MMLogger] = None): + super().__init__(max_iter, log_interval, num_warmup, logger) + + assert dataset_type in ['train', 'val', 'test'], \ + 'dataset_type only supports train,' \ + f' val and test, but got {dataset_type}' + assert get_world_size( + ) == 1, 'Dataloader benchmark does not allow distributed multi-GPU' + + self.cfg = copy.deepcopy(cfg) + self.distributed = distributed + + if psutil is None: + raise ImportError('psutil is not installed, please install it by: ' + 'pip install psutil') + self._process = psutil.Process() + + mp_cfg = self.cfg.get('env_cfg', {}).get('mp_cfg') + if mp_cfg is not None: + set_multi_processing(distributed=self.distributed, **mp_cfg) + else: + set_multi_processing(distributed=self.distributed) + + print_log('before build: ', self.logger) + print_process_memory(self._process, self.logger) + + if dataset_type == 'train': + self.data_loader = Runner.build_dataloader(cfg.train_dataloader) + elif dataset_type == 'test': + self.data_loader = Runner.build_dataloader(cfg.test_dataloader) + else: + self.data_loader = Runner.build_dataloader(cfg.val_dataloader) + + self.batch_size = self.data_loader.batch_size + self.num_workers = self.data_loader.num_workers + + print_log('after build: ', self.logger) + print_process_memory(self._process, self.logger) + + def run_once(self) -> dict: + """Executes the benchmark once.""" + pure_inf_time = 0 + fps = 0 + + # benchmark with 2000 image and take the average + start_time = time.perf_counter() + for i, data in enumerate(self.data_loader): + elapsed = time.perf_counter() - start_time + + if (i + 1) % self.log_interval == 0: + print_log('==================================', self.logger) + + if i >= self.num_warmup: + pure_inf_time += elapsed + if (i + 1) % self.log_interval == 0: + fps = (i + 1 - self.num_warmup) / pure_inf_time + + print_log( + f'Done batch [{i + 1:<3}/{self.max_iter}], ' + f'fps: {fps:.1f} batch/s, ' + f'times per batch: {1000 / fps:.1f} ms/batch, ' + f'batch size: {self.batch_size}, num_workers: ' + f'{self.num_workers}', self.logger) + print_process_memory(self._process, self.logger) + + if (i + 1) == self.max_iter: + fps = (i + 1 - self.num_warmup) / pure_inf_time + break + + start_time = time.perf_counter() + + return {'fps': fps} + + def average_multiple_runs(self, results: List[dict]) -> dict: + """Average the results of multiple runs.""" + print_log('============== Done ==================', self.logger) + + fps_list_ = [round(result['fps'], 1) for result in results] + avg_fps_ = sum(fps_list_) / len(fps_list_) + outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_} + + if len(fps_list_) > 1: + times_pre_image_list_ = [ + round(1000 / result['fps'], 1) for result in results + ] + avg_times_pre_image_ = sum(times_pre_image_list_) / len( + times_pre_image_list_) + + print_log( + f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, ' + 'times per batch: ' + f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] ' + f'ms/batch, batch size: {self.batch_size}, num_workers: ' + f'{self.num_workers}', self.logger) + else: + print_log( + f'Overall fps: {fps_list_[0]:.1f} batch/s, ' + f'times per batch: {1000 / fps_list_[0]:.1f} ms/batch, ' + f'batch size: {self.batch_size}, num_workers: ' + f'{self.num_workers}', self.logger) + + print_process_memory(self._process, self.logger) + + return outputs + + +class DatasetBenchmark(BaseBenchmark): + """The dataset benchmark class. It will be statistical inference FPS, FPS + pre transform and CPU memory information. + + Args: + cfg (mmengine.Config): config. + dataset_type (str): benchmark data type, only supports ``train``, + ``val`` and ``test``. + max_iter (int): maximum iterations of benchmark. Defaults to 2000. + log_interval (int): interval of logging. Defaults to 50. + num_warmup (int): Number of Warmup. Defaults to 5. + logger (MMLogger, optional): Formatted logger used to record messages. + """ + + def __init__(self, + cfg: Config, + dataset_type: str, + max_iter: int = 2000, + log_interval: int = 50, + num_warmup: int = 5, + logger: Optional[MMLogger] = None): + super().__init__(max_iter, log_interval, num_warmup, logger) + assert dataset_type in ['train', 'val', 'test'], \ + 'dataset_type only supports train,' \ + f' val and test, but got {dataset_type}' + assert get_world_size( + ) == 1, 'Dataset benchmark does not allow distributed multi-GPU' + self.cfg = copy.deepcopy(cfg) + + if dataset_type == 'train': + dataloader_cfg = copy.deepcopy(cfg.train_dataloader) + elif dataset_type == 'test': + dataloader_cfg = copy.deepcopy(cfg.test_dataloader) + else: + dataloader_cfg = copy.deepcopy(cfg.val_dataloader) + + dataset_cfg = dataloader_cfg.pop('dataset') + dataset = DATASETS.build(dataset_cfg) + if hasattr(dataset, 'full_init'): + dataset.full_init() + self.dataset = dataset + + def run_once(self) -> dict: + """Executes the benchmark once.""" + pure_inf_time = 0 + fps = 0 + + total_index = list(range(len(self.dataset))) + np.random.shuffle(total_index) + + start_time = time.perf_counter() + for i, idx in enumerate(total_index): + if (i + 1) % self.log_interval == 0: + print_log('==================================', self.logger) + + get_data_info_start_time = time.perf_counter() + data_info = self.dataset.get_data_info(idx) + get_data_info_elapsed = time.perf_counter( + ) - get_data_info_start_time + + if (i + 1) % self.log_interval == 0: + print_log(f'get_data_info - {get_data_info_elapsed * 1000} ms', + self.logger) + + for t in self.dataset.pipeline.transforms: + transform_start_time = time.perf_counter() + data_info = t(data_info) + transform_elapsed = time.perf_counter() - transform_start_time + + if (i + 1) % self.log_interval == 0: + print_log( + f'{t.__class__.__name__} - ' + f'{transform_elapsed * 1000} ms', self.logger) + + if data_info is None: + break + + elapsed = time.perf_counter() - start_time + + if i >= self.num_warmup: + pure_inf_time += elapsed + if (i + 1) % self.log_interval == 0: + fps = (i + 1 - self.num_warmup) / pure_inf_time + + print_log( + f'Done img [{i + 1:<3}/{self.max_iter}], ' + f'fps: {fps:.1f} img/s, ' + f'times per img: {1000 / fps:.1f} ms/img', self.logger) + + if (i + 1) == self.max_iter: + fps = (i + 1 - self.num_warmup) / pure_inf_time + break + + start_time = time.perf_counter() + + return {'fps': fps} + + def average_multiple_runs(self, results: List[dict]) -> dict: + """Average the results of multiple runs.""" + print_log('============== Done ==================', self.logger) + + fps_list_ = [round(result['fps'], 1) for result in results] + avg_fps_ = sum(fps_list_) / len(fps_list_) + outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_} + + if len(fps_list_) > 1: + times_pre_image_list_ = [ + round(1000 / result['fps'], 1) for result in results + ] + avg_times_pre_image_ = sum(times_pre_image_list_) / len( + times_pre_image_list_) + + print_log( + f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, ' + 'times per img: ' + f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] ' + 'ms/img', self.logger) + else: + print_log( + f'Overall fps: {fps_list_[0]:.1f} img/s, ' + f'times per img: {1000 / fps_list_[0]:.1f} ms/img', + self.logger) + + return outputs diff --git a/mmdet/utils/collect_env.py b/mmdet/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..b0eed80fe2e4630b78ea3b13fde6046914e47e8b --- /dev/null +++ b/mmdet/utils/collect_env.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmdet + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7] + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/mmdet/utils/compat_config.py b/mmdet/utils/compat_config.py new file mode 100644 index 0000000000000000000000000000000000000000..133adb65c2276401eca947e223e5b7c1760de418 --- /dev/null +++ b/mmdet/utils/compat_config.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings + +from mmengine.config import ConfigDict + + +def compat_cfg(cfg): + """This function would modify some filed to keep the compatibility of + config. + + For example, it will move some args which will be deprecated to the correct + fields. + """ + cfg = copy.deepcopy(cfg) + cfg = compat_imgs_per_gpu(cfg) + cfg = compat_loader_args(cfg) + cfg = compat_runner_args(cfg) + return cfg + + +def compat_runner_args(cfg): + if 'runner' not in cfg: + cfg.runner = ConfigDict({ + 'type': 'EpochBasedRunner', + 'max_epochs': cfg.total_epochs + }) + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + else: + if 'total_epochs' in cfg: + assert cfg.total_epochs == cfg.runner.max_epochs + return cfg + + +def compat_imgs_per_gpu(cfg): + cfg = copy.deepcopy(cfg) + if 'imgs_per_gpu' in cfg.data: + warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. ' + 'Please use "samples_per_gpu" instead') + if 'samples_per_gpu' in cfg.data: + warnings.warn( + f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' + f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' + f'={cfg.data.imgs_per_gpu} is used in this experiments') + else: + warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"=' + f'{cfg.data.imgs_per_gpu} in this experiments') + cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu + return cfg + + +def compat_loader_args(cfg): + """Deprecated sample_per_gpu in cfg.data.""" + + cfg = copy.deepcopy(cfg) + if 'train_dataloader' not in cfg.data: + cfg.data['train_dataloader'] = ConfigDict() + if 'val_dataloader' not in cfg.data: + cfg.data['val_dataloader'] = ConfigDict() + if 'test_dataloader' not in cfg.data: + cfg.data['test_dataloader'] = ConfigDict() + + # special process for train_dataloader + if 'samples_per_gpu' in cfg.data: + + samples_per_gpu = cfg.data.pop('samples_per_gpu') + assert 'samples_per_gpu' not in \ + cfg.data.train_dataloader, ('`samples_per_gpu` are set ' + 'in `data` field and ` ' + 'data.train_dataloader` ' + 'at the same time. ' + 'Please only set it in ' + '`data.train_dataloader`. ') + cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu + + if 'persistent_workers' in cfg.data: + + persistent_workers = cfg.data.pop('persistent_workers') + assert 'persistent_workers' not in \ + cfg.data.train_dataloader, ('`persistent_workers` are set ' + 'in `data` field and ` ' + 'data.train_dataloader` ' + 'at the same time. ' + 'Please only set it in ' + '`data.train_dataloader`. ') + cfg.data.train_dataloader['persistent_workers'] = persistent_workers + + if 'workers_per_gpu' in cfg.data: + + workers_per_gpu = cfg.data.pop('workers_per_gpu') + cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu + cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu + cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu + + # special process for val_dataloader + if 'samples_per_gpu' in cfg.data.val: + # keep default value of `sample_per_gpu` is 1 + assert 'samples_per_gpu' not in \ + cfg.data.val_dataloader, ('`samples_per_gpu` are set ' + 'in `data.val` field and ` ' + 'data.val_dataloader` at ' + 'the same time. ' + 'Please only set it in ' + '`data.val_dataloader`. ') + cfg.data.val_dataloader['samples_per_gpu'] = \ + cfg.data.val.pop('samples_per_gpu') + # special process for val_dataloader + + # in case the test dataset is concatenated + if isinstance(cfg.data.test, dict): + if 'samples_per_gpu' in cfg.data.test: + assert 'samples_per_gpu' not in \ + cfg.data.test_dataloader, ('`samples_per_gpu` are set ' + 'in `data.test` field and ` ' + 'data.test_dataloader` ' + 'at the same time. ' + 'Please only set it in ' + '`data.test_dataloader`. ') + + cfg.data.test_dataloader['samples_per_gpu'] = \ + cfg.data.test.pop('samples_per_gpu') + + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + if 'samples_per_gpu' in ds_cfg: + assert 'samples_per_gpu' not in \ + cfg.data.test_dataloader, ('`samples_per_gpu` are set ' + 'in `data.test` field and ` ' + 'data.test_dataloader` at' + ' the same time. ' + 'Please only set it in ' + '`data.test_dataloader`. ') + samples_per_gpu = max( + [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) + cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu + + return cfg diff --git a/mmdet/utils/contextmanagers.py b/mmdet/utils/contextmanagers.py new file mode 100644 index 0000000000000000000000000000000000000000..fa12bfcaff1e781b0a8cc7d7c8b839c2f2955a05 --- /dev/null +++ b/mmdet/utils/contextmanagers.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import contextlib +import logging +import os +import time +from typing import List + +import torch + +logger = logging.getLogger(__name__) + +DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False)) + + +@contextlib.asynccontextmanager +async def completed(trace_name='', + name='', + sleep_interval=0.05, + streams: List[torch.cuda.Stream] = None): + """Async context manager that waits for work to complete on given CUDA + streams.""" + if not torch.cuda.is_available(): + yield + return + + stream_before_context_switch = torch.cuda.current_stream() + if not streams: + streams = [stream_before_context_switch] + else: + streams = [s if s else stream_before_context_switch for s in streams] + + end_events = [ + torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams + ] + + if DEBUG_COMPLETED_TIME: + start = torch.cuda.Event(enable_timing=True) + stream_before_context_switch.record_event(start) + + cpu_start = time.monotonic() + logger.debug('%s %s starting, streams: %s', trace_name, name, streams) + grad_enabled_before = torch.is_grad_enabled() + try: + yield + finally: + current_stream = torch.cuda.current_stream() + assert current_stream == stream_before_context_switch + + if DEBUG_COMPLETED_TIME: + cpu_end = time.monotonic() + for i, stream in enumerate(streams): + event = end_events[i] + stream.record_event(event) + + grad_enabled_after = torch.is_grad_enabled() + + # observed change of torch.is_grad_enabled() during concurrent run of + # async_test_bboxes code + assert (grad_enabled_before == grad_enabled_after + ), 'Unexpected is_grad_enabled() value change' + + are_done = [e.query() for e in end_events] + logger.debug('%s %s completed: %s streams: %s', trace_name, name, + are_done, streams) + with torch.cuda.stream(stream_before_context_switch): + while not all(are_done): + await asyncio.sleep(sleep_interval) + are_done = [e.query() for e in end_events] + logger.debug( + '%s %s completed: %s streams: %s', + trace_name, + name, + are_done, + streams, + ) + + current_stream = torch.cuda.current_stream() + assert current_stream == stream_before_context_switch + + if DEBUG_COMPLETED_TIME: + cpu_time = (cpu_end - cpu_start) * 1000 + stream_times_ms = '' + for i, stream in enumerate(streams): + elapsed_time = start.elapsed_time(end_events[i]) + stream_times_ms += f' {stream} {elapsed_time:.2f} ms' + logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time, + stream_times_ms) + + +@contextlib.asynccontextmanager +async def concurrent(streamqueue: asyncio.Queue, + trace_name='concurrent', + name='stream'): + """Run code concurrently in different streams. + + :param streamqueue: asyncio.Queue instance. + + Queue tasks define the pool of streams used for concurrent execution. + """ + if not torch.cuda.is_available(): + yield + return + + initial_stream = torch.cuda.current_stream() + + with torch.cuda.stream(initial_stream): + stream = await streamqueue.get() + assert isinstance(stream, torch.cuda.Stream) + + try: + with torch.cuda.stream(stream): + logger.debug('%s %s is starting, stream: %s', trace_name, name, + stream) + yield + current = torch.cuda.current_stream() + assert current == stream + logger.debug('%s %s has finished, stream: %s', trace_name, + name, stream) + finally: + streamqueue.task_done() + streamqueue.put_nowait(stream) diff --git a/mmdet/utils/dist_utils.py b/mmdet/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2c8614a181ec0594ba157002a2760737e2c6e3 --- /dev/null +++ b/mmdet/utils/dist_utils.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import pickle +import warnings +from collections import OrderedDict + +import numpy as np +import torch +import torch.distributed as dist +from mmengine.dist import get_dist_info +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + """Allreduce gradients. + + Args: + params (list[torch.Parameters]): List of parameters of a model + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +def reduce_mean(tensor): + """"Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +def obj2tensor(pyobj, device='cuda'): + """Serialize picklable python object to tensor.""" + storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) + return torch.ByteTensor(storage).to(device=device) + + +def tensor2obj(tensor): + """Deserialize tensor to picklable python object.""" + return pickle.loads(tensor.cpu().numpy().tobytes()) + + +@functools.lru_cache() +def _get_global_gloo_group(): + """Return a process group based on gloo backend, containing all the ranks + The result is cached.""" + if dist.get_backend() == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + + +def all_reduce_dict(py_dict, op='sum', group=None, to_float=True): + """Apply all reduce function for python dict object. + + The code is modified from https://github.com/Megvii- + BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. + + NOTE: make sure that py_dict in different ranks has the same keys and + the values should be in the same shape. Currently only supports + nccl backend. + + Args: + py_dict (dict): Dict to be applied all reduce op. + op (str): Operator, could be 'sum' or 'mean'. Default: 'sum' + group (:obj:`torch.distributed.group`, optional): Distributed group, + Default: None. + to_float (bool): Whether to convert all values of dict to float. + Default: True. + + Returns: + OrderedDict: reduced python dict object. + """ + warnings.warn( + 'group` is deprecated. Currently only supports NCCL backend.') + _, world_size = get_dist_info() + if world_size == 1: + return py_dict + + # all reduce logic across different devices. + py_key = list(py_dict.keys()) + if not isinstance(py_dict, OrderedDict): + py_key_tensor = obj2tensor(py_key) + dist.broadcast(py_key_tensor, src=0) + py_key = tensor2obj(py_key_tensor) + + tensor_shapes = [py_dict[k].shape for k in py_key] + tensor_numels = [py_dict[k].numel() for k in py_key] + + if to_float: + warnings.warn('Note: the "to_float" is True, you need to ' + 'ensure that the behavior is reasonable.') + flatten_tensor = torch.cat( + [py_dict[k].flatten().float() for k in py_key]) + else: + flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) + + dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM) + if op == 'mean': + flatten_tensor /= world_size + + split_tensors = [ + x.reshape(shape) for x, shape in zip( + torch.split(flatten_tensor, tensor_numels), tensor_shapes) + ] + out_dict = {k: v for k, v in zip(py_key, split_tensors)} + if isinstance(py_dict, OrderedDict): + out_dict = OrderedDict(out_dict) + return out_dict + + +def sync_random_seed(seed=None, device='cuda'): + """Make sure different ranks share the same seed. + + All workers must call this function, otherwise it will deadlock. + This method is generally used in `DistributedSampler`, + because the seed should be identical across all processes + in the distributed group. + + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() diff --git a/mmdet/utils/large_image.py b/mmdet/utils/large_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f1f07c2bdc6958f2b3bdd69da0a639276252a91e --- /dev/null +++ b/mmdet/utils/large_image.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +from mmcv.ops import batched_nms +from mmengine.structures import InstanceData + +from mmdet.structures import DetDataSample, SampleList + + +def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]): + """Shift rotated bboxes with offset. + + Args: + bboxes (Tensor): The rotated bboxes need to be translated. + With shape (n, 5), which means (x, y, w, h, a). + offset (Sequence[int]): The translation offsets with shape of (2, ). + Returns: + Tensor: Shifted rotated bboxes. + """ + offset_tensor = bboxes.new_tensor(offset) + shifted_bboxes = bboxes.clone() + shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor + return shifted_bboxes + + +def shift_predictions(det_data_samples: SampleList, + offsets: Sequence[Tuple[int, int]], + src_image_shape: Tuple[int, int]) -> SampleList: + """Shift predictions to the original image. + + Args: + det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. + offsets (Sequence[Tuple[int, int]]): Positions of the left top points + of patches. + src_image_shape (Tuple[int, int]): A (height, width) tuple of the large + image's width and height. + Returns: + (List[:obj:`DetDataSample`]): shifted results. + """ + try: + from sahi.slicing import shift_bboxes, shift_masks + except ImportError: + raise ImportError('Please run "pip install -U sahi" ' + 'to install sahi first for large image inference.') + + assert len(det_data_samples) == len( + offsets), 'The `results` should has the ' 'same length with `offsets`.' + shifted_predictions = [] + for det_data_sample, offset in zip(det_data_samples, offsets): + pred_inst = det_data_sample.pred_instances.clone() + + # Check bbox type + if pred_inst.bboxes.size(-1) == 4: + # Horizontal bboxes + shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset) + elif pred_inst.bboxes.size(-1) == 5: + # Rotated bboxes + shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset) + else: + raise NotImplementedError + + # shift bboxes and masks + pred_inst.bboxes = shifted_bboxes + if 'masks' in det_data_sample: + pred_inst.masks = shift_masks(pred_inst.masks, offset, + src_image_shape) + + shifted_predictions.append(pred_inst.clone()) + + shifted_predictions = InstanceData.cat(shifted_predictions) + + return shifted_predictions + + +def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, + int]], + src_image_shape: Tuple[int, int], + nms_cfg: dict) -> DetDataSample: + """Merge patch results by nms. + + Args: + results (List[:obj:`DetDataSample`]): A list of patch results. + offsets (Sequence[Tuple[int, int]]): Positions of the left top points + of patches. + src_image_shape (Tuple[int, int]): A (height, width) tuple of the large + image's width and height. + nms_cfg (dict): it should specify nms type and other parameters + like `iou_threshold`. + Returns: + :obj:`DetDataSample`: merged results. + """ + shifted_instances = shift_predictions(results, offsets, src_image_shape) + + _, keeps = batched_nms( + boxes=shifted_instances.bboxes, + scores=shifted_instances.scores, + idxs=shifted_instances.labels, + nms_cfg=nms_cfg) + merged_instances = shifted_instances[keeps] + + merged_result = results[0].clone() + merged_result.pred_instances = merged_instances + return merged_result diff --git a/mmdet/utils/logger.py b/mmdet/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..9fec08bbad5517c9169eedb15b4768e7d88d39c7 --- /dev/null +++ b/mmdet/utils/logger.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect + +from mmengine.logging import print_log + + +def get_caller_name(): + """Get name of caller method.""" + # this_func_frame = inspect.stack()[0][0] # i.e., get_caller_name + # callee_frame = inspect.stack()[1][0] # e.g., log_img_scale + caller_frame = inspect.stack()[2][0] # e.g., caller of log_img_scale + caller_method = caller_frame.f_code.co_name + try: + caller_class = caller_frame.f_locals['self'].__class__.__name__ + return f'{caller_class}.{caller_method}' + except KeyError: # caller is a function + return caller_method + + +def log_img_scale(img_scale, shape_order='hw', skip_square=False): + """Log image size. + + Args: + img_scale (tuple): Image size to be logged. + shape_order (str, optional): The order of image shape. + 'hw' for (height, width) and 'wh' for (width, height). + Defaults to 'hw'. + skip_square (bool, optional): Whether to skip logging for square + img_scale. Defaults to False. + + Returns: + bool: Whether to have done logging. + """ + if shape_order == 'hw': + height, width = img_scale + elif shape_order == 'wh': + width, height = img_scale + else: + raise ValueError(f'Invalid shape_order {shape_order}.') + + if skip_square and (height == width): + return False + + caller = get_caller_name() + print_log( + f'image shape: height={height}, width={width} in {caller}', + logger='current') + + return True diff --git a/mmdet/utils/memory.py b/mmdet/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f9cbc7f9e5f54e2cc429e5e655b2a27d38d61f --- /dev/null +++ b/mmdet/utils/memory.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import abc +from contextlib import contextmanager +from functools import wraps + +import torch +from mmengine.logging import MMLogger + + +def cast_tensor_type(inputs, src_type=None, dst_type=None): + """Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``. + + Args: + inputs: Inputs that to be casted. + src_type (torch.dtype | torch.device): Source type. + src_type (torch.dtype | torch.device): Destination type. + + Returns: + The same type with inputs, but all contained Tensors have been cast. + """ + assert dst_type is not None + if isinstance(inputs, torch.Tensor): + if isinstance(dst_type, torch.device): + # convert Tensor to dst_device + if hasattr(inputs, 'to') and \ + hasattr(inputs, 'device') and \ + (inputs.device == src_type or src_type is None): + return inputs.to(dst_type) + else: + return inputs + else: + # convert Tensor to dst_dtype + if hasattr(inputs, 'to') and \ + hasattr(inputs, 'dtype') and \ + (inputs.dtype == src_type or src_type is None): + return inputs.to(dst_type) + else: + return inputs + # we need to ensure that the type of inputs to be casted are the same + # as the argument `src_type`. + elif isinstance(inputs, abc.Mapping): + return type(inputs)({ + k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type) + for k, v in inputs.items() + }) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( + cast_tensor_type(item, src_type=src_type, dst_type=dst_type) + for item in inputs) + # TODO: Currently not supported + # elif isinstance(inputs, InstanceData): + # for key, value in inputs.items(): + # inputs[key] = cast_tensor_type( + # value, src_type=src_type, dst_type=dst_type) + # return inputs + else: + return inputs + + +@contextmanager +def _ignore_torch_cuda_oom(): + """A context which ignores CUDA OOM exception from pytorch. + + Code is modified from + # noqa: E501 + """ + try: + yield + except RuntimeError as e: + # NOTE: the string may change? + if 'CUDA out of memory. ' in str(e): + pass + else: + raise + + +class AvoidOOM: + """Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of + Memory error. It will do the following steps: + + 1. First retry after calling `torch.cuda.empty_cache()`. + 2. If that still fails, it will then retry by converting inputs + to FP16. + 3. If that still fails trying to convert inputs to CPUs. + In this case, it expects the function to dispatch to + CPU implementation. + + Args: + to_cpu (bool): Whether to convert outputs to CPU if get an OOM + error. This will slow down the code significantly. + Defaults to True. + test (bool): Skip `_ignore_torch_cuda_oom` operate that can use + lightweight data in unit test, only used in + test unit. Defaults to False. + + Examples: + >>> from mmdet.utils.memory import AvoidOOM + >>> AvoidCUDAOOM = AvoidOOM() + >>> output = AvoidOOM.retry_if_cuda_oom( + >>> some_torch_function)(input1, input2) + >>> # To use as a decorator + >>> # from mmdet.utils import AvoidCUDAOOM + >>> @AvoidCUDAOOM.retry_if_cuda_oom + >>> def function(*args, **kwargs): + >>> return None + ``` + + Note: + 1. The output may be on CPU even if inputs are on GPU. Processing + on CPU will slow down the code significantly. + 2. When converting inputs to CPU, it will only look at each argument + and check if it has `.device` and `.to` for conversion. Nested + structures of tensors are not supported. + 3. Since the function might be called more than once, it has to be + stateless. + """ + + def __init__(self, to_cpu=True, test=False): + self.to_cpu = to_cpu + self.test = test + + def retry_if_cuda_oom(self, func): + """Makes a function retry itself after encountering pytorch's CUDA OOM + error. + + The implementation logic is referred to + https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py + + Args: + func: a stateless callable that takes tensor-like objects + as arguments. + Returns: + func: a callable which retries `func` if OOM is encountered. + """ # noqa: W605 + + @wraps(func) + def wrapped(*args, **kwargs): + + # raw function + if not self.test: + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Clear cache and retry + torch.cuda.empty_cache() + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # get the type and device of first tensor + dtype, device = None, None + values = args + tuple(kwargs.values()) + for value in values: + if isinstance(value, torch.Tensor): + dtype = value.dtype + device = value.device + break + if dtype is None or device is None: + raise ValueError('There is no tensor in the inputs, ' + 'cannot get dtype and device.') + + # Convert to FP16 + fp16_args = cast_tensor_type(args, dst_type=torch.half) + fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half) + logger = MMLogger.get_current_instance() + logger.warning(f'Attempting to copy inputs of {str(func)} ' + 'to FP16 due to CUDA OOM') + + # get input tensor type, the output type will same as + # the first parameter type. + with _ignore_torch_cuda_oom(): + output = func(*fp16_args, **fp16_kwargs) + output = cast_tensor_type( + output, src_type=torch.half, dst_type=dtype) + if not self.test: + return output + logger.warning('Using FP16 still meet CUDA OOM') + + # Try on CPU. This will slow down the code significantly, + # therefore print a notice. + if self.to_cpu: + logger.warning(f'Attempting to copy inputs of {str(func)} ' + 'to CPU due to CUDA OOM') + cpu_device = torch.empty(0).device + cpu_args = cast_tensor_type(args, dst_type=cpu_device) + cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device) + + # convert outputs to GPU + with _ignore_torch_cuda_oom(): + logger.warning(f'Convert outputs to GPU (device={device})') + output = func(*cpu_args, **cpu_kwargs) + output = cast_tensor_type( + output, src_type=cpu_device, dst_type=device) + return output + + warnings.warn('Cannot convert output to GPU due to CUDA OOM, ' + 'the output is now on CPU, which might cause ' + 'errors if the output need to interact with GPU ' + 'data in subsequent operations') + logger.warning('Cannot convert output to GPU due to ' + 'CUDA OOM, the output is on CPU now.') + + return func(*cpu_args, **cpu_kwargs) + else: + # may still get CUDA OOM error + return func(*args, **kwargs) + + return wrapped + + +# To use AvoidOOM as a decorator +AvoidCUDAOOM = AvoidOOM() diff --git a/mmdet/utils/misc.py b/mmdet/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8dfb394465196cbd1e60c96f5be3aaee416d59cf --- /dev/null +++ b/mmdet/utils/misc.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os +import os.path as osp +import urllib +import warnings +from typing import Union + +import torch +from mmengine.config import Config, ConfigDict +from mmengine.logging import print_log +from mmengine.utils import scandir + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', + '.tiff', '.webp') + + +def find_latest_checkpoint(path, suffix='pth'): + """Find the latest checkpoint from the working directory. + + Args: + path(str): The path to find checkpoints. + suffix(str): File extension. + Defaults to pth. + + Returns: + latest_path(str | None): File path of the latest checkpoint. + References: + .. [1] https://github.com/microsoft/SoftTeacher + /blob/main/ssod/utils/patch.py + """ + if not osp.exists(path): + warnings.warn('The path of checkpoints does not exist.') + return None + if osp.exists(osp.join(path, f'latest.{suffix}')): + return osp.join(path, f'latest.{suffix}') + + checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) + if len(checkpoints) == 0: + warnings.warn('There are no checkpoints in the path.') + return None + latest = -1 + latest_path = None + for checkpoint in checkpoints: + count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) + if count > latest: + latest = count + latest_path = checkpoint + return latest_path + + +def update_data_root(cfg, logger=None): + """Update data root according to env MMDET_DATASETS. + + If set env MMDET_DATASETS, update cfg.data_root according to + MMDET_DATASETS. Otherwise, using cfg.data_root as default. + + Args: + cfg (:obj:`Config`): The model config need to modify + logger (logging.Logger | str | None): the way to print msg + """ + assert isinstance(cfg, Config), \ + f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' + + if 'MMDET_DATASETS' in os.environ: + dst_root = os.environ['MMDET_DATASETS'] + print_log(f'MMDET_DATASETS has been set to be {dst_root}.' + f'Using {dst_root} as data root.') + else: + return + + assert isinstance(cfg, Config), \ + f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' + + def update(cfg, src_str, dst_str): + for k, v in cfg.items(): + if isinstance(v, ConfigDict): + update(cfg[k], src_str, dst_str) + if isinstance(v, str) and src_str in v: + cfg[k] = v.replace(src_str, dst_str) + + update(cfg.data, cfg.data_root, dst_root) + cfg.data_root = dst_root + + +def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict: + """Get the test dataset pipeline from entire config. + + Args: + cfg (str or :obj:`ConfigDict`): the entire config. Can be a config + file or a ``ConfigDict``. + + Returns: + :obj:`ConfigDict`: the config of test dataset. + """ + if isinstance(cfg, str): + cfg = Config.fromfile(cfg) + + def _get_test_pipeline_cfg(dataset_cfg): + if 'pipeline' in dataset_cfg: + return dataset_cfg.pipeline + # handle dataset wrapper + elif 'dataset' in dataset_cfg: + return _get_test_pipeline_cfg(dataset_cfg.dataset) + # handle dataset wrappers like ConcatDataset + elif 'datasets' in dataset_cfg: + return _get_test_pipeline_cfg(dataset_cfg.datasets[0]) + + raise RuntimeError('Cannot find `pipeline` in `test_dataloader`') + + return _get_test_pipeline_cfg(cfg.test_dataloader.dataset) + + +def get_file_list(source_root: str) -> [list, dict]: + """Get file list. + + Args: + source_root (str): image or video source path + + Return: + source_file_path_list (list): A list for all source file. + source_type (dict): Source type: file or url or dir. + """ + is_dir = os.path.isdir(source_root) + is_url = source_root.startswith(('http:/', 'https:/')) + is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS + + source_file_path_list = [] + if is_dir: + # when input source is dir + for file in scandir(source_root, IMG_EXTENSIONS, recursive=True): + source_file_path_list.append(os.path.join(source_root, file)) + elif is_url: + # when input source is url + filename = os.path.basename( + urllib.parse.unquote(source_root).split('?')[0]) + file_save_path = os.path.join(os.getcwd(), filename) + print(f'Downloading source file to {file_save_path}') + torch.hub.download_url_to_file(source_root, file_save_path) + source_file_path_list = [file_save_path] + elif is_file: + # when input source is single image + source_file_path_list = [source_root] + else: + print('Cannot find image file.') + + source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file) + + return source_file_path_list, source_type diff --git a/mmdet/utils/mot_error_visualize.py b/mmdet/utils/mot_error_visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..01bf8645d340aa1f5ab8251211a719f2de9845b1 --- /dev/null +++ b/mmdet/utils/mot_error_visualize.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Union + +try: + import seaborn as sns +except ImportError: + sns = None +import cv2 +import matplotlib.pyplot as plt +import mmcv +import numpy as np +from matplotlib.patches import Rectangle +from mmengine.utils import mkdir_or_exist + + +def imshow_mot_errors(*args, backend: str = 'cv2', **kwargs): + """Show the wrong tracks on the input image. + + Args: + backend (str, optional): Backend of visualization. + Defaults to 'cv2'. + """ + if backend == 'cv2': + return _cv2_show_wrong_tracks(*args, **kwargs) + elif backend == 'plt': + return _plt_show_wrong_tracks(*args, **kwargs) + else: + raise NotImplementedError() + + +def _cv2_show_wrong_tracks(img: Union[str, np.ndarray], + bboxes: np.ndarray, + ids: np.ndarray, + error_types: np.ndarray, + thickness: int = 2, + font_scale: float = 0.4, + text_width: int = 10, + text_height: int = 15, + show: bool = False, + wait_time: int = 100, + out_file: str = None) -> np.ndarray: + """Show the wrong tracks with opencv. + + Args: + img (str or ndarray): The image to be displayed. + bboxes (ndarray): A ndarray of shape (k, 5). + ids (ndarray): A ndarray of shape (k, ). + error_types (ndarray): A ndarray of shape (k, ), where 0 denotes + false positives, 1 denotes false negative and 2 denotes ID switch. + thickness (int, optional): Thickness of lines. + Defaults to 2. + font_scale (float, optional): Font scale to draw id and score. + Defaults to 0.4. + text_width (int, optional): Width to draw id and score. + Defaults to 10. + text_height (int, optional): Height to draw id and score. + Defaults to 15. + show (bool, optional): Whether to show the image on the fly. + Defaults to False. + wait_time (int, optional): Value of waitKey param. + Defaults to 100. + out_file (str, optional): The filename to write the image. + Defaults to None. + + Returns: + ndarray: Visualized image. + """ + if sns is None: + raise ImportError('please run pip install seaborn') + assert bboxes.ndim == 2, \ + f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' + assert ids.ndim == 1, \ + f' ids ndim should be 1, but its ndim is {ids.ndim}.' + assert error_types.ndim == 1, \ + f' error_types ndim should be 1, but its ndim is {error_types.ndim}.' + assert bboxes.shape[0] == ids.shape[0], \ + 'bboxes.shape[0] and ids.shape[0] should have the same length.' + assert bboxes.shape[1] == 5, \ + f' bboxes.shape[1] should be 5, but its {bboxes.shape[1]}.' + + bbox_colors = sns.color_palette() + # red, yellow, blue + bbox_colors = [bbox_colors[3], bbox_colors[1], bbox_colors[0]] + bbox_colors = [[int(255 * _c) for _c in bbox_color][::-1] + for bbox_color in bbox_colors] + + if isinstance(img, str): + img = mmcv.imread(img) + else: + assert img.ndim == 3 + + img_shape = img.shape + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + + for bbox, error_type, id in zip(bboxes, error_types, ids): + x1, y1, x2, y2 = bbox[:4].astype(np.int32) + score = float(bbox[-1]) + + # bbox + bbox_color = bbox_colors[error_type] + cv2.rectangle(img, (x1, y1), (x2, y2), bbox_color, thickness=thickness) + + # FN does not have id and score + if error_type == 1: + continue + + # score + text = '{:.02f}'.format(score) + width = (len(text) - 1) * text_width + img[y1:y1 + text_height, x1:x1 + width, :] = bbox_color + cv2.putText( + img, + text, (x1, y1 + text_height - 2), + cv2.FONT_HERSHEY_COMPLEX, + font_scale, + color=(0, 0, 0)) + + # id + text = str(id) + width = len(text) * text_width + img[y1 + text_height:y1 + text_height * 2, + x1:x1 + width, :] = bbox_color + cv2.putText( + img, + str(id), (x1, y1 + text_height * 2 - 2), + cv2.FONT_HERSHEY_COMPLEX, + font_scale, + color=(0, 0, 0)) + + if show: + mmcv.imshow(img, wait_time=wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + return img + + +def _plt_show_wrong_tracks(img: Union[str, np.ndarray], + bboxes: np.ndarray, + ids: np.ndarray, + error_types: np.ndarray, + thickness: float = 0.1, + font_scale: float = 3.0, + text_width: int = 8, + text_height: int = 13, + show: bool = False, + wait_time: int = 100, + out_file: str = None) -> np.ndarray: + """Show the wrong tracks with matplotlib. + + Args: + img (str or ndarray): The image to be displayed. + bboxes (ndarray): A ndarray of shape (k, 5). + ids (ndarray): A ndarray of shape (k, ). + error_types (ndarray): A ndarray of shape (k, ), where 0 denotes + false positives, 1 denotes false negative and 2 denotes ID switch. + thickness (float, optional): Thickness of lines. + Defaults to 0.1. + font_scale (float, optional): Font scale to draw id and score. + Defaults to 3.0. + text_width (int, optional): Width to draw id and score. + Defaults to 8. + text_height (int, optional): Height to draw id and score. + Defaults to 13. + show (bool, optional): Whether to show the image on the fly. + Defaults to False. + wait_time (int, optional): Value of waitKey param. + Defaults to 100. + out_file (str, optional): The filename to write the image. + Defaults to None. + + Returns: + ndarray: Original image. + """ + assert bboxes.ndim == 2, \ + f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' + assert ids.ndim == 1, \ + f' ids ndim should be 1, but its ndim is {ids.ndim}.' + assert error_types.ndim == 1, \ + f' error_types ndim should be 1, but its ndim is {error_types.ndim}.' + assert bboxes.shape[0] == ids.shape[0], \ + 'bboxes.shape[0] and ids.shape[0] should have the same length.' + assert bboxes.shape[1] == 5, \ + f' bboxes.shape[1] should be 5, but its {bboxes.shape[1]}.' + + bbox_colors = sns.color_palette() + # red, yellow, blue + bbox_colors = [bbox_colors[3], bbox_colors[1], bbox_colors[0]] + + if isinstance(img, str): + img = plt.imread(img) + else: + assert img.ndim == 3 + img = mmcv.bgr2rgb(img) + + img_shape = img.shape + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + + plt.imshow(img) + plt.gca().set_axis_off() + plt.autoscale(False) + plt.subplots_adjust( + top=1, bottom=0, right=1, left=0, hspace=None, wspace=None) + plt.margins(0, 0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.rcParams['figure.figsize'] = img_shape[1], img_shape[0] + + for bbox, error_type, id in zip(bboxes, error_types, ids): + x1, y1, x2, y2, score = bbox + w, h = int(x2 - x1), int(y2 - y1) + left_top = (int(x1), int(y1)) + + # bbox + plt.gca().add_patch( + Rectangle( + left_top, + w, + h, + thickness, + edgecolor=bbox_colors[error_type], + facecolor='none')) + + # FN does not have id and score + if error_type == 1: + continue + + # score + text = '{:.02f}'.format(score) + width = len(text) * text_width + plt.gca().add_patch( + Rectangle((left_top[0], left_top[1]), + width, + text_height, + thickness, + edgecolor=bbox_colors[error_type], + facecolor=bbox_colors[error_type])) + + plt.text( + left_top[0], + left_top[1] + text_height + 2, + text, + fontsize=font_scale) + + # id + text = str(id) + width = len(text) * text_width + plt.gca().add_patch( + Rectangle((left_top[0], left_top[1] + text_height + 1), + width, + text_height, + thickness, + edgecolor=bbox_colors[error_type], + facecolor=bbox_colors[error_type])) + plt.text( + left_top[0], + left_top[1] + 2 * (text_height + 1), + text, + fontsize=font_scale) + + if out_file is not None: + mkdir_or_exist(osp.abspath(osp.dirname(out_file))) + plt.savefig(out_file, dpi=300, bbox_inches='tight', pad_inches=0.0) + + if show: + plt.draw() + plt.pause(wait_time / 1000.) + + plt.clf() + return img diff --git a/mmdet/utils/profiling.py b/mmdet/utils/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..2f53f456c72db57bfa69a8d022c92d153580209e --- /dev/null +++ b/mmdet/utils/profiling.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +import sys +import time + +import torch + +if sys.version_info >= (3, 7): + + @contextlib.contextmanager + def profile_time(trace_name, + name, + enabled=True, + stream=None, + end_stream=None): + """Print time spent by CPU and GPU. + + Useful as a temporary context manager to find sweet spots of code + suitable for async implementation. + """ + if (not enabled) or not torch.cuda.is_available(): + yield + return + stream = stream if stream else torch.cuda.current_stream() + end_stream = end_stream if end_stream else stream + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + stream.record_event(start) + try: + cpu_start = time.monotonic() + yield + finally: + cpu_end = time.monotonic() + end_stream.record_event(end) + end.synchronize() + cpu_time = (cpu_end - cpu_start) * 1000 + gpu_time = start.elapsed_time(end) + msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' + msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' + print(msg, end_stream) diff --git a/mmdet/utils/replace_cfg_vals.py b/mmdet/utils/replace_cfg_vals.py new file mode 100644 index 0000000000000000000000000000000000000000..a3331a36ce5a22fcc4d4a955d757f5e8b6bfc6bb --- /dev/null +++ b/mmdet/utils/replace_cfg_vals.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re + +from mmengine.config import Config + + +def replace_cfg_vals(ori_cfg): + """Replace the string "${key}" with the corresponding value. + + Replace the "${key}" with the value of ori_cfg.key in the config. And + support replacing the chained ${key}. Such as, replace "${key0.key1}" + with the value of cfg.key0.key1. Code is modified from `vars.py + < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501 + + Args: + ori_cfg (mmengine.config.Config): + The origin config with "${key}" generated from a file. + + Returns: + updated_cfg [mmengine.config.Config]: + The config with "${key}" replaced by the corresponding value. + """ + + def get_value(cfg, key): + for k in key.split('.'): + cfg = cfg[k] + return cfg + + def replace_value(cfg): + if isinstance(cfg, dict): + return {key: replace_value(value) for key, value in cfg.items()} + elif isinstance(cfg, list): + return [replace_value(item) for item in cfg] + elif isinstance(cfg, tuple): + return tuple([replace_value(item) for item in cfg]) + elif isinstance(cfg, str): + # the format of string cfg may be: + # 1) "${key}", which will be replaced with cfg.key directly + # 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx", + # which will be replaced with the string of the cfg.key + keys = pattern_key.findall(cfg) + values = [get_value(ori_cfg, key[2:-1]) for key in keys] + if len(keys) == 1 and keys[0] == cfg: + # the format of string cfg is "${key}" + cfg = values[0] + else: + for key, value in zip(keys, values): + # the format of string cfg is + # "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx" + assert not isinstance(value, (dict, list, tuple)), \ + f'for the format of string cfg is ' \ + f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \ + f"the type of the value of '${key}' " \ + f'can not be dict, list, or tuple' \ + f'but you input {type(value)} in {cfg}' + cfg = cfg.replace(key, str(value)) + return cfg + else: + return cfg + + # the pattern of string "${key}" + pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}') + # the type of ori_cfg._cfg_dict is mmengine.config.ConfigDict + updated_cfg = Config( + replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename) + # replace the model with model_wrapper + if updated_cfg.get('model_wrapper', None) is not None: + updated_cfg.model = updated_cfg.model_wrapper + updated_cfg.pop('model_wrapper') + return updated_cfg diff --git a/mmdet/utils/setup_env.py b/mmdet/utils/setup_env.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b37845a883752a1659fabf62c7404cff971191 --- /dev/null +++ b/mmdet/utils/setup_env.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import logging +import os +import platform +import warnings + +import cv2 +import torch.multiprocessing as mp +from mmengine import DefaultScope +from mmengine.logging import print_log +from mmengine.utils import digit_version + + +def setup_cache_size_limit_of_dynamo(): + """Setup cache size limit of dynamo. + + Note: Due to the dynamic shape of the loss calculation and + post-processing parts in the object detection algorithm, these + functions must be compiled every time they are run. + Setting a large value for torch._dynamo.config.cache_size_limit + may result in repeated compilation, which can slow down training + and testing speed. Therefore, we need to set the default value of + cache_size_limit smaller. An empirical value is 4. + """ + + import torch + if digit_version(torch.__version__) >= digit_version('2.0.0'): + if 'DYNAMO_CACHE_SIZE_LIMIT' in os.environ: + import torch._dynamo + cache_size_limit = int(os.environ['DYNAMO_CACHE_SIZE_LIMIT']) + torch._dynamo.config.cache_size_limit = cache_size_limit + print_log( + f'torch._dynamo.config.cache_size_limit is force ' + f'set to {cache_size_limit}.', + logger='current', + level=logging.WARNING) + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != 'Windows': + mp_start_method = cfg.get('mp_start_method', 'fork') + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f'Multi-processing start method `{mp_start_method}` is ' + f'different from the previous setting `{current_method}`.' + f'It will be force set to `{mp_start_method}`. You can change ' + f'this behavior by changing `mp_start_method` in your config.') + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get('opencv_num_threads', 0) + cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + workers_per_gpu = cfg.data.get('workers_per_gpu', 1) + if 'train_dataloader' in cfg.data: + workers_per_gpu = \ + max(cfg.data.train_dataloader.get('workers_per_gpu', 1), + workers_per_gpu) + + if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1: + omp_num_threads = 1 + warnings.warn( + f'Setting OMP_NUM_THREADS environment variable for each process ' + f'to be {omp_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + + # setup MKL threads + if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1: + mkl_num_threads = 1 + warnings.warn( + f'Setting MKL_NUM_THREADS environment variable for each process ' + f'to be {mkl_num_threads} in default, to avoid your system being ' + f'overloaded, please further tune the variable for optimal ' + f'performance in your application as needed.') + os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmdet into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmdet default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmdet`, and all registries will build modules from mmdet's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmdet.datasets # noqa: F401,F403 + import mmdet.engine # noqa: F401,F403 + import mmdet.evaluation # noqa: F401,F403 + import mmdet.models # noqa: F401,F403 + import mmdet.visualization # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmdet') + if never_created: + DefaultScope.get_instance('mmdet', scope_name='mmdet') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmdet': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmdet", ' + '`register_all_modules` will force the current' + 'default scope to be "mmdet". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmdet-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmdet') diff --git a/mmdet/utils/split_batch.py b/mmdet/utils/split_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..0276fb331f23c1a7f7451faf2a8f768e616d45fd --- /dev/null +++ b/mmdet/utils/split_batch.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def split_batch(img, img_metas, kwargs): + """Split data_batch by tags. + + Code is modified from + # noqa: E501 + + Args: + img (Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys, see + :class:`mmdet.datasets.pipelines.Collect`. + kwargs (dict): Specific to concrete implementation. + + Returns: + data_groups (dict): a dict that data_batch splited by tags, + such as 'sup', 'unsup_teacher', and 'unsup_student'. + """ + + # only stack img in the batch + def fuse_list(obj_list, obj): + return torch.stack(obj_list) if isinstance(obj, + torch.Tensor) else obj_list + + # select data with tag from data_batch + def select_group(data_batch, current_tag): + group_flag = [tag == current_tag for tag in data_batch['tag']] + return { + k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v) + for k, v in data_batch.items() + } + + kwargs.update({'img': img, 'img_metas': img_metas}) + kwargs.update({'tag': [meta['tag'] for meta in img_metas]}) + tags = list(set(kwargs['tag'])) + data_groups = {tag: select_group(kwargs, tag) for tag in tags} + for tag, group in data_groups.items(): + group.pop('tag') + return data_groups diff --git a/mmdet/utils/typing_utils.py b/mmdet/utils/typing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6caf6de53274594e139dbe7c1973c747229bf010 --- /dev/null +++ b/mmdet/utils/typing_utils.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Collecting some commonly used type hint in mmdetection.""" +from typing import List, Optional, Sequence, Tuple, Union + +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData, PixelData + +# TODO: Need to avoid circular import with assigner and sampler +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] +# Type hint of one or more config data +MultiConfig = Union[ConfigType, List[ConfigType]] +OptMultiConfig = Optional[MultiConfig] + +InstanceList = List[InstanceData] +OptInstanceList = Optional[InstanceList] + +PixelList = List[PixelData] +OptPixelList = Optional[PixelList] + +RangeType = Sequence[Tuple[int, int]] diff --git a/mmdet/utils/util_mixins.py b/mmdet/utils/util_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..b83b6617f5e4a202067e1659bf448962a2a2bc72 --- /dev/null +++ b/mmdet/utils/util_mixins.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module defines the :class:`NiceRepr` mixin class, which defines a +``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__`` +method, which you must define. This means you only have to overload one +function instead of two. Furthermore, if the object defines a ``__len__`` +method, then the ``__nice__`` method defaults to something sensible, otherwise +it is treated as abstract and raises ``NotImplementedError``. + +To use simply have your object inherit from :class:`NiceRepr` +(multi-inheritance should be ok). + +This code was copied from the ubelt library: https://github.com/Erotemic/ubelt + +Example: + >>> # Objects that define __nice__ have a default __str__ and __repr__ + >>> class Student(NiceRepr): + ... def __init__(self, name): + ... self.name = name + ... def __nice__(self): + ... return self.name + >>> s1 = Student('Alice') + >>> s2 = Student('Bob') + >>> print(f's1 = {s1}') + >>> print(f's2 = {s2}') + s1 = + s2 = + +Example: + >>> # Objects that define __len__ have a default __nice__ + >>> class Group(NiceRepr): + ... def __init__(self, data): + ... self.data = data + ... def __len__(self): + ... return len(self.data) + >>> g = Group([1, 2, 3]) + >>> print(f'g = {g}') + g = +""" +import warnings + + +class NiceRepr: + """Inherit from this class and define ``__nice__`` to "nicely" print your + objects. + + Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function + Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``. + If the inheriting class has a ``__len__``, method then the default + ``__nice__`` method will return its length. + + Example: + >>> class Foo(NiceRepr): + ... def __nice__(self): + ... return 'info' + >>> foo = Foo() + >>> assert str(foo) == '' + >>> assert repr(foo).startswith('>> class Bar(NiceRepr): + ... pass + >>> bar = Bar() + >>> import pytest + >>> with pytest.warns(None) as record: + >>> assert 'object at' in str(bar) + >>> assert 'object at' in repr(bar) + + Example: + >>> class Baz(NiceRepr): + ... def __len__(self): + ... return 5 + >>> baz = Baz() + >>> assert str(baz) == '' + """ + + def __nice__(self): + """str: a "nice" summary string describing this module""" + if hasattr(self, '__len__'): + # It is a common pattern for objects to use __len__ in __nice__ + # As a convenience we define a default __nice__ for these objects + return str(len(self)) + else: + # In all other cases force the subclass to overload __nice__ + raise NotImplementedError( + f'Define the __nice__ method for {self.__class__!r}') + + def __repr__(self): + """str: the string of the module""" + try: + nice = self.__nice__() + classname = self.__class__.__name__ + return f'<{classname}({nice}) at {hex(id(self))}>' + except NotImplementedError as ex: + warnings.warn(str(ex), category=RuntimeWarning) + return object.__repr__(self) + + def __str__(self): + """str: the string of the module""" + try: + classname = self.__class__.__name__ + nice = self.__nice__() + return f'<{classname}({nice})>' + except NotImplementedError as ex: + warnings.warn(str(ex), category=RuntimeWarning) + return object.__repr__(self) diff --git a/mmdet/utils/util_random.py b/mmdet/utils/util_random.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1ecb6c03b026156c9947cb6d356a822448be0f --- /dev/null +++ b/mmdet/utils/util_random.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Helpers for random number generators.""" +import numpy as np + + +def ensure_rng(rng=None): + """Coerces input into a random number generator. + + If the input is None, then a global random state is returned. + + If the input is a numeric value, then that is used as a seed to construct a + random state. Otherwise the input is returned as-is. + + Adapted from [1]_. + + Args: + rng (int | numpy.random.RandomState | None): + if None, then defaults to the global rng. Otherwise this can be an + integer or a RandomState class + Returns: + (numpy.random.RandomState) : rng - + a numpy random number generator + + References: + .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501 + """ + + if rng is None: + rng = np.random.mtrand._rand + elif isinstance(rng, int): + rng = np.random.RandomState(rng) + else: + rng = rng + return rng diff --git a/mmdet/version.py b/mmdet/version.py new file mode 100644 index 0000000000000000000000000000000000000000..38ce834e15205fdef803aa27a61d68fb9e111982 --- /dev/null +++ b/mmdet/version.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +__version__ = '3.2.0' +short_version = __version__ + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/mmdet/visualization/__init__.py b/mmdet/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7edaed9d8701b1be72ff2f7ca646b865007e2eb --- /dev/null +++ b/mmdet/visualization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import DetLocalVisualizer, TrackLocalVisualizer +from .palette import get_palette, jitter_color, palette_val + +__all__ = [ + 'palette_val', 'get_palette', 'DetLocalVisualizer', 'jitter_color', + 'TrackLocalVisualizer' +] diff --git a/mmdet/visualization/local_visualizer.py b/mmdet/visualization/local_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6521c56eb167c2c94a3f058594d9e832fb15ad --- /dev/null +++ b/mmdet/visualization/local_visualizer.py @@ -0,0 +1,699 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import mmcv +import numpy as np + +try: + import seaborn as sns +except ImportError: + sns = None +import torch +from mmengine.dist import master_only +from mmengine.structures import InstanceData, PixelData +from mmengine.visualization import Visualizer + +from ..evaluation import INSTANCE_OFFSET +from ..registry import VISUALIZERS +from ..structures import DetDataSample +from ..structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon +from .palette import _get_adaptive_scales, get_palette, jitter_color + + +@VISUALIZERS.register_module() +class DetLocalVisualizer(Visualizer): + """MMDetection Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + bbox_color (str, tuple(int), optional): Color of bbox lines. + The tuple of color should be in BGR order. Defaults to None. + text_color (str, tuple(int), optional): Color of texts. + The tuple of color should be in BGR order. + Defaults to (200, 200, 200). + mask_color (str, tuple(int), optional): Color of masks. + The tuple of color should be in BGR order. + Defaults to None. + line_width (int, float): The linewidth of lines. + Defaults to 3. + alpha (int, float): The transparency of bboxes or mask. + Defaults to 0.8. + + Examples: + >>> import numpy as np + >>> import torch + >>> from mmengine.structures import InstanceData + >>> from mmdet.structures import DetDataSample + >>> from mmdet.visualization import DetLocalVisualizer + + >>> det_local_visualizer = DetLocalVisualizer() + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_instances = InstanceData() + >>> gt_instances.bboxes = torch.Tensor([[1, 2, 2, 5]]) + >>> gt_instances.labels = torch.randint(0, 2, (1,)) + >>> gt_det_data_sample = DetDataSample() + >>> gt_det_data_sample.gt_instances = gt_instances + >>> det_local_visualizer.add_datasample('image', image, + ... gt_det_data_sample) + >>> det_local_visualizer.add_datasample( + ... 'image', image, gt_det_data_sample, + ... out_file='out_file.jpg') + >>> det_local_visualizer.add_datasample( + ... 'image', image, gt_det_data_sample, + ... show=True) + >>> pred_instances = InstanceData() + >>> pred_instances.bboxes = torch.Tensor([[2, 4, 4, 8]]) + >>> pred_instances.labels = torch.randint(0, 2, (1,)) + >>> pred_det_data_sample = DetDataSample() + >>> pred_det_data_sample.pred_instances = pred_instances + >>> det_local_visualizer.add_datasample('image', image, + ... gt_det_data_sample, + ... pred_det_data_sample) + """ + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + bbox_color: Optional[Union[str, Tuple[int]]] = None, + text_color: Optional[Union[str, + Tuple[int]]] = (200, 200, 200), + mask_color: Optional[Union[str, Tuple[int]]] = None, + line_width: Union[int, float] = 3, + alpha: float = 0.8) -> None: + super().__init__( + name=name, + image=image, + vis_backends=vis_backends, + save_dir=save_dir) + self.bbox_color = bbox_color + self.text_color = text_color + self.mask_color = mask_color + self.line_width = line_width + self.alpha = alpha + # Set default value. When calling + # `DetLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + self.dataset_meta = {} + + def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], + classes: Optional[List[str]], + palette: Optional[List[tuple]]) -> np.ndarray: + """Draw instances of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + classes (List[str], optional): Category information. + palette (List[tuple], optional): Palette information + corresponding to the category. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + self.set_image(image) + + if 'bboxes' in instances and instances.bboxes.sum() > 0: + bboxes = instances.bboxes + labels = instances.labels + + max_label = int(max(labels) if len(labels) > 0 else 0) + text_palette = get_palette(self.text_color, max_label + 1) + text_colors = [text_palette[label] for label in labels] + + bbox_color = palette if self.bbox_color is None \ + else self.bbox_color + bbox_palette = get_palette(bbox_color, max_label + 1) + colors = [bbox_palette[label] for label in labels] + self.draw_bboxes( + bboxes, + edge_colors=colors, + alpha=self.alpha, + line_widths=self.line_width) + + positions = bboxes[:, :2] + self.line_width + areas = (bboxes[:, 3] - bboxes[:, 1]) * ( + bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas) + + for i, (pos, label) in enumerate(zip(positions, labels)): + if 'label_names' in instances: + label_text = instances.label_names[i] + else: + label_text = classes[ + label] if classes is not None else f'class {label}' + if 'scores' in instances: + score = round(float(instances.scores[i]) * 100, 1) + label_text += f': {score}' + + self.draw_texts( + label_text, + pos, + colors=text_colors[i], + font_sizes=int(13 * scales[i]), + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + if 'masks' in instances: + labels = instances.labels + masks = instances.masks + if isinstance(masks, torch.Tensor): + masks = masks.numpy() + elif isinstance(masks, (PolygonMasks, BitmapMasks)): + masks = masks.to_ndarray() + + masks = masks.astype(bool) + + max_label = int(max(labels) if len(labels) > 0 else 0) + mask_color = palette if self.mask_color is None \ + else self.mask_color + mask_palette = get_palette(mask_color, max_label + 1) + colors = [jitter_color(mask_palette[label]) for label in labels] + text_palette = get_palette(self.text_color, max_label + 1) + text_colors = [text_palette[label] for label in labels] + + polygons = [] + for i, mask in enumerate(masks): + contours, _ = bitmap_to_polygon(mask) + polygons.extend(contours) + self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha) + self.draw_binary_masks(masks, colors=colors, alphas=self.alpha) + + if len(labels) > 0 and \ + ('bboxes' not in instances or + instances.bboxes.sum() == 0): + # instances.bboxes.sum()==0 represent dummy bboxes. + # A typical example of SOLO does not exist bbox branch. + areas = [] + positions = [] + for mask in masks: + _, _, stats, centroids = cv2.connectedComponentsWithStats( + mask.astype(np.uint8), connectivity=8) + if stats.shape[0] > 1: + largest_id = np.argmax(stats[1:, -1]) + 1 + positions.append(centroids[largest_id]) + areas.append(stats[largest_id, -1]) + areas = np.stack(areas, axis=0) + scales = _get_adaptive_scales(areas) + + for i, (pos, label) in enumerate(zip(positions, labels)): + if 'label_names' in instances: + label_text = instances.label_names[i] + else: + label_text = classes[ + label] if classes is not None else f'class {label}' + if 'scores' in instances: + score = round(float(instances.scores[i]) * 100, 1) + label_text += f': {score}' + + self.draw_texts( + label_text, + pos, + colors=text_colors[i], + font_sizes=int(13 * scales[i]), + horizontal_alignments='center', + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + return self.get_image() + + def _draw_panoptic_seg(self, image: np.ndarray, + panoptic_seg: ['PixelData'], + classes: Optional[List[str]], + palette: Optional[List]) -> np.ndarray: + """Draw panoptic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + panoptic_seg (:obj:`PixelData`): Data structure for + pixel-level annotations or predictions. + classes (List[str], optional): Category information. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + # TODO: Is there a way to bypass? + num_classes = len(classes) + + panoptic_seg_data = panoptic_seg.sem_seg[0] + + ids = np.unique(panoptic_seg_data)[::-1] + + if 'label_names' in panoptic_seg: + # open set panoptic segmentation + classes = panoptic_seg.metainfo['label_names'] + ignore_index = panoptic_seg.metainfo.get('ignore_index', + len(classes)) + ids = ids[ids != ignore_index] + else: + # for VOID label + ids = ids[ids != num_classes] + + labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) + segms = (panoptic_seg_data[None] == ids[:, None, None]) + + max_label = int(max(labels) if len(labels) > 0 else 0) + + mask_color = palette if self.mask_color is None \ + else self.mask_color + mask_palette = get_palette(mask_color, max_label + 1) + colors = [mask_palette[label] for label in labels] + + self.set_image(image) + + # draw segm + polygons = [] + for i, mask in enumerate(segms): + contours, _ = bitmap_to_polygon(mask) + polygons.extend(contours) + self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha) + self.draw_binary_masks(segms, colors=colors, alphas=self.alpha) + + # draw label + areas = [] + positions = [] + for mask in segms: + _, _, stats, centroids = cv2.connectedComponentsWithStats( + mask.astype(np.uint8), connectivity=8) + max_id = np.argmax(stats[1:, -1]) + 1 + positions.append(centroids[max_id]) + areas.append(stats[max_id, -1]) + areas = np.stack(areas, axis=0) + scales = _get_adaptive_scales(areas) + + text_palette = get_palette(self.text_color, max_label + 1) + text_colors = [text_palette[label] for label in labels] + + for i, (pos, label) in enumerate(zip(positions, labels)): + label_text = classes[label] + + self.draw_texts( + label_text, + pos, + colors=text_colors[i], + font_sizes=int(13 * scales[i]), + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }], + horizontal_alignments='center') + return self.get_image() + + def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, + classes: Optional[List], + palette: Optional[List]) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for pixel-level + annotations or predictions. + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + sem_seg_data = sem_seg.sem_seg + if isinstance(sem_seg_data, torch.Tensor): + sem_seg_data = sem_seg_data.numpy() + + # 0 ~ num_class, the value 0 means background + ids = np.unique(sem_seg_data) + ignore_index = sem_seg.metainfo.get('ignore_index', 255) + ids = ids[ids != ignore_index] + + if 'label_names' in sem_seg: + # open set semseg + label_names = sem_seg.metainfo['label_names'] + else: + label_names = classes + + labels = np.array(ids, dtype=np.int64) + colors = [palette[label] for label in labels] + + self.set_image(image) + + # draw semantic masks + for i, (label, color) in enumerate(zip(labels, colors)): + masks = sem_seg_data == label + self.draw_binary_masks(masks, colors=[color], alphas=self.alpha) + label_text = label_names[label] + _, _, stats, centroids = cv2.connectedComponentsWithStats( + masks[0].astype(np.uint8), connectivity=8) + if stats.shape[0] > 1: + largest_id = np.argmax(stats[1:, -1]) + 1 + centroids = centroids[largest_id] + + areas = stats[largest_id, -1] + scales = _get_adaptive_scales(areas) + + self.draw_texts( + label_text, + centroids, + colors=(255, 255, 255), + font_sizes=int(13 * scales), + horizontal_alignments='center', + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + return self.get_image() + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: Optional['DetDataSample'] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + pred_score_thr: float = 0.3, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. t is usually used when the display + is not available. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + data_sample (:obj:`DetDataSample`, optional): A data + sample that contain annotations and predictions. + Defaults to None. + draw_gt (bool): Whether to draw GT DetDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction DetDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + pred_score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + step (int): Global step value to record. Defaults to 0. + """ + image = image.clip(0, 255).astype(np.uint8) + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + + gt_img_data = None + pred_img_data = None + + if data_sample is not None: + data_sample = data_sample.cpu() + + if draw_gt and data_sample is not None: + gt_img_data = image + if 'gt_instances' in data_sample: + gt_img_data = self._draw_instances(image, + data_sample.gt_instances, + classes, palette) + if 'gt_sem_seg' in data_sample: + gt_img_data = self._draw_sem_seg(gt_img_data, + data_sample.gt_sem_seg, + classes, palette) + + if 'gt_panoptic_seg' in data_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing panoptic ' \ + 'segmentation results.' + gt_img_data = self._draw_panoptic_seg( + gt_img_data, data_sample.gt_panoptic_seg, classes, palette) + + if draw_pred and data_sample is not None: + pred_img_data = image + if 'pred_instances' in data_sample: + pred_instances = data_sample.pred_instances + pred_instances = pred_instances[ + pred_instances.scores > pred_score_thr] + pred_img_data = self._draw_instances(image, pred_instances, + classes, palette) + + if 'pred_sem_seg' in data_sample: + pred_img_data = self._draw_sem_seg(pred_img_data, + data_sample.pred_sem_seg, + classes, palette) + + if 'pred_panoptic_seg' in data_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing panoptic ' \ + 'segmentation results.' + pred_img_data = self._draw_panoptic_seg( + pred_img_data, data_sample.pred_panoptic_seg.numpy(), + classes, palette) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + elif pred_img_data is not None: + drawn_img = pred_img_data + else: + # Display the original image directly if nothing is drawn. + drawn_img = image + + # It is convenient for users to obtain the drawn image. + # For example, the user wants to obtain the drawn image and + # save it as a video during video inference. + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step) + + +def random_color(seed): + """Random a color according to the input seed.""" + if sns is None: + raise RuntimeError('motmetrics is not installed,\ + please install it by: pip install seaborn') + np.random.seed(seed) + colors = sns.color_palette() + color = colors[np.random.choice(range(len(colors)))] + color = tuple([int(255 * c) for c in color]) + return color + + +@VISUALIZERS.register_module() +class TrackLocalVisualizer(Visualizer): + """Tracking Local Visualizer for the MOT, VIS tasks. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + line_width (int, float): The linewidth of lines. + Defaults to 3. + alpha (int, float): The transparency of bboxes or mask. + Defaults to 0.8. + """ + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + line_width: Union[int, float] = 3, + alpha: float = 0.8) -> None: + super().__init__(name, image, vis_backends, save_dir) + self.line_width = line_width + self.alpha = alpha + # Set default value. When calling + # `TrackLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + self.dataset_meta = {} + + def _draw_instances(self, image: np.ndarray, + instances: InstanceData) -> np.ndarray: + """Draw instances of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + self.set_image(image) + classes = self.dataset_meta.get('classes', None) + + # get colors and texts + # for the MOT and VIS tasks + colors = [random_color(_id) for _id in instances.instances_id] + categories = [ + classes[label] if classes is not None else f'cls{label}' + for label in instances.labels + ] + if 'scores' in instances: + texts = [ + f'{category_name}\n{instance_id} | {score:.2f}' + for category_name, instance_id, score in zip( + categories, instances.instances_id, instances.scores) + ] + else: + texts = [ + f'{category_name}\n{instance_id}' for category_name, + instance_id in zip(categories, instances.instances_id) + ] + + # draw bboxes and texts + if 'bboxes' in instances: + # draw bboxes + bboxes = instances.bboxes.clone() + self.draw_bboxes( + bboxes, + edge_colors=colors, + alpha=self.alpha, + line_widths=self.line_width) + # draw texts + if texts is not None: + positions = bboxes[:, :2] + self.line_width + areas = (bboxes[:, 3] - bboxes[:, 1]) * ( + bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas.cpu().numpy()) + for i, pos in enumerate(positions): + self.draw_texts( + texts[i], + pos, + colors='black', + font_sizes=int(13 * scales[i]), + bboxes=[{ + 'facecolor': [c / 255 for c in colors[i]], + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + # draw masks + if 'masks' in instances: + masks = instances.masks + polygons = [] + for i, mask in enumerate(masks): + contours, _ = bitmap_to_polygon(mask) + polygons.extend(contours) + self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha) + self.draw_binary_masks(masks, colors=colors, alphas=self.alpha) + + return self.get_image() + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: DetDataSample = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: int = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + pred_score_thr: float = 0.3, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. t is usually used when the display + is not available. + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + data_sample (OptTrackSampleList): A data + sample that contain annotations and predictions. + Defaults to None. + draw_gt (bool): Whether to draw GT TrackDataSample. + Default to True. + draw_pred (bool): Whether to draw Prediction TrackDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (int): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + pred_score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + step (int): Global step value to record. Defaults to 0. + """ + gt_img_data = None + pred_img_data = None + + if data_sample is not None: + data_sample = data_sample.cpu() + + if draw_gt and data_sample is not None: + assert 'gt_instances' in data_sample + gt_img_data = self._draw_instances(image, data_sample.gt_instances) + + if draw_pred and data_sample is not None: + assert 'pred_track_instances' in data_sample + pred_instances = data_sample.pred_track_instances + if 'scores' in pred_instances: + pred_instances = pred_instances[ + pred_instances.scores > pred_score_thr].cpu() + pred_img_data = self._draw_instances(image, pred_instances) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step) diff --git a/mmdet/visualization/palette.py b/mmdet/visualization/palette.py new file mode 100644 index 0000000000000000000000000000000000000000..3c402c08823a60759c984093ba7f05f1e310dbd9 --- /dev/null +++ b/mmdet/visualization/palette.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import mmcv +import numpy as np +from mmengine.utils import is_str + + +def palette_val(palette: List[tuple]) -> List[tuple]: + """Convert palette to matplotlib palette. + + Args: + palette (List[tuple]): A list of color tuples. + + Returns: + List[tuple[float]]: A list of RGB matplotlib color tuples. + """ + new_palette = [] + for color in palette: + color = [c / 255 for c in color] + new_palette.append(tuple(color)) + return new_palette + + +def get_palette(palette: Union[List[tuple], str, tuple], + num_classes: int) -> List[Tuple[int]]: + """Get palette from various inputs. + + Args: + palette (list[tuple] | str | tuple): palette inputs. + num_classes (int): the number of classes. + + Returns: + list[tuple[int]]: A list of color tuples. + """ + assert isinstance(num_classes, int) + + if isinstance(palette, list): + dataset_palette = palette + elif isinstance(palette, tuple): + dataset_palette = [palette] * num_classes + elif palette == 'random' or palette is None: + state = np.random.get_state() + # random color + np.random.seed(42) + palette = np.random.randint(0, 256, size=(num_classes, 3)) + np.random.set_state(state) + dataset_palette = [tuple(c) for c in palette] + elif palette == 'coco': + from mmdet.datasets import CocoDataset, CocoPanopticDataset + dataset_palette = CocoDataset.METAINFO['palette'] + if len(dataset_palette) < num_classes: + dataset_palette = CocoPanopticDataset.METAINFO['palette'] + elif palette == 'citys': + from mmdet.datasets import CityscapesDataset + dataset_palette = CityscapesDataset.METAINFO['palette'] + elif palette == 'voc': + from mmdet.datasets import VOCDataset + dataset_palette = VOCDataset.METAINFO['palette'] + elif is_str(palette): + dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes + else: + raise TypeError(f'Invalid type for palette: {type(palette)}') + + assert len(dataset_palette) >= num_classes, \ + 'The length of palette should not be less than `num_classes`.' + return dataset_palette + + +def _get_adaptive_scales(areas: np.ndarray, + min_area: int = 800, + max_area: int = 30000) -> np.ndarray: + """Get adaptive scales according to areas. + + The scale range is [0.5, 1.0]. When the area is less than + ``min_area``, the scale is 0.5 while the area is larger than + ``max_area``, the scale is 1.0. + + Args: + areas (ndarray): The areas of bboxes or masks with the + shape of (n, ). + min_area (int): Lower bound areas for adaptive scales. + Defaults to 800. + max_area (int): Upper bound areas for adaptive scales. + Defaults to 30000. + + Returns: + ndarray: The adaotive scales with the shape of (n, ). + """ + scales = 0.5 + (areas - min_area) // (max_area - min_area) + scales = np.clip(scales, 0.5, 1.0) + return scales + + +def jitter_color(color: tuple) -> tuple: + """Randomly jitter the given color in order to better distinguish instances + with the same class. + + Args: + color (tuple): The RGB color tuple. Each value is between [0, 255]. + + Returns: + tuple: The jittered color tuple. + """ + jitter = np.random.rand(3) + jitter = (jitter / np.linalg.norm(jitter) - 0.5) * 0.5 * 255 + color = np.clip(jitter + color, 0, 255).astype(np.uint8) + return tuple(color) diff --git a/mmpretrain/.DS_Store b/mmpretrain/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a17fd18f5a75057aae9614d91ddda49c774c346a Binary files /dev/null and b/mmpretrain/.DS_Store differ diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69c585bd26fbf30bb09b28383621a63d7752890f --- /dev/null +++ b/mmpretrain/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +from .apis import * # noqa: F401, F403 +from .version import __version__ + +mmcv_minimum_version = '2.0.0' +mmcv_maximum_version = '2.2.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.8.3' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<{mmengine_maximum_version}.' + +__all__ = ['__version__'] diff --git a/mmpretrain/__pycache__/__init__.cpython-311.pyc b/mmpretrain/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e438e101154749b4ed5753eb5dceffe47361d980 Binary files /dev/null and b/mmpretrain/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/__pycache__/registry.cpython-311.pyc b/mmpretrain/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e58feb14d586851dd1552cbb477896a9b8435a8 Binary files /dev/null and b/mmpretrain/__pycache__/registry.cpython-311.pyc differ diff --git a/mmpretrain/__pycache__/version.cpython-311.pyc b/mmpretrain/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c3d9950f15a4281efb5a177269c62509119512a Binary files /dev/null and b/mmpretrain/__pycache__/version.cpython-311.pyc differ diff --git a/mmpretrain/apis/__init__.py b/mmpretrain/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbf443772a983c41f7273124f843bdfbb7f0f46 --- /dev/null +++ b/mmpretrain/apis/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseInferencer +from .feature_extractor import FeatureExtractor +from .image_caption import ImageCaptionInferencer +from .image_classification import ImageClassificationInferencer +from .image_retrieval import ImageRetrievalInferencer +from .model import (ModelHub, get_model, inference_model, init_model, + list_models) +from .multimodal_retrieval import (ImageToTextRetrievalInferencer, + TextToImageRetrievalInferencer) +from .nlvr import NLVRInferencer +from .visual_grounding import VisualGroundingInferencer +from .visual_question_answering import VisualQuestionAnsweringInferencer + +__all__ = [ + 'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub', + 'ImageClassificationInferencer', 'ImageRetrievalInferencer', + 'FeatureExtractor', 'ImageCaptionInferencer', + 'TextToImageRetrievalInferencer', 'VisualGroundingInferencer', + 'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer', + 'BaseInferencer', 'NLVRInferencer' +] diff --git a/mmpretrain/apis/__pycache__/__init__.cpython-311.pyc b/mmpretrain/apis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2e9f336d0cfc8e15ab968f6fdf5a753ea9304a5 Binary files /dev/null and b/mmpretrain/apis/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/base.cpython-311.pyc b/mmpretrain/apis/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..724982937f4b2aa3ebd5ad442fbb53fd524cd044 Binary files /dev/null and b/mmpretrain/apis/__pycache__/base.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/feature_extractor.cpython-311.pyc b/mmpretrain/apis/__pycache__/feature_extractor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3280e49e84f1cc32ba4d16f6a8087e1dd4970f1 Binary files /dev/null and b/mmpretrain/apis/__pycache__/feature_extractor.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/image_caption.cpython-311.pyc b/mmpretrain/apis/__pycache__/image_caption.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24dc8f221bab2230490f90b1675a7cc429cf25c0 Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_caption.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/image_classification.cpython-311.pyc b/mmpretrain/apis/__pycache__/image_classification.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..256666ebda9b53b25e75ef184f691435853eb52d Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_classification.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/image_retrieval.cpython-311.pyc b/mmpretrain/apis/__pycache__/image_retrieval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..815c5b324bd22650358daa77f4c0231dce62f3f3 Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_retrieval.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/model.cpython-311.pyc b/mmpretrain/apis/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a691fba803de028039d672fdc66fec2e5b9e650 Binary files /dev/null and b/mmpretrain/apis/__pycache__/model.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-311.pyc b/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f027ef726a118db4d2dbb8e791989aeba2d40379 Binary files /dev/null and b/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/nlvr.cpython-311.pyc b/mmpretrain/apis/__pycache__/nlvr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72de52075a5dc3b16487e533645aa9eecd5cc774 Binary files /dev/null and b/mmpretrain/apis/__pycache__/nlvr.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/visual_grounding.cpython-311.pyc b/mmpretrain/apis/__pycache__/visual_grounding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..071cda79ea0dc7deac8467280e284ef10f204d8d Binary files /dev/null and b/mmpretrain/apis/__pycache__/visual_grounding.cpython-311.pyc differ diff --git a/mmpretrain/apis/__pycache__/visual_question_answering.cpython-311.pyc b/mmpretrain/apis/__pycache__/visual_question_answering.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3283632210b7ed3d4867e74bc01a69d6442840a7 Binary files /dev/null and b/mmpretrain/apis/__pycache__/visual_question_answering.cpython-311.pyc differ diff --git a/mmpretrain/apis/base.py b/mmpretrain/apis/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7bff6bd18675a3a0996dcd09081a15728311657f --- /dev/null +++ b/mmpretrain/apis/base.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from math import ceil +from typing import Callable, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.config import Config +from mmengine.dataset import default_collate +from mmengine.fileio import get_file_backend +from mmengine.model import BaseModel +from mmengine.runner import load_checkpoint + +from mmpretrain.structures import DataSample +from mmpretrain.utils import track +from .model import get_model, list_models + +ModelType = Union[BaseModel, str, Config] +InputType = Union[str, np.ndarray, list] + + +class BaseInferencer: + """Base inferencer for various tasks. + + The BaseInferencer provides the standard workflow for inference as follows: + + 1. Preprocess the input data by :meth:`preprocess`. + 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` + assumes the model inherits from :class:`mmengine.models.BaseModel` and + will call `model.test_step` in :meth:`forward` by default. + 3. Visualize the results by :meth:`visualize`. + 4. Postprocess and return the results by :meth:`postprocess`. + + When we call the subclasses inherited from BaseInferencer (not overriding + ``__call__``), the workflow will be executed in order. + + All subclasses of BaseInferencer could define the following class + attributes for customization: + + - ``preprocess_kwargs``: The keys of the kwargs that will be passed to + :meth:`preprocess`. + - ``forward_kwargs``: The keys of the kwargs that will be passed to + :meth:`forward` + - ``visualize_kwargs``: The keys of the kwargs that will be passed to + :meth:`visualize` + - ``postprocess_kwargs``: The keys of the kwargs that will be passed to + :meth:`postprocess` + + All attributes mentioned above should be a ``set`` of keys (strings), + and each key should not be duplicated. Actually, :meth:`__call__` will + dispatch all the arguments to the corresponding methods according to the + ``xxx_kwargs`` mentioned above. + + Subclasses inherited from ``BaseInferencer`` should implement + :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: + + - _init_pipeline: Return a callable object to preprocess the input data. + - visualize: Visualize the results returned by :meth:`forward`. + - postprocess: Postprocess the results returned by :meth:`forward` and + :meth:`visualize`. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``cls.list_models()`` and you can also query it in + :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = set() + postprocess_kwargs: set = set() + + def __init__(self, + model: ModelType, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + device_map=None, + offload_folder=None, + **kwargs) -> None: + + if isinstance(model, BaseModel): + if isinstance(pretrained, str): + load_checkpoint(model, pretrained, map_location='cpu') + if device_map is not None: + from .utils import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_folder=offload_folder) + elif device is not None: + model.to(device) + else: + model = get_model( + model, + pretrained, + device=device, + device_map=device_map, + offload_folder=offload_folder, + **kwargs) + + model.eval() + + self.config = model._config + self.model = model + self.pipeline = self._init_pipeline(self.config) + self.visualizer = None + + def __call__( + self, + inputs, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs, + ) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Key words arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess( + ori_inputs, batch_size=batch_size, **preprocess_kwargs) + preds = [] + for data in track( + inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)): + preds.extend(self.forward(data, **forward_kwargs)) + visualization = self.visualize(ori_inputs, preds, **visualize_kwargs) + results = self.postprocess(preds, visualization, return_datasamples, + **postprocess_kwargs) + return results + + def _inputs_to_list(self, inputs: InputType) -> list: + """Preprocess the inputs to a list. + + Cast the input data to a list of data. + + - list or tuple: return inputs + - str: + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + - other: return a list with one item. + + Args: + inputs (str | array | list): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if isinstance(inputs, str): + backend = get_file_backend(inputs) + if hasattr(backend, 'isdir') and backend.isdir(inputs): + # Backends like HttpsBackend do not implement `isdir`, so only + # those backends that implement `isdir` could accept the inputs + # as a directory + file_list = backend.list_dir_or_file(inputs, list_dir=False) + inputs = [ + backend.join_path(inputs, file) for file in file_list + ] + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + + def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs): + """Process the inputs into a model-feedable format. + + Customize your preprocess by overriding this method. Preprocess should + return an iterable object, of which each item will be used as the + input of ``model.test_step``. + + ``BaseInferencer.preprocess`` will return an iterable chunked data, + which will be used in __call__ like this: + + .. code-block:: python + + def __call__(self, inputs, batch_size=1, **kwargs): + chunked_data = self.preprocess(inputs, batch_size, **kwargs) + for batch in chunked_data: + preds = self.forward(batch, **kwargs) + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``default_collate``. + """ + chunked_data = self._get_chunk_data( + map(self.pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], **kwargs): + """Feed the inputs to the model.""" + return self.model.test_step(inputs) + + def visualize(self, + inputs: list, + preds: List[DataSample], + show: bool = False, + **kwargs) -> List[np.ndarray]: + """Visualize predictions. + + Customize your visualization by overriding this method. visualize + should return visualization results, which could be np.ndarray or any + other objects. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + + Returns: + List[np.ndarray]: Visualization results. + """ + if show: + raise NotImplementedError( + f'The `visualize` method of {self.__class__.__name__} ' + 'is not implemented.') + + @abstractmethod + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasample=False, + **kwargs, + ) -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Customize your postprocess by overriding this method. Make sure + ``postprocess`` will return a dict with visualization results and + inference results. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (np.ndarray): Visualized predictions. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + + @abstractmethod + def _init_pipeline(self, cfg: Config) -> Callable: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + + def _get_chunk_data(self, inputs: Iterable, chunk_size: int): + """Get batch data from dataset. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + processed_data = next(inputs_iter) + chunk_data.append(processed_data) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + + def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]: + """Dispatch kwargs to preprocess(), forward(), visualize() and + postprocess() according to the actual demands. + + Returns: + Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, + forward, visualize and postprocess respectively. + """ + # Ensure each argument only matches one function + method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \ + self.visualize_kwargs | self.postprocess_kwargs + + union_kwargs = method_kwargs | set(kwargs.keys()) + if union_kwargs != method_kwargs: + unknown_kwargs = union_kwargs - method_kwargs + raise ValueError( + f'unknown argument {unknown_kwargs} for `preprocess`, ' + '`forward`, `visualize` and `postprocess`') + + preprocess_kwargs = {} + forward_kwargs = {} + visualize_kwargs = {} + postprocess_kwargs = {} + + for key, value in kwargs.items(): + if key in self.preprocess_kwargs: + preprocess_kwargs[key] = value + if key in self.forward_kwargs: + forward_kwargs[key] = value + if key in self.visualize_kwargs: + visualize_kwargs[key] = value + if key in self.postprocess_kwargs: + postprocess_kwargs[key] = value + + return ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List models defined in metafile of corresponding packages. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern) diff --git a/mmpretrain/apis/feature_extractor.py b/mmpretrain/apis/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..ee14f92f489497dd036fe0567786a94207924d4a --- /dev/null +++ b/mmpretrain/apis/feature_extractor.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from .base import BaseInferencer, InputType +from .model import list_models + + +class FeatureExtractor(BaseInferencer): + """The inferencer for extract features. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``FeatureExtractor.list_models()`` and you can also query it in + :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import FeatureExtractor + >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3))) + >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0] + >>> for feat in feats: + >>> print(feat.shape) + torch.Size([256, 56, 56]) + torch.Size([512, 28, 28]) + torch.Size([1024, 14, 14]) + torch.Size([2048, 7, 7]) + """ # noqa: E501 + + def __call__(self, + inputs: InputType, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Other keyword arguments accepted by the `extract_feat` + method of the model. + + Returns: + tensor | Tuple[tensor]: The extracted features. + """ + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess(ori_inputs, batch_size=batch_size) + preds = [] + for data in inputs: + preds.extend(self.forward(data, **kwargs)) + + return preds + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], **kwargs): + inputs = self.model.data_preprocessor(inputs, False)['inputs'] + outputs = self.model.extract_feat(inputs, **kwargs) + + def scatter(feats, index): + if isinstance(feats, torch.Tensor): + return feats[index] + else: + # Sequence of tensor + return type(feats)([scatter(item, index) for item in feats]) + + results = [] + for i in range(inputs.shape[0]): + results.append(scatter(outputs, i)) + + return results + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self): + raise NotImplementedError( + "The FeatureExtractor doesn't support visualization.") + + def postprocess(self): + raise NotImplementedError( + "The FeatureExtractor doesn't need postprocessing.") + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern) diff --git a/mmpretrain/apis/image_caption.py b/mmpretrain/apis/image_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..c11c0d3044d9924aba159782309d2cc20f1745bc --- /dev/null +++ b/mmpretrain/apis/image_caption.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType +from .model import list_models + + +class ImageCaptionInferencer(BaseInferencer): + """The inferencer for image caption. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageCaptionInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageCaptionInferencer + >>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption') + >>> inferencer('demo/cat-dog.png')[0] + {'pred_caption': 'a puppy and a cat sitting on a blanket'} + """ # noqa: E501 + + visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'} + + def __call__(self, + images: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(images, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_caption( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({'pred_caption': data_sample.get('pred_caption')}) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Caption') diff --git a/mmpretrain/apis/image_classification.py b/mmpretrain/apis/image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..a20218071c7afc90c6a46d61b5ed3a8fee5bc012 --- /dev/null +++ b/mmpretrain/apis/image_classification.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType, ModelType +from .model import list_models + + +class ImageClassificationInferencer(BaseInferencer): + """The inferencer for image classification. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageClassificationInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + 1. Use a pre-trained model in MMPreTrain to inference an image. + + >>> from mmpretrain import ImageClassificationInferencer + >>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k') + >>> inferencer('demo/demo.JPEG') + [{'pred_score': array([...]), + 'pred_label': 65, + 'pred_score': 0.6649367809295654, + 'pred_class': 'sea snake'}] + + 2. Use a config file and checkpoint to inference multiple images on GPU, + and save the visualization results in a folder. + + >>> from mmpretrain import ImageClassificationInferencer + >>> inferencer = ImageClassificationInferencer( + model='configs/resnet/resnet50_8xb32_in1k.py', + pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', + device='cuda') + >>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/") + """ # noqa: E501 + + visualize_kwargs: set = { + 'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir', + 'wait_time' + } + + def __init__(self, + model: ModelType, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + classes=None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + if classes is not None: + self.classes = classes + else: + self.classes = getattr(self.model, '_dataset_meta', + {}).get('classes') + + def __call__(self, + inputs: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor for visualization. This is helpful when the image is too + large or too small for visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__( + inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_cls( + image, + data_sample, + classes=self.classes, + resize=resize, + show=show, + wait_time=wait_time, + rescale_factor=rescale_factor, + draw_gt=False, + draw_pred=True, + draw_score=draw_score, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + pred_scores = data_sample.pred_score + pred_score = float(torch.max(pred_scores).item()) + pred_label = torch.argmax(pred_scores).item() + result = { + 'pred_scores': pred_scores.detach().cpu().numpy(), + 'pred_label': pred_label, + 'pred_score': pred_score, + } + if self.classes is not None: + result['pred_class'] = self.classes[pred_label] + results.append(result) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Classification') diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..27919b20f58afe603fb23d9aeb2fc37326683286 --- /dev/null +++ b/mmpretrain/apis/image_retrieval.py @@ -0,0 +1,288 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import BaseDataset, Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType, ModelType +from .model import list_models + + +class ImageRetrievalInferencer(BaseInferencer): + """The inferencer for image to image retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageRetrievalInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader, BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The directory of the the images. + - list: A list of path of the images. + - dict: A config dict of the a prototype dataset. + - BaseDataset: A prototype dataset. + - DataLoader: A data loader to load the prototype data. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageRetrievalInferencer + >>> inferencer = ImageRetrievalInferencer( + ... 'resnet50-arcface_inshop', + ... prototype='./demo/', + ... prototype_cache='img_retri.pth') + >>> inferencer('demo/cat-dog.png', topk=2)[0][1] + {'match_score': tensor(0.4088, device='cuda:0'), + 'sample_idx': 3, + 'sample': {'img_path': './demo/dog.jpg'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__( + self, + model: ModelType, + prototype, + prototype_cache=None, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs, + ) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.prototype_dataset = self._prepare_prototype( + prototype, prototype_cache, prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A directory path of images + prototype = dict( + type='CustomDataset', with_label=False, data_root=prototype) + + if isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, dict): + # A config of dataset + from mmpretrain.registry import DATASETS + test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + prototype.setdefault('pipeline', test_pipeline) + dataset = DATASETS.build(prototype) + dataloader = build_dataloader(dataset) + elif isinstance(prototype, DataLoader): + dataset = prototype.dataset + dataloader = prototype + elif isinstance(prototype, BaseDataset): + dataset = prototype + dataloader = build_dataloader(dataset) + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + if cache is not None and Path(cache).exists(): + self.model.prototype = cache + else: + self.model.prototype = dataloader + self.model.prepare_prototype() + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + self.model.dump_prototype(path) + + def __call__(self, + inputs: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + sample = self.prototype_dataset.get_data_info( + sample_idx.item()) + sample_idx = sample.pop('sample_idx') + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'sample': sample + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Retrieval') diff --git a/mmpretrain/apis/model.py b/mmpretrain/apis/model.py new file mode 100644 index 0000000000000000000000000000000000000000..eba475e7f791f42eb9aec384afec947f72722f27 --- /dev/null +++ b/mmpretrain/apis/model.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import fnmatch +import os.path as osp +import re +import warnings +from os import PathLike +from pathlib import Path +from typing import List, Tuple, Union + +from mmengine.config import Config +from modelindex.load_model_index import load +from modelindex.models.Model import Model + + +class ModelHub: + """A hub to host the meta information of all pre-defined models.""" + _models_dict = {} + __mmpretrain_registered = False + + @classmethod + def register_model_index(cls, + model_index_path: Union[str, PathLike], + config_prefix: Union[str, PathLike, None] = None): + """Parse the model-index file and register all models. + + Args: + model_index_path (str | PathLike): The path of the model-index + file. + config_prefix (str | PathLike | None): The prefix of all config + file paths in the model-index file. + """ + model_index = load(str(model_index_path)) + model_index.build_models_with_collections() + + for metainfo in model_index.models: + model_name = metainfo.name.lower() + if metainfo.name in cls._models_dict: + raise ValueError( + 'The model name {} is conflict in {} and {}.'.format( + model_name, osp.abspath(metainfo.filepath), + osp.abspath(cls._models_dict[model_name].filepath))) + metainfo.config = cls._expand_config_path(metainfo, config_prefix) + cls._models_dict[model_name] = metainfo + + @classmethod + def get(cls, model_name): + """Get the model's metainfo by the model name. + + Args: + model_name (str): The name of model. + + Returns: + modelindex.models.Model: The metainfo of the specified model. + """ + cls._register_mmpretrain_models() + # lazy load config + metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) + if metainfo is None: + raise ValueError( + f'Failed to find model "{model_name}". please use ' + '`mmpretrain.list_models` to get all available names.') + if isinstance(metainfo.config, str): + metainfo.config = Config.fromfile(metainfo.config) + return metainfo + + @staticmethod + def _expand_config_path(metainfo: Model, + config_prefix: Union[str, PathLike] = None): + if config_prefix is None: + config_prefix = osp.dirname(metainfo.filepath) + + if metainfo.config is None or osp.isabs(metainfo.config): + config_path: str = metainfo.config + else: + config_path = osp.abspath(osp.join(config_prefix, metainfo.config)) + + return config_path + + @classmethod + def _register_mmpretrain_models(cls): + # register models in mmpretrain + if not cls.__mmpretrain_registered: + from importlib_metadata import distribution + root = distribution('mmpretrain').locate_file('mmpretrain') + model_index_path = root / '.mim' / 'model-index.yml' + ModelHub.register_model_index( + model_index_path, config_prefix=root / '.mim') + cls.__mmpretrain_registered = True + + @classmethod + def has(cls, model_name): + """Whether a model name is in the ModelHub.""" + return model_name in cls._models_dict + + +def get_model(model: Union[str, Config], + pretrained: Union[str, bool] = False, + device=None, + device_map=None, + offload_folder=None, + url_mapping: Tuple[str, str] = None, + **kwargs): + """Get a pre-defined model or create a model from config. + + Args: + model (str | Config): The name of model, the config file path or a + config instance. + pretrained (bool | str): When use name to specify model, you can + use ``True`` to load the pre-defined pretrained weights. And you + can also use a string to specify the path or link of weights to + load. Defaults to False. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + url_mapping (Tuple[str, str], optional): The mapping of pretrained + checkpoint link. For example, load checkpoint from a local dir + instead of download by ``('https://.*/', './checkpoint')``. + Defaults to None. + **kwargs: Other keyword arguments of the model config. + + Returns: + mmengine.model.BaseModel: The result model. + + Examples: + Get a ResNet-50 model and extract images feature: + + >>> import torch + >>> from mmpretrain import get_model + >>> inputs = torch.rand(16, 3, 224, 224) + >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) + >>> feats = model.extract_feat(inputs) + >>> for feat in feats: + ... print(feat.shape) + torch.Size([16, 256]) + torch.Size([16, 512]) + torch.Size([16, 1024]) + torch.Size([16, 2048]) + + Get Swin-Transformer model with pre-trained weights and inference: + + >>> from mmpretrain import get_model, inference_model + >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) + >>> result = inference_model(model, 'demo/demo.JPEG') + >>> print(result['pred_class']) + 'sea snake' + """ # noqa: E501 + if device_map is not None: + from .utils import dispatch_model + dispatch_model._verify_require() + + metainfo = None + if isinstance(model, Config): + config = copy.deepcopy(model) + if pretrained is True and 'load_from' in config: + pretrained = config.load_from + elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py': + config = Config.fromfile(model) + if pretrained is True and 'load_from' in config: + pretrained = config.load_from + elif isinstance(model, str): + metainfo = ModelHub.get(model) + config = metainfo.config + if pretrained is True and metainfo.weights is not None: + pretrained = metainfo.weights + else: + raise TypeError('model must be a name, a path or a Config object, ' + f'but got {type(config)}') + + if pretrained is True: + warnings.warn('Unable to find pre-defined checkpoint of the model.') + pretrained = None + elif pretrained is False: + pretrained = None + + if kwargs: + config.merge_from_dict({'model': kwargs}) + config.model.setdefault('data_preprocessor', + config.get('data_preprocessor', None)) + + from mmengine.registry import DefaultScope + + from mmpretrain.registry import MODELS + with DefaultScope.overwrite_default_scope('mmpretrain'): + model = MODELS.build(config.model) + + dataset_meta = {} + if pretrained: + # Mapping the weights to GPU may cause unexpected video memory leak + # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 + from mmengine.runner import load_checkpoint + if url_mapping is not None: + pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained) + checkpoint = load_checkpoint(model, pretrained, map_location='cpu') + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmpretrain 1.x + dataset_meta = checkpoint['meta']['dataset_meta'] + elif 'CLASSES' in checkpoint.get('meta', {}): + # mmcls 0.x + dataset_meta = {'classes': checkpoint['meta']['CLASSES']} + + if len(dataset_meta) == 0 and 'test_dataloader' in config: + from mmpretrain.registry import DATASETS + dataset_class = DATASETS.get(config.test_dataloader.dataset.type) + dataset_meta = getattr(dataset_class, 'METAINFO', {}) + + if device_map is not None: + model = dispatch_model( + model, device_map=device_map, offload_folder=offload_folder) + elif device is not None: + model.to(device) + + model._dataset_meta = dataset_meta # save the dataset meta + model._config = config # save the config in the model + model._metainfo = metainfo # save the metainfo in the model + model.eval() + return model + + +def init_model(config, checkpoint=None, device=None, **kwargs): + """Initialize a classifier from config file (deprecated). + + It's only for compatibility, please use :func:`get_model` instead. + + Args: + config (str | :obj:`mmengine.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + **kwargs: Other keyword arguments of the model config. + + Returns: + nn.Module: The constructed model. + """ + return get_model(config, checkpoint, device, **kwargs) + + +def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]: + """List all models available in MMPretrain. + + Args: + pattern (str | None): A wildcard pattern to match model names. + Defaults to None. + exclude_patterns (list | None): A list of wildcard patterns to + exclude names from the matched names. Defaults to None. + task (str | none): The evaluation task of the model. + + Returns: + List[str]: a list of model names. + + Examples: + List all models: + + >>> from mmpretrain import list_models + >>> list_models() + + List ResNet-50 models on ImageNet-1k dataset: + + >>> from mmpretrain import list_models + >>> list_models('resnet*in1k') + ['resnet50_8xb32_in1k', + 'resnet50_8xb32-fp16_in1k', + 'resnet50_8xb256-rsb-a1-600e_in1k', + 'resnet50_8xb256-rsb-a2-300e_in1k', + 'resnet50_8xb256-rsb-a3-100e_in1k'] + + List Swin-Transformer models trained from stratch and exclude + Swin-Transformer-V2 models: + + >>> from mmpretrain import list_models + >>> list_models('swin', exclude_patterns=['swinv2', '*-pre']) + ['swin-base_16xb64_in1k', + 'swin-base_3rdparty_in1k', + 'swin-base_3rdparty_in1k-384', + 'swin-large_8xb8_cub-384px', + 'swin-small_16xb64_in1k', + 'swin-small_3rdparty_in1k', + 'swin-tiny_16xb64_in1k', + 'swin-tiny_3rdparty_in1k'] + + List all EVA models for image classification task. + + >>> from mmpretrain import list_models + >>> list_models('eva', task='Image Classification') + ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px', + 'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px', + 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px', + 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px', + 'eva-l-p14_mim-pre_3rdparty_in1k-196px', + 'eva-l-p14_mim-pre_3rdparty_in1k-336px'] + """ + ModelHub._register_mmpretrain_models() + matches = set(ModelHub._models_dict.keys()) + + if pattern is not None: + # Always match keys with any postfix. + matches = set(fnmatch.filter(matches, pattern + '*')) + + exclude_patterns = exclude_patterns or [] + for exclude_pattern in exclude_patterns: + exclude = set(fnmatch.filter(matches, exclude_pattern + '*')) + matches = matches - exclude + + if task is not None: + task_matches = [] + for key in matches: + metainfo = ModelHub._models_dict[key] + if metainfo.results is None and task == 'null': + task_matches.append(key) + elif metainfo.results is None: + continue + elif task in [result.task for result in metainfo.results]: + task_matches.append(key) + matches = task_matches + + return sorted(list(matches)) + + +def inference_model(model, *args, **kwargs): + """Inference an image with the inferencer. + + Automatically select inferencer to inference according to the type of + model. It's a shortcut for a quick start, and for advanced usage, please + use the correspondding inferencer class. + + Here is the mapping from task to inferencer: + + - Image Classification: :class:`ImageClassificationInferencer` + - Image Retrieval: :class:`ImageRetrievalInferencer` + - Image Caption: :class:`ImageCaptionInferencer` + - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer` + - Visual Grounding: :class:`VisualGroundingInferencer` + - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer` + - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer` + - NLVR: :class:`NLVRInferencer` + + Args: + model (BaseModel | str | Config): The loaded model, the model + name or the config of the model. + *args: Positional arguments to call the inferencer. + **kwargs: Other keyword arguments to initialize and call the + correspondding inferencer. + + Returns: + result (dict): The inference results. + """ # noqa: E501 + from mmengine.model import BaseModel + + if isinstance(model, BaseModel): + metainfo = getattr(model, '_metainfo', None) + else: + metainfo = ModelHub.get(model) + + from inspect import signature + + from .image_caption import ImageCaptionInferencer + from .image_classification import ImageClassificationInferencer + from .image_retrieval import ImageRetrievalInferencer + from .multimodal_retrieval import (ImageToTextRetrievalInferencer, + TextToImageRetrievalInferencer) + from .nlvr import NLVRInferencer + from .visual_grounding import VisualGroundingInferencer + from .visual_question_answering import VisualQuestionAnsweringInferencer + task_mapping = { + 'Image Classification': ImageClassificationInferencer, + 'Image Retrieval': ImageRetrievalInferencer, + 'Image Caption': ImageCaptionInferencer, + 'Visual Question Answering': VisualQuestionAnsweringInferencer, + 'Visual Grounding': VisualGroundingInferencer, + 'Text-To-Image Retrieval': TextToImageRetrievalInferencer, + 'Image-To-Text Retrieval': ImageToTextRetrievalInferencer, + 'NLVR': NLVRInferencer, + } + + inferencer_type = None + + if metainfo is not None and metainfo.results is not None: + tasks = set(result.task for result in metainfo.results) + inferencer_type = [ + task_mapping.get(task) for task in tasks if task in task_mapping + ] + if len(inferencer_type) > 1: + inferencer_names = [cls.__name__ for cls in inferencer_type] + warnings.warn('The model supports multiple tasks, auto select ' + f'{inferencer_names[0]}, you can also use other ' + f'inferencer {inferencer_names} directly.') + inferencer_type = inferencer_type[0] + + if inferencer_type is None: + raise NotImplementedError('No available inferencer for the model') + + init_kwargs = { + k: kwargs.pop(k) + for k in list(kwargs) + if k in signature(inferencer_type).parameters.keys() + } + + inferencer = inferencer_type(model, **init_kwargs) + return inferencer(*args, **kwargs)[0] diff --git a/mmpretrain/apis/multimodal_retrieval.py b/mmpretrain/apis/multimodal_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb9c859aca309306c1e775b7a03bf3bbc1c7717 --- /dev/null +++ b/mmpretrain/apis/multimodal_retrieval.py @@ -0,0 +1,603 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union + +import mmengine +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import BaseDataset, Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from mmpretrain.utils import track +from .base import BaseInferencer +from .base import InputType as ImageType +from .base import ModelType +from .model import list_models + + +def filter_transforms(transforms: list, data_info: dict): + """Filter pipeline to avoid KeyError with partial data info.""" + data_info = deepcopy(data_info) + filtered_transforms = [] + for t in transforms: + try: + data_info = t(data_info) + filtered_transforms.append(t) + except KeyError: + pass + return filtered_transforms + + +class TextToImageRetrievalInferencer(BaseInferencer): + """The inferencer for text to image retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``TextToImageRetrievalInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader | BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The directory of the the images. + - list: A list of path of the images. + - dict: A config dict of the a prototype dataset. + - BaseDataset: A prototype dataset. + - DataLoader: A data loader to load the prototype data. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + fast_match (bool): Some algorithms will record extra image features for + further matching, which may consume large memory, set True to avoid + this behavior. Defaults to True. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import TextToImageRetrievalInferencer + >>> inferencer = TextToImageRetrievalInferencer( + ... 'blip-base_3rdparty_retrieval', + ... prototype='./demo/', + ... prototype_cache='t2i_retri.pth') + >>> inferencer('A cat and a dog.')[0] + {'match_score': tensor(0.3855, device='cuda:0'), + 'sample_idx': 1, + 'sample': {'img_path': './demo/cat-dog.png'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__(self, + model: ModelType, + prototype, + prototype_cache=None, + fast_match=True, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.img_pipeline, self.text_pipeline = self.pipeline + + if hasattr(self.model, 'fast_match'): + self.model.fast_match = fast_match + + self.prototype_dataset = self._prepare_prototype( + prototype, prototype_cache, batch_size=prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A directory path of images + prototype = dict( + type='CustomDataset', with_label=False, data_root=prototype) + + if isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, dict): + # A config of dataset + from mmpretrain.registry import DATASETS + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + prototype.setdefault('pipeline', test_pipeline) + dataset = DATASETS.build(prototype) + dataloader = build_dataloader(dataset) + elif isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, DataLoader): + dataset = prototype.dataset + dataloader = prototype + elif isinstance(prototype, BaseDataset): + dataset = prototype + dataloader = build_dataloader(dataset) + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + if cache is not None and Path(cache).exists(): + self.prototype = torch.load(cache) + else: + prototype = [] + for data_batch in track(dataloader, 'Prepare prototype...'): + with torch.no_grad(): + data_batch = self.model.data_preprocessor( + data_batch, False) + feats = self.model._run_forward(data_batch, mode='tensor') + prototype.append(feats) + prototype = { + k: torch.cat([d[k] for d in prototype]) + for k in prototype[0] + } + self.prototype = prototype + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + torch.save(self.prototype, path) + + def __call__(self, + inputs: ImageType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + @torch.no_grad() + def forward(self, data: dict, **kwargs): + """Feed the inputs to the model.""" + data = self.model.data_preprocessor(data, False) + data_samples = data['data_samples'] + feats = self.prototype.copy() + feats.update(self.model.extract_feat(data_samples=data_samples)) + return self.model.predict_all(feats, data_samples, cal_i2t=False)[0] + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg] + img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)} + text_info = {'text': 'example'} + img_pipeline = Compose(filter_transforms(test_transfroms, img_info)) + text_pipeline = Compose(filter_transforms(test_transfroms, text_info)) + return img_pipeline, text_pipeline + + def preprocess(self, inputs: List[str], batch_size: int = 1): + + def process_text(input_: str): + return self.text_pipeline({'text': input_}) + + chunked_data = self._get_chunk_data( + map(process_text, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[str], + preds: List[DataSample], + topk: int = 3, + figsize: Tuple[int, int] = (16, 9), + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)): + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_t2i_retrieval( + text, + data_sample, + self.prototype_dataset, + topk=topk, + fig_cfg=dict(figsize=figsize), + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + sample = self.prototype_dataset.get_data_info( + sample_idx.item()) + sample_idx = sample.pop('sample_idx') + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'sample': sample + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Text-To-Image Retrieval') + + +class ImageToTextRetrievalInferencer(BaseInferencer): + """The inferencer for image to text retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageToTextRetrievalInferencer.list_models()`` and you can + also query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader, BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The file path to load the string list. + - list: A list of string. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + fast_match (bool): Some algorithms will record extra image features for + further matching, which may consume large memory, set True to avoid + this behavior. Defaults to True. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageToTextRetrievalInferencer + >>> inferencer = ImageToTextRetrievalInferencer( + ... 'blip-base_3rdparty_retrieval', + ... prototype=['cat', 'dog', 'snake', 'bird'], + ... prototype_cache='i2t_retri.pth') + >>> inferencer('demo/bird.JPEG')[0] + {'match_score': tensor(0.3855, device='cuda:0'), + 'sample_idx': 1, + 'sample': {'img_path': './demo/cat-dog.png'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__(self, + model: ModelType, + prototype, + prototype_cache=None, + fast_match=True, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.img_pipeline, self.text_pipeline = self.pipeline + + if hasattr(self.model, 'fast_match'): + self.model.fast_match = fast_match + + self.prototype_dataset = self._prepare_prototype( + prototype, cache=prototype_cache, batch_size=prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + [ + self.text_pipeline({ + 'sample_idx': i, + 'text': text + }) for i, text in enumerate(dataset) + ], + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A file path of a list of string + dataset = mmengine.list_from_file(prototype) + elif mmengine.utils.is_seq_of(prototype, str): + dataset = prototype + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + dataloader = build_dataloader(dataset) + + if cache is not None and Path(cache).exists(): + self.prototype = torch.load(cache) + else: + prototype = [] + for data_batch in track(dataloader, 'Prepare prototype...'): + with torch.no_grad(): + data_batch = self.model.data_preprocessor( + data_batch, False) + feats = self.model._run_forward(data_batch, mode='tensor') + prototype.append(feats) + prototype = { + k: torch.cat([d[k] for d in prototype]) + for k in prototype[0] + } + self.prototype = prototype + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + torch.save(self.prototype, path) + + def __call__(self, + inputs: ImageType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + @torch.no_grad() + def forward(self, data: dict, **kwargs): + """Feed the inputs to the model.""" + data = self.model.data_preprocessor(data, False) + feats = self.prototype.copy() + feats.update(self.model.extract_feat(images=data['images'])) + return self.model.predict_all( + feats, data['data_samples'], cal_t2i=False)[0] + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg] + img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)} + text_info = {'text': 'example'} + img_pipeline = Compose(filter_transforms(test_transfroms, img_info)) + text_pipeline = Compose(filter_transforms(test_transfroms, text_info)) + return img_pipeline, text_pipeline + + def preprocess(self, inputs: List[ImageType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.img_pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[ImageType], + preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_i2t_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + text = self.prototype_dataset[sample_idx.item()] + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'text': text + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image-To-Text Retrieval') diff --git a/mmpretrain/apis/nlvr.py b/mmpretrain/apis/nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..9977c3b06f36fa61a3cd2edf36077a993b2030cd --- /dev/null +++ b/mmpretrain/apis/nlvr.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + +InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str] +InputsType = Union[List[InputType], InputType] + + +class NLVRInferencer(BaseInferencer): + """The inferencer for Natural Language for Visual Reasoning. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``NLVRInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + """ + + visualize_kwargs: set = { + 'resize', 'draw_score', 'show', 'show_dir', 'wait_time' + } + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (tuple, List[tuple]): The input data tuples, every tuple + should include three items (left image, right image, text). + The image can be a path or numpy array. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + assert isinstance(inputs, (tuple, list)) + if isinstance(inputs, tuple): + inputs = [inputs] + for input_ in inputs: + assert isinstance(input_, tuple) + assert len(input_) == 3 + + return super().__call__( + inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + assert test_pipeline_cfg[0]['type'] == 'ApplyToList' + + list_pipeline = deepcopy(test_pipeline_cfg[0]) + if list_pipeline.scatter_key == 'img_path': + # Remove `LoadImageFromFile` + list_pipeline.transforms.pop(0) + list_pipeline.scatter_key = 'img' + + test_pipeline = Compose( + [TRANSFORMS.build(list_pipeline)] + + [TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]]) + return test_pipeline + + def preprocess(self, inputs: InputsType, batch_size: int = 1): + + def load_image(input_): + img1 = imread(input_[0]) + img2 = imread(input_[1]) + text = input_[2] + if img1 is None: + raise ValueError(f'Failed to read image {input_[0]}.') + if img2 is None: + raise ValueError(f'Failed to read image {input_[1]}.') + return dict( + img=[img1, img2], + img_shape=[img1.shape[:2], img2.shape[:2]], + ori_shape=[img1.shape[:2], img2.shape[:2]], + text=text, + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + pred_scores = data_sample.pred_score + pred_score = float(torch.max(pred_scores).item()) + pred_label = torch.argmax(pred_scores).item() + result = { + 'pred_scores': pred_scores.detach().cpu().numpy(), + 'pred_label': pred_label, + 'pred_score': pred_score, + } + results.append(result) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='NLVR') diff --git a/mmpretrain/apis/utils.py b/mmpretrain/apis/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83e76325472f6925f78c746e3a10f3a58b0e6de4 --- /dev/null +++ b/mmpretrain/apis/utils.py @@ -0,0 +1,270 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from collections import defaultdict +from contextlib import contextmanager +from itertools import chain +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.utils import require + + +@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/') +@require('accelerate') +def dispatch_model( + model, + device_map: Union[str, dict], + max_memory: Optional[dict] = None, + no_split_module_classes: Optional[List[str]] = None, + offload_folder: str = None, + offload_buffers: bool = False, + preload_module_classes: Optional[List[str]] = None, +): + """Split and dispatch a model across devices. + + The function depends on the `accelerate` package. Refers to + https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling + + Args: + model (torch.nn.Module): The model to dispatch. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + max_memory (dict | None): A dictionary device identifier to maximum + memory. Will default to the maximum memory available for each GPU + and the available CPU RAM if unset. Defaults to None. + no_split_module_classes (List[str] | None): A list of layer class names + that should never be split across device (for instance any layer + that has a residual connection). If None, try to get the settings + from the model class. Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + offload_buffers (bool): In the layers that are offloaded on the CPU + or the hard drive, whether or not to offload the buffers as + well as the parameters. Defaults to False. + preload_module_classes (List[str] | None): A list of classes whose + instances should load all their weights (even in the submodules) at + the beginning of the forward. This should only be used for classes + that have submodules which are registered but not called directly + during the forward, for instance if a `dense` linear layer is + registered, but at forward, `dense.weight` and `dense.bias` are + used in some operations instead of calling `dense` directly. + Defaults to None. + """ + from accelerate import dispatch_model, infer_auto_device_map + + # Check valid device_map string. + valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential'] + if isinstance(device_map, str) and device_map not in valid_map_option: + raise ValueError('If passing a string for `device_map`, please choose ' + f'from {valid_map_option}.') + + # Generate device map automatically + if isinstance(device_map, str): + if no_split_module_classes is None: + no_split_module_classes = getattr(model, '_no_split_modules', None) + if no_split_module_classes is None: + raise ValueError(f'{model.__class__.__name__} does not support ' + f"`device_map='{device_map}'` yet.") + + if device_map != 'sequential': + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=None, + low_zero=(device_map == 'balanced_low_0'), + ) + max_memory[0] *= 0.9 + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=None, + ) + + if 'disk' in device_map.values(): + if offload_folder is None: + raise ValueError( + 'The current `device_map` had weights offloaded to the disk. ' + 'Please provide an `offload_folder` for them.') + os.makedirs(offload_folder, exist_ok=True) + + main_device = next( + (d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu') + + model = dispatch_model( + model, + device_map=device_map, + main_device=main_device, + offload_dir=offload_folder, + offload_buffers=offload_buffers, + preload_module_classes=preload_module_classes, + ) + if hasattr(model, 'data_preprocessor'): + model.data_preprocessor._device = torch.device(main_device) + return model + + +@contextmanager +def init_empty_weights(include_buffers: bool = False): + """A context manager under which models are initialized with all parameters + on the meta device. + + With this context manager, we can create an empty model. Useful when just + initializing the model would blow the available RAM. + + Besides move the parameters to meta device, this method will also avoid + load checkpoint from `mmengine.runner.load_checkpoint` and + `transformers.PreTrainedModel.from_pretrained`. + + Modified from https://github.com/huggingface/accelerate + + Args: + include_buffers (bool): Whether put all buffers on the meta device + during initialization. + """ + device = torch.device('meta') + + # move parameter and buffer to meta device + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + # See https://github.com/huggingface/accelerate/pull/699 + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ['empty', 'zeros', 'ones', 'full'] + } + + def register_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs) + + def register_buffer(module, name, buffer, *args, **kwargs): + old_register_buffer(module, name, buffer, *args, **kwargs) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + + def wrapper(*args, **kwargs): + kwargs['device'] = device + return fn(*args, **kwargs) + + return wrapper + + # Patch load_checkpoint + import mmengine.runner.checkpoint as mmengine_load + old_load_checkpoint = mmengine_load.load_checkpoint + + def patch_load_checkpoint(*args, **kwargs): + return {} + + # Patch transformers from pretrained + try: + from transformers import PreTrainedModel + from transformers.models.auto.auto_factory import (AutoConfig, + _BaseAutoModelClass) + with_transformers = True + except ImportError: + with_transformers = False + + @classmethod + def patch_auto_model(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, + *model_args, **kwargs) + return cls.from_config(cfg) + + @classmethod + def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path, + *model_args, **kwargs) + return cls(cfg) + + if with_transformers: + old_pretrained_model = PreTrainedModel.from_pretrained + old_auto_model = _BaseAutoModelClass.from_pretrained + + try: + nn.Module.register_parameter = register_parameter + mmengine_load.load_checkpoint = patch_load_checkpoint + if with_transformers: + PreTrainedModel.from_pretrained = patch_pretrained_model + _BaseAutoModelClass.from_pretrained = patch_auto_model + if include_buffers: + nn.Module.register_buffer = register_buffer + for func in tensor_constructors_to_patch.keys(): + tensor_constructor = patch_tensor_constructor( + getattr(torch, func)) + setattr(torch, func, tensor_constructor) + yield + finally: + nn.Module.register_parameter = old_register_parameter + mmengine_load.load_checkpoint = old_load_checkpoint + if with_transformers: + PreTrainedModel.from_pretrained = old_pretrained_model + _BaseAutoModelClass.from_pretrained = old_auto_model + if include_buffers: + nn.Module.register_buffer = old_register_buffer + for func, ori in tensor_constructors_to_patch.items(): + setattr(torch, func, ori) + + +def compute_module_sizes( + model: nn.Module, + dtype: Union[str, torch.dtype, None] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None): + """Compute the size of each submodule of a given model.""" + + def get_dtype(dtype): + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + if dtype is not None: + assert issubclass(dtype, torch.dtype) + return dtype + + def dtype_bytes(dtype: torch.dtype): + if dtype is torch.bool: + return 1 + if dtype.is_floating_point: + return torch.finfo(dtype).bits / 8 + else: + return torch.iinfo(dtype).bits / 8 + + if dtype is not None: + dtype = get_dtype(dtype) + dtype_size = dtype_bytes(dtype) + + if special_dtypes is not None: + special_dtypes = { + key: dtype_bytes(dtype) + for key, dtype in special_dtypes.items() + } + + module_sizes = defaultdict(int) + for name, tensor in chain( + model.named_parameters(recurse=True), + model.named_buffers(recurse=True)): + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes[name] + elif dtype is None: + size = tensor.numel() * tensor.element_size() + else: + size = tensor.numel() * min(dtype_size, tensor.element_size()) + name_parts = name.split('.') + for idx in range(len(name_parts) + 1): + module_sizes['.'.join(name_parts[:idx])] += size + + return module_sizes diff --git a/mmpretrain/apis/visual_grounding.py b/mmpretrain/apis/visual_grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..0153d56f5ca10a32e9fd2ccabb0d15c1135e213d --- /dev/null +++ b/mmpretrain/apis/visual_grounding.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + + +class VisualGroundingInferencer(BaseInferencer): + """The inferencer for visual grounding. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``VisualGroundingInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import VisualGroundingInferencer + >>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco') + >>> inferencer('demo/cat-dog.png', 'dog')[0] + {'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])} + """ # noqa: E501 + + visualize_kwargs: set = { + 'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color' + } + + def __call__(self, + images: Union[str, np.ndarray, list], + texts: Union[str, list], + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + texts (str | list): The text to do visual grounding. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + line_width (int): The line width of the bbox. Defaults to 3. + bbox_color (str | tuple): The color of the bbox. + Defaults to 'green'. + + Returns: + list: The inference results. + """ + if not isinstance(images, (list, tuple)): + assert isinstance(texts, str) + inputs = [{'img': images, 'text': texts}] + else: + inputs = [] + for i in range(len(images)): + input_ = {'img': images[i], 'text': texts[i]} + inputs.append(input_) + + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[dict], batch_size: int = 1): + + def load_image(input_: dict): + img = imread(input_['img']) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return {**input_, 'img': img} + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[dict], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + line_width: int = 3, + bbox_color: Union[str, tuple] = 'green', + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_['img']) + if isinstance(input_['img'], str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_['img']).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_visual_grounding( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + line_width=line_width, + bbox_color=bbox_color, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({'pred_bboxes': data_sample.get('pred_bboxes')}) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Visual Grounding') diff --git a/mmpretrain/apis/visual_question_answering.py b/mmpretrain/apis/visual_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..616e1edc66709401df83cb5253590325e727aa98 --- /dev/null +++ b/mmpretrain/apis/visual_question_answering.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + + +class VisualQuestionAnsweringInferencer(BaseInferencer): + """The inferencer for visual question answering. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``VisualQuestionAnsweringInferencer.list_models()`` and you can + also query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import VisualQuestionAnsweringInferencer + >>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa') + >>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0] + {'question': "What's the animal next to the dog?", 'pred_answer': 'cat'} + """ # noqa: E501 + + visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'} + + def __call__(self, + images: Union[str, np.ndarray, list], + questions: Union[str, list], + return_datasamples: bool = False, + batch_size: int = 1, + objects: Optional[List[str]] = None, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + questions (str | list): The question to the correspondding image. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + objects (List[List[str]], optional): Some algorithms like OFA + fine-tuned VQA models requires extra object description list + for every image. Defaults to None. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + if not isinstance(images, (list, tuple)): + assert isinstance(questions, str) + inputs = [{'img': images, 'question': questions}] + if objects is not None: + assert isinstance(objects[0], str) + inputs[0]['objects'] = objects + else: + inputs = [] + for i in range(len(images)): + input_ = {'img': images[i], 'question': questions[i]} + if objects is not None: + input_['objects'] = objects[i] + inputs.append(input_) + + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[dict], batch_size: int = 1): + + def load_image(input_: dict): + img = imread(input_['img']) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return {**input_, 'img': img} + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[dict], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_['img']) + if isinstance(input_['img'], str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_['img']).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_vqa( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({ + 'question': data_sample.get('question'), + 'pred_answer': data_sample.get('pred_answer'), + }) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Visual Question Answering') diff --git a/mmpretrain/configs/.DS_Store b/mmpretrain/configs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..6ba79b9412da6bbc6029b54bdc777714e41305c1 Binary files /dev/null and b/mmpretrain/configs/.DS_Store differ diff --git a/mmpretrain/configs/_base_/datasets/cifar10_bs16.py b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py new file mode 100644 index 0000000000000000000000000000000000000000..3737dbee9a669a231c4aa93a711fa4b231bdf073 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CIFAR10 +data_preprocessor = dict( + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type=RandomCrop, crop_size=32, padding=4), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10/', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/cub_bs8_384.py b/mmpretrain/configs/_base_/datasets/cub_bs8_384.py new file mode 100644 index 0000000000000000000000000000000000000000..b193bf83cedaac3d358ac54ec63618833f6544d7 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cub_bs8_384.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile, + PackInputs, RandomCrop, RandomFlip, Resize) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CUB +data_preprocessor = dict( + num_classes=200, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=510), + dict(type=RandomCrop, crop_size=384), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=510), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=8, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/CUB_200_2011', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=8, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/CUB_200_2011', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py b/mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py new file mode 100644 index 0000000000000000000000000000000000000000..11c4c0a4b74a1218c050d2425bfa0b2915011ef6 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (ImageNet21k, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop) + +# dataset settings +dataset_type = ImageNet21k +data_preprocessor = dict( + num_classes=21842, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet21k', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0aa629d72fcacca755eeef3ed16e5d21824d40 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet, + LoadImageFromFile, PackInputs, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py b/mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py new file mode 100644 index 0000000000000000000000000000000000000000..f89eb17b846c25cea4c709829ff516eebb15e4e7 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler, default_collate + +from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet, + LoadImageFromFile, PackInputs, RandomFlip, + RandomResizedCropAndInterpolationWithTwoPic) +from mmpretrain.models import TwoNormDataPreprocessor + +dataset_type = ImageNet +data_root = 'data/imagenet/' + +data_preprocessor = dict( + type=TwoNormDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + second_mean=[127.5, 127.5, 127.5], + second_std=[127.5, 127.5, 127.5], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, + hue=0.), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandomResizedCropAndInterpolationWithTwoPic, + size=224, + second_size=224, + interpolation='bicubic', + second_interpolation='bicubic', + scale=(0.2, 1.0)), + dict( + type=BEiTMaskGenerator, + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=75, + min_num_patches=16), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=dataset_type, + data_root=data_root, + split='train', + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py new file mode 100644 index 0000000000000000000000000000000000000000..7d074008cc204f4ac486dc04fb3f1c638fb9e161 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..f911bc20ff68fb2bb34b3ce495bc784ac0d0f62d --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py new file mode 100644 index 0000000000000000000000000000000000000000..29b698f498eb4a4e4aaf8fb0cab04129704d484a --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmcv.transforms import (LoadImageFromFile, RandomApply, RandomFlip, + RandomGrayscale) +from mmengine.dataset import DefaultSampler, default_collate + +from mmpretrain.datasets import (ColorJitter, GaussianBlur, ImageNet, + MultiView, PackInputs, RandomResizedCrop) +from mmpretrain.models import SelfSupDataPreprocessor + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=SelfSupDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +view_pipeline = [ + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomApply, + transforms=[ + dict( + type=ColorJitter, + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2) + ], + prob=0.8), + dict( + type=RandomGrayscale, + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989)), + dict( + type=GaussianBlur, + magnitude_range=(0.1, 2.0), + magnitude_std='inf', + prob=0.5), +] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=MultiView, num_views=2, transforms=[view_pipeline]), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=32, + num_workers=4, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=ImageNet, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..017f5b7807eee0855bb427d0f445e4225127d08e --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmcv.transforms import LoadImageFromFile, RandomFlip +from mmengine.dataset.sampler import DefaultSampler + +from mmpretrain.datasets import ImageNet, PackInputs, RandomResizedCrop +from mmpretrain.models import SelfSupDataPreprocessor + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=SelfSupDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + crop_ratio_range=(0.2, 1.0), + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=512, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + data_root=data_root, + split='train', + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d8aea8bc2ec149031eab87f1a15540d5fec312 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f052662e4f834892ab6f813e15e3c1c7bb4e7d --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.datasets.transforms import AutoAugment +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py new file mode 100644 index 0000000000000000000000000000000000000000..5a38943e270777426fb0ec3e991afbccce2a8873 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandAugment, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py new file mode 100644 index 0000000000000000000000000000000000000000..9690ff8447895d656d345f380bf324420a9b72df --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandAugment, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=256, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=292, # ( 256 / 224 * 256 ) + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=256), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py new file mode 100644 index 0000000000000000000000000000000000000000..85aeb1e2c131109f3f6d75d21e2cc1c782c82b7f --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (ImageNet, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, Resize) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=384, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=384, backend='pillow', interpolation='bicubic'), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/default_runtime.py b/mmpretrain/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c748eb84b3e50d7c6b30efaa87cd3c1f2f1827 --- /dev/null +++ b/mmpretrain/configs/_base_/default_runtime.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.visualization import LocalVisBackend + +from mmpretrain.engine.hooks import VisualizationHook +from mmpretrain.visualization import UniversalVisualizer + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=100), + + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), + + # validation results visualization, set True to enable it. + visualization=dict(type=VisualizationHook, enable=False), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict(type=UniversalVisualizer, vis_backends=vis_backends) + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# Do not need to specify default_scope with new config. Therefore set it to +# None to avoid BC-breaking. +default_scope = None diff --git a/mmpretrain/configs/_base_/models/convnext_base.py b/mmpretrain/configs/_base_/models/convnext_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6315b2f1966d2484739087e1e131fe8dd9a2ad56 --- /dev/null +++ b/mmpretrain/configs/_base_/models/convnext_base.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model import TruncNormalInit + +from mmpretrain.models import (ConvNeXt, CutMix, ImageClassifier, + LabelSmoothLoss, LinearClsHead, Mixup) + +# Model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=ConvNeXt, arch='base', drop_path_rate=0.5), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type=TruncNormalInit, layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0), + ]), +) diff --git a/mmpretrain/configs/_base_/models/mae_hivit_base_p16.py b/mmpretrain/configs/_base_/models/mae_hivit_base_p16.py new file mode 100644 index 0000000000000000000000000000000000000000..975e16b44626198ced6494b26d707d3501094f3c --- /dev/null +++ b/mmpretrain/configs/_base_/models/mae_hivit_base_p16.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (MAE, MAEHiViT, MAEPretrainDecoder, + MAEPretrainHead, PixelReconstructionLoss) + +# model settings +model = dict( + type=MAE, + backbone=dict(type=MAEHiViT, patch_size=16, arch='base', mask_ratio=0.75), + neck=dict( + type=MAEPretrainDecoder, + patch_size=16, + in_chans=3, + embed_dim=512, + decoder_embed_dim=512, + decoder_depth=6, + decoder_num_heads=16, + mlp_ratio=4., + ), + head=dict( + type=MAEPretrainHead, + norm_pix=True, + patch_size=16, + loss=dict(type=PixelReconstructionLoss, criterion='L2')), + init_cfg=[ + dict(type='Xavier', layer='Linear', distribution='uniform'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) diff --git a/mmpretrain/configs/_base_/models/mae_vit_base_p16.py b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py new file mode 100644 index 0000000000000000000000000000000000000000..9347d1e8810e553ef5563a96198794ec139ea3a4 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (MAE, MAEPretrainDecoder, MAEPretrainHead, + MAEViT, PixelReconstructionLoss) + +# model settings +model = dict( + type=MAE, + backbone=dict(type=MAEViT, arch='b', patch_size=16, mask_ratio=0.75), + neck=dict( + type=MAEPretrainDecoder, + patch_size=16, + in_chans=3, + embed_dim=768, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4., + ), + head=dict( + type=MAEPretrainHead, + norm_pix=True, + patch_size=16, + loss=dict(type=PixelReconstructionLoss, criterion='L2')), + init_cfg=[ + dict(type='Xavier', layer='Linear', distribution='uniform'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) diff --git a/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..17dbb9fdd88c26767c1be7faeb0689be597626df --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, MobileNetV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV2, widen_factor=1.0), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1280, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/models/mobilenet_v3_small.py b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py new file mode 100644 index 0000000000000000000000000000000000000000..83edab592063113b86965dfb173fca7eb6f630cd --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model.weight_init import NormalInit +from torch.nn.modules.activation import Hardswish + +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, MobileNetV3, + StackedLinearClsHead) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV3, arch='small'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=StackedLinearClsHead, + num_classes=1000, + in_channels=576, + mid_channels=[1024], + dropout_rate=0.2, + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + init_cfg=dict( + type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/models/resnet18.py b/mmpretrain/configs/_base_/models/resnet18.py new file mode 100644 index 0000000000000000000000000000000000000000..30b8f65148611c5602858b875b9be89b31f225cb --- /dev/null +++ b/mmpretrain/configs/_base_/models/resnet18.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, ResNet) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=ResNet, + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=512, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/models/swin_transformer_base.py b/mmpretrain/configs/_base_/models/swin_transformer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c73c254d7a8af9524091ead1d61d1320541d3c5e --- /dev/null +++ b/mmpretrain/configs/_base_/models/swin_transformer_base.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, SwinTransformer) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=SwinTransformer, + arch='base', + img_size=384, + stage_cfgs=dict(block_cfgs=dict(window_size=12))), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py b/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c7566b5e1a08c7c8c9bf4afee57d973b6801d6c3 --- /dev/null +++ b/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (GlobalAveragePooling, ImageClassifier, + LabelSmoothLoss, LinearClsHead, + SwinTransformerV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + cal_acc=False)) diff --git a/mmpretrain/configs/_base_/models/vit_base_p16.py b/mmpretrain/configs/_base_/models/vit_base_p16.py new file mode 100644 index 0000000000000000000000000000000000000000..326c50aea815d023acf2efb12718cd847677fc60 --- /dev/null +++ b/mmpretrain/configs/_base_/models/vit_base_p16.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model.weight_init import KaimingInit + +from mmpretrain.models import (ImageClassifier, LabelSmoothLoss, + VisionTransformer, VisionTransformerClsHead) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=VisionTransformer, + arch='b', + img_size=224, + patch_size=16, + drop_rate=0.1, + init_cfg=[ + dict( + type=KaimingInit, + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ]), + neck=None, + head=dict( + type=VisionTransformerClsHead, + num_classes=1000, + in_channels=768, + loss=dict( + type=LabelSmoothLoss, label_smooth_val=0.1, mode='classy_vision'), + )) diff --git a/mmpretrain/configs/_base_/schedules/cifar10_bs128.py b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab749e8b648aca8a50fc7775281330fa1ce2a2b --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128) diff --git a/mmpretrain/configs/_base_/schedules/cub_bs64.py b/mmpretrain/configs/_base_/schedules/cub_bs64.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca40bfe36efdd315f63ca872bcebb5247747f26 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cub_bs64.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict( + type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True)) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=0.01, + by_epoch=True, + begin=0, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type=CosineAnnealingLR, + T_max=95, + by_epoch=True, + begin=5, + end=100, + ) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=64) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..60ccaa0e25ec69aa618430f51a60d949506fc406 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +# for batch in each gpu is 128, 8 gpu +# lr = 5e-4 * 128 * 8 / 512 = 0.001 +optim_wrapper = dict( + optimizer=dict( + type=AdamW, + lr=5e-4 * 1024 / 512, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + flat_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=1e-3, + by_epoch=True, + end=20, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type=CosineAnnealingLR, eta_min=1e-5, by_epoch=True, begin=20) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py new file mode 100644 index 0000000000000000000000000000000000000000..95afa2ad292c277a84aa274786ee34a9d6b8b0ef --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[30, 60, 90], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py new file mode 100644 index 0000000000000000000000000000000000000000..9d245ebb9c35345c457e502f350b756ead181ffe --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import StepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004)) + +# learning policy +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py b/mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..4561f23db1b7b546fd9667ef51aed81dd9e6d4a7 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=AdamW, lr=0.003, weight_decay=0.3), + # specific to vit pretrain + paramwise_cfg=dict(custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=30, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type=CosineAnnealingLR, + T_max=270, + by_epoch=True, + begin=30, + end=300, + ) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7e6171e2aeb20a94277e7ca4d02b2598d73b8e --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop + +from mmpretrain.engine.optimizers.lars import LARS + +# optimizer wrapper +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=LARS, lr=4.8, weight_decay=1e-6, momentum=0.9)) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict(type=CosineAnnealingLR, T_max=190, by_epoch=True, begin=10, end=200) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=200) diff --git a/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9c329abba18ecb5bad1875090f7de667a77391 --- /dev/null +++ b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmengine.dataset import DefaultSampler, default_collate +from mmengine.hooks import CheckpointHook +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet, + LoadImageFromFile, PackInputs, RandomFlip, + RandomResizedCropAndInterpolationWithTwoPic) +from mmpretrain.models import (BEiT, BEiTPretrainViT, BEiTV1Head, + CrossEntropyLoss, DALLEEncoder, + TwoNormDataPreprocessor) + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=TwoNormDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + second_mean=[-31.875, -31.875, -31.875], + second_std=[318.75, 318.75, 318.75], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, + hue=0.), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandomResizedCropAndInterpolationWithTwoPic, + size=224, + second_size=112, + interpolation='bicubic', + second_interpolation='lanczos', + scale=(0.08, 1.0)), + dict( + type=BEiTMaskGenerator, + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=None, + min_num_patches=16), + dict(type=PackInputs) +] +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) + +# model settings +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + out_type='raw', + layer_scale_init_value=0.1, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=None, + head=dict( + type=BEiTV1Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=DALLEEncoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa: E251 + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501 + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +default_hooks.update( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness.update(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..00a76b75e1581294d4c5fe885c1cf3d8b1caff8e --- /dev/null +++ b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.hooks import CheckpointHook +from mmengine.model import PretrainedInit, TruncNormalInit +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from mmpretrain.datasets import LoadImageFromFile, PackInputs, RandomFlip +from mmpretrain.engine.optimizers import \ + LearningRateDecayOptimWrapperConstructor +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments import CutMix, Mixup + +data_preprocessor = dict( + num_classes=1000, + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + to_rgb=True, +) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + drop_path_rate=0.1, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + init_cfg=dict(type=PretrainedInit, checkpoint='', prefix='backbone.')), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type=TruncNormalInit, layer='Linear', std=0.02)]), + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs) +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict(type=AdamW, lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)), + constructor=LearningRateDecayOptimWrapperConstructor, + paramwise_cfg=dict( + _delete_=True, + layer_decay_rate=0.65, + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + by_epoch=True, + begin=20, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0) diff --git a/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..b4718afbc80496e5129bcf3424a815248793f47a --- /dev/null +++ b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments import CutMix, Mixup + +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + ), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=.02), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..6bec16b342d25f69676b26703798f4dbe4a55899 --- /dev/null +++ b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs256_beitv2 import * + from .._base_.default_runtime import * + +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (VQKD, BEiT, BEiTPretrainViT, BEiTV2Head, + BEiTV2Neck, CrossEntropyLoss) + +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + +layer_scale_init_value = 0.1 +drop_path_rate = 0.1 # 0. for 300 epochs and 0.1 for 1600 epochs. + +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + out_indices=[-4, -1], + drop_path_rate=drop_path_rate, + final_norm=False, + out_type='raw', + layer_scale_init_value=layer_scale_init_value, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=dict( + type=BEiTV2Neck, + num_layers=2, + early_layers=9, + backbone_arch='base', + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ), + head=dict( + type=BEiTV2Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=VQKD, + encoder_config=vqkd_encoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/vqkd_encoder.pth' # noqa + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + # betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 1600 epochs. + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +default_hooks = dict( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness = dict(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe9b503c9be6a66496fc4c711cc5306e4dfdcd8 --- /dev/null +++ b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs256_beitv2 import * + from .._base_.default_runtime import * + +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (VQKD, BEiT, BEiTPretrainViT, BEiTV2Head, + BEiTV2Neck, CrossEntropyLoss) + +# model settings +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + +layer_scale_init_value = 0.1 +drop_path_rate = 0. # 0. for 300 epochs and 0.1 for 1600 epochs. +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + out_indices=[-4, -1], + drop_path_rate=drop_path_rate, + final_norm=False, + out_type='raw', + layer_scale_init_value=layer_scale_init_value, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=dict( + type=BEiTV2Neck, + num_layers=2, + early_layers=9, + backbone_arch='base', + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ), + head=dict( + type=BEiTV2Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=VQKD, + encoder_config=vqkd_encoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/vqkd_encoder.pth' # noqa + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + # betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 1600 epochs. + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.98), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +default_hooks = dict( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness = dict(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..ee32d3a98ed7da2ac60525a75d82f967df151fc7 --- /dev/null +++ b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.model import PretrainedInit, TruncNormalInit +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from mmpretrain.engine.optimizers import \ + LearningRateDecayOptimWrapperConstructor +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments import CutMix, Mixup + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + # 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs. + drop_path_rate=0.1, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + init_cfg=dict(type=PretrainedInit, checkpoint='', prefix='backbone.')), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type=TruncNormalInit, layer='Linear', std=0.02)]), + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs) +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict(type=AdamW, lr=5e-4, weight_decay=0.05, betas=(0.9, 0.999)), + constructor=LearningRateDecayOptimWrapperConstructor, + paramwise_cfg=dict( + _delete_=True, + # 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs + layer_decay_rate=0.65, + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + by_epoch=True, + begin=20, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0) diff --git a/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..ec20ba950519c790caed99fec04fd28781d6a0d1 --- /dev/null +++ b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments.cutmix import CutMix +from mmpretrain.models.utils.batch_augments.mixup import Mixup + +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + ), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=.02), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py b/mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8a10f020cb72bf0756bf0a3661a759c28e30de --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +from mmpretrain.engine import EMAHook + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py b/mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py new file mode 100644 index 0000000000000000000000000000000000000000..73fb0a0af1cf4d5eae41115cb1f55d3cee2bad5c --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model setting +model.update(head=dict(num_classes=21841)) + +# dataset setting +data_preprocessor.update(num_classes=21841) +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py new file mode 100644 index 0000000000000000000000000000000000000000..2da428a5ad9816d36f97483c389cc4d5d9acad78 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..e11e6a9f9091e19f2aa9719e85ef7a4fed793b4b --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py b/mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py new file mode 100644 index 0000000000000000000000000000000000000000..d103dfa7678944d2d5e2248c75a8356b7d77a7dd --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model setting +model.update(head=dict(num_classes=21841)) + +# dataset setting +data_preprocessor.update(num_classes=21841) +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7bce73f8e78017010de9a5525e6f9da919cbab --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..bd43ec16488573a339c3536e9ef5c7a77dc0df6a --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7bce73f8e78017010de9a5525e6f9da919cbab --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..bd43ec16488573a339c3536e9ef5c7a77dc0df6a --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py new file mode 100644 index 0000000000000000000000000000000000000000..2da428a5ad9816d36f97483c389cc4d5d9acad78 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb1157effac74299861a23b79df67c132dd276a --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py new file mode 100644 index 0000000000000000000000000000000000000000..21f10dcd605b5792caf612540f0ff3233c7675e0 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model setting +model.update(head=dict(num_classes=21841)) + +# dataset setting +data_preprocessor.update(num_classes=21841) +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..6d90e7107a3f189eab22d866d4db3f3a1b5e06bf --- /dev/null +++ b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +from mmpretrain.engine import EMAHook + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..a254ac8a84d94acdd1ec5f84059c7e75abc3cbc4 --- /dev/null +++ b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks import CheckpointHook +from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (EVA, CLIPGenerator, CosineSimilarityLoss, + MAEPretrainDecoder, MIMHead) + +# dataset settings +train_dataloader.batch_size = 256 + +# model settings +model.type = EVA +model.init_cfg = None +model.backbone.update(init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) +]) +model.neck.update( + type=MAEPretrainDecoder, + predict_feature_dim=512, + init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) +model.head = dict( + type=MIMHead, + loss=dict(type=CosineSimilarityLoss, shift_factor=2.0, scale_factor=2.0)) +model.target_generator = dict( + type=CLIPGenerator, + tokenizer_path= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa +) + +# optimizer wrapper +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) +find_unused_parameters = True + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(dict(seed=0, diff_rank_seed=True)) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..a32cb0c2e856784dd900ffbc4ef8ad674a4eb4d9 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffcf6d13c049fa8802766d74f7e5c9a803b706e --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a49b5840d8905414058b873b6a6f3a0acbb2a1 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1aba546e22e612dde9d1f41f0ae45306034ba9 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEHiViT, arch='large'), + neck=dict(type=MAEPretrainDecoder, embed_dim=768)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc1259ffce7f84c83b1412c4ecc480bfbbcc202 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEHiViT, arch='large'), + neck=dict(type=MAEPretrainDecoder, embed_dim=768)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..657ee01181ecb3c60c65c7a6fb08bed897a29012 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEHiViT, arch='large'), + neck=dict(type=MAEPretrainDecoder, embed_dim=768)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b325df877acea22cfc565903644bbd8eb1d8cd --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..6cee3bc93fd8b1c65263ac415422b9b73628e88d --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=260, + by_epoch=True, + begin=40, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..fb78e2bdfbd9b100f925958bf312fa8081b212c4 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..f34e1dac4ee444aaca71ce6f48f194d91076d0cd --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..bc91ee00bca5d77ebe873fa846cc13eaa677b9c8 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEViT, arch='h', patch_size=14), + neck=dict( + type=MAEPretrainDecoder, + embed_dim=1280, + patch_size=14, + num_patches=256), + head=dict(patch_size=14)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..ef0777af8c19e4e615cc4c5ae2976e964febd0ce --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..ea005e4b3a59ade8b778e64068e95fc4e73ed321 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=260, + by_epoch=True, + begin=40, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..6f735491a2ca8f5a43eb8ea7bb57b8c5162f77da --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a5abd58b3cde6a9422c85c08c464fbc72f5d59 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..79eec6355017ef41c10812f6c67bbc362ad0c343 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.mobilenet_v2_1x import * + from .._base_.schedules.imagenet_bs256_epochstep import * diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1bee1c132e9d5a718ba6d92be5822543b2222d --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. + +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict(arch='large'), + head=dict(in_channels=960, mid_channels=[1280]), + )) +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..50e1ffc6709e3fa490cb0bf5e1eb958d94264315 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_050', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=288), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) + +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c640cd8a0ed4d3a33b7c2ffd10e4c44229307b --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_075', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=432), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..0c220a01d098c1a0a8259f08c81bc07054ff9ebb --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..0f91ee38243543b37e73c09386a5433bfcb46458 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.cifar10_bs16 import * + from .._base_.schedules.cifar10_bs128 import * + from .._base_.default_runtime import * + +from mmengine.optim import MultiStepLR + +# model settings +model.merge( + dict( + head=dict( + _delete_=True, + type=StackedLinearClsHead, + num_classes=10, + in_channels=576, + mid_channels=[1280], + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5)))) +# schedule settings +param_scheduler.merge( + dict( + type=MultiStepLR, + by_epoch=True, + milestones=[120, 170], + gamma=0.1, + )) + +train_cfg.merge(dict(by_epoch=True, max_epochs=200)) diff --git a/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..f16d248b6988c924e8540a7782dabee4997baba1 --- /dev/null +++ b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32 import * + from .._base_.default_runtime import * + from .._base_.models.resnet18 import * + from .._base_.schedules.imagenet_bs256 import * diff --git a/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..09c738f219e561e5863dd2ce8246af005502bc83 --- /dev/null +++ b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_simclr import * + from .._base_.schedules.imagenet_lars_coslr_200e import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper + +from mmpretrain.engine.optimizers.lars import LARS +from mmpretrain.models.backbones.resnet import ResNet +from mmpretrain.models.heads.contrastive_head import ContrastiveHead +from mmpretrain.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmpretrain.models.necks.nonlinear_neck import NonLinearNeck +from mmpretrain.models.selfsup.simclr import SimCLR + +# dataset settings +train_dataloader.merge(dict(batch_size=256)) + +# model settings +model = dict( + type=SimCLR, + backbone=dict( + type=ResNet, + depth=50, + norm_cfg=dict(type='SyncBN'), + zero_init_residual=True), + neck=dict( + type=NonLinearNeck, # SimCLR non-linear neck + in_channels=2048, + hid_channels=2048, + out_channels=128, + num_layers=2, + with_avg_pool=True), + head=dict( + type=ContrastiveHead, + loss=dict(type=CrossEntropyLoss), + temperature=0.1), +) + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=LARS, lr=4.8, momentum=0.9, weight_decay=1e-6), + paramwise_cfg=dict( + custom_keys={ + 'bn': dict(decay_mult=0, lars_exclude=True), + 'bias': dict(decay_mult=0, lars_exclude=True), + # bn layer in ResNet block downsample module + 'downsample.1': dict(decay_mult=0, lars_exclude=True) + })) + +# runtime settings +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=10, max_keep_ckpts=3) diff --git a/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..09af3d0149a60d6b6b6aba7cc74e436e8d205dd1 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(img_size=224, drop_path_rate=0.5, stage_cfgs=None), + head=dict( + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..aacdc3274367cb07127d9261ef76149ab385c08f --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fc27937d2cd521dbd32eeb22e094851405b788 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='large', img_size=224, stage_cfgs=None), + head=dict(in_channels=1536), +) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..9a449aa656349d29038721a43041ec5c71b097bd --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='large'), + head=dict(in_channels=1536), +) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py b/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..2003cd3a0787b9649ee18e99e0bbb0a00ad13530 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.hooks import CheckpointHook, LoggerHook +from mmengine.model import PretrainedInit +from torch.optim.adamw import AdamW + +from mmpretrain.models import ImageClassifier + +with read_base(): + from .._base_.datasets.cub_bs8_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.cub_bs64 import * + +# model settings +checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa + +model.update( + backbone=dict( + arch='large', + init_cfg=dict( + type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')), + head=dict(num_classes=200, in_channels=1536)) + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type=AdamW, + lr=5e-6, + weight_decay=0.0005, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }), + clip_grad=dict(max_norm=5.0), +) + +default_hooks = dict( + # log every 20 intervals + logger=dict(type=LoggerHook, interval=20), + # save last three checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) diff --git a/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..59792528435c038ec86d090fe75ce0d9430a18d0 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='small', img_size=224, drop_path_rate=0.3, stage_cfgs=None), + head=dict( + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..733e1ef0ec471aa98e3be75bcf1b21c07a5b60f5 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='tiny', img_size=224, drop_path_rate=0.2, stage_cfgs=None), + head=dict( + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecc4363330d9747bee4c104be591924449694e1 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]), + head=dict(num_classes=21841), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +data_preprocessor = dict(num_classes=21841) + +_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop +_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge +_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..103afb42608b68bc28fa6ccc3576c09a3a9bb595 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=256, drop_path_rate=0.5, window_size=[16, 16, 16, 8]), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..6588f50fffd7b694dbc6b8f1851a5ae04eb83fd6 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=256, + window_size=[16, 16, 16, 8], + pretrained_window_sizes=[12, 12, 12, 6]), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..118c085e7550ebfe01dc98c880c8bfaf6ef2d977 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + window_size=[24, 24, 24, 12], pretrained_window_sizes=[12, 12, 12, 6])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..d40144cbba13654a0f9e07c721389419d8ac67d0 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(img_size=256, drop_path_rate=0.5), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecc4363330d9747bee4c104be591924449694e1 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]), + head=dict(num_classes=21841), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +data_preprocessor = dict(num_classes=21841) + +_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop +_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge +_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1b59df0640b456dbb90691d582bc0b0f5da85e --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Only for evaluation +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='large', + img_size=256, + window_size=[16, 16, 16, 8], + pretrained_window_sizes=[12, 12, 12, 6]), + head=dict( + in_channels=1536, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..b20bcead8410e77d63b688799d17a9913cb51f94 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py @@ -0,0 +1,24 @@ +# Only for evaluation +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='large', + img_size=384, + window_size=[24, 24, 24, 12], + pretrained_window_sizes=[12, 12, 12, 6]), + head=dict( + in_channels=1536, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd15c313954a325a9f42e2ebc2bc77a20de6cb6 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='small', + img_size=256, + drop_path_rate=0.3, + window_size=[16, 16, 16, 8]), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..bfec346617f3dd70269ac1e5f0e730f72232669a --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='small', img_size=256, drop_path_rate=0.3), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..f2fa160963da8739d17fd2570088ac578e189624 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='tiny', + img_size=256, + drop_path_rate=0.2, + window_size=[16, 16, 16, 8]), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py new file mode 100644 index 0000000000000000000000000000000000000000..8cca2b3830236b646d5a24652223acf00a683d8a --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='tiny', img_size=256, drop_path_rate=0.2), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py b/mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..18c2afdaf2b161ecda1ab28c3ab4b6445dd08f5e --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit +from torch.optim import AdamW + +from mmpretrain.engine import EMAHook +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +model.update( + backbone=dict(drop_rate=0, drop_path_rate=0.1, init_cfg=None), + head=dict(loss=dict(mode='original')), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=.02), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +train_dataloader.update(batch_size=128) + +# schedule settings +optim_wrapper.update( + optimizer=dict( + type=AdamW, + lr=1e-4 * 4096 / 256, + weight_decay=0.3, + eps=1e-8, + betas=(0.9, 0.95)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + +# runtime settings +custom_hooks = [dict(type=EMAHook, momentum=1e-4)] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..8f128d1cfa00386ab4299349135c8f422c68faa7 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + head=dict(hidden_dim=3072), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..98e01f306c1e37717cd6826fa0dba2dc74ab3807 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update(backbone=dict(img_size=384)) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..3651c93b602c8a8391b6ec2c5debb12660cf27fe --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(patch_size=32), + head=dict( + hidden_dim=3072, + topk=(1, 5), + ), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..253740cc7c313d4822b92e533f22576412caf93c --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(img_size=384, patch_size=32), head=dict(topk=(1, 5))) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..03f4a74b40c8c3c0ac4bc28c081539741ca477c8 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l'), + head=dict( + hidden_dim=3072, + in_channels=1024, + topk=(1, 5), + ), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..eba4bc45b33bb8351dce2ce2f1112d8c1a7e9701 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l', img_size=384), + head=dict(in_channels=1024, topk=(1, 5))) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py new file mode 100644 index 0000000000000000000000000000000000000000..73dae6e79d1ea37d5a74d23a805935040432daa2 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l', patch_size=32), + head=dict( + hidden_dim=3072, + in_channels=1024, + topk=(1, 5), + ), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py new file mode 100644 index 0000000000000000000000000000000000000000..82e16192f3f03dac06f5c6de0398d7f5de463461 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l', img_size=384, patch_size=32), + head=dict(in_channels=1024, topk=(1, 5))) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/datasets/.DS_Store b/mmpretrain/datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d5b0f96bee2462c095a8da206877131dfdf23152 Binary files /dev/null and b/mmpretrain/datasets/.DS_Store differ diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e621e15771480e539ce52f3264b87c19202c1602 --- /dev/null +++ b/mmpretrain/datasets/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL +from .base_dataset import BaseDataset +from .builder import build_dataset +from .caltech101 import Caltech101 +from .cifar import CIFAR10, CIFAR100 +from .cub import CUB +from .custom import CustomDataset +from .dataset_wrappers import KFoldDataset +from .dtd import DTD +from .fgvcaircraft import FGVCAircraft +from .flowers102 import Flowers102 +from .food101 import Food101 +from .imagenet import ImageNet, ImageNet21k +from .inshop import InShop +from .mnist import MNIST, FashionMNIST +from .multi_label import MultiLabelDataset +from .multi_task import MultiTaskDataset +from .nlvr2 import NLVR2 +from .oxfordiiitpet import OxfordIIITPet +from .places205 import Places205 +from .samplers import * # noqa: F401,F403 +from .stanfordcars import StanfordCars +from .sun397 import SUN397 +from .transforms import * # noqa: F401,F403 +from .voc import VOC + +__all__ = [ + 'BaseDataset', 'CIFAR10', 'CIFAR100', 'CUB', 'Caltech101', 'CustomDataset', + 'DTD', 'FGVCAircraft', 'FashionMNIST', 'Flowers102', 'Food101', 'ImageNet', + 'ImageNet21k', 'InShop', 'KFoldDataset', 'MNIST', 'MultiLabelDataset', + 'MultiTaskDataset', 'NLVR2', 'OxfordIIITPet', 'Places205', 'SUN397', + 'StanfordCars', 'VOC', 'build_dataset' +] + +if WITH_MULTIMODAL: + from .coco_caption import COCOCaption + from .coco_retrieval import COCORetrieval + from .coco_vqa import COCOVQA + from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA + from .flickr30k_caption import Flickr30kCaption + from .flickr30k_retrieval import Flickr30kRetrieval + from .gqa_dataset import GQA + from .iconqa import IconQA + from .infographic_vqa import InfographicVQA + from .minigpt4_dataset import MiniGPT4Dataset + from .nocaps import NoCaps + from .ocr_vqa import OCRVQA + from .refcoco import RefCOCO + from .scienceqa import ScienceQA + from .textvqa import TextVQA + from .visual_genome import VisualGenomeQA + from .vizwiz import VizWiz + from .vsr import VSR + + __all__.extend([ + 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', + 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', + 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', + 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA', + 'MiniGPT4Dataset' + ]) diff --git a/mmpretrain/datasets/base_dataset.py b/mmpretrain/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dffdf04772163b5fa55afabc8e15ac8c118aadd2 --- /dev/null +++ b/mmpretrain/datasets/base_dataset.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from os import PathLike +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset as _BaseDataset + +from mmpretrain.registry import DATASETS, TRANSFORMS + + +def expanduser(path): + """Expand ~ and ~user constructions. + + If user or $HOME is unknown, do nothing. + """ + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +@DATASETS.register_module() +class BaseDataset(_BaseDataset): + """Base dataset for image classification task. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + .. _OpenMMLab 2.0 style annotation format: + https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md + + Comparing with the :class:`mmengine.BaseDataset`, this class implemented + several useful methods. + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None, which means using all ``data_infos``. + serialize_data (bool): Whether to hold memory using serialized objects, + when enabled, data loader workers can use shared RAM from master + process instead of making a copy. Defaults to True. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + test_mode (bool, optional): ``test_mode=True`` means in test phase, + an error will be raised when getting an item fails, ``test_mode=False`` + means in training phase, another item will be returned randomly. + Defaults to False. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + max_refetch (int): If ``Basedataset.prepare_data`` get a None img. + The maximum extra number of cycles to get a valid image. + Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ # noqa: E501 + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: str = '', + data_prefix: Union[str, dict] = '', + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: Sequence = (), + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + classes: Union[str, Sequence[str], None] = None): + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + metainfo = self._compat_classes(metainfo, classes) + + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=transforms, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @property + def img_prefix(self): + """The prefix of images.""" + return self.data_prefix['img_path'] + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def get_gt_labels(self): + """Get all ground-truth labels (categories). + + Returns: + np.ndarray: categories for all images. + """ + + gt_labels = np.array( + [self.get_data_info(i)['gt_label'] for i in range(len(self))]) + return gt_labels + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category id by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image category of specified index. + """ + + return [int(self.get_data_info(idx)['gt_label'])] + + def _compat_classes(self, metainfo, classes): + """Merge the old style ``classes`` arguments to ``metainfo``.""" + if isinstance(classes, str): + # take it as a file path + class_names = mmengine.list_from_file(expanduser(classes)) + elif isinstance(classes, (tuple, list)): + class_names = classes + elif classes is not None: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if metainfo is None: + metainfo = {} + + if classes is not None: + metainfo = {'classes': tuple(class_names), **metainfo} + + return metainfo + + def full_init(self): + """Load annotation file and set ``BaseDataset._fully_initialized`` to + True.""" + super().full_init() + + # To support the standard OpenMMLab 2.0 annotation format. Generate + # metainfo in internal format from standard metainfo format. + if 'categories' in self._metainfo and 'classes' not in self._metainfo: + categories = sorted( + self._metainfo['categories'], key=lambda x: x['id']) + self._metainfo['classes'] = tuple( + [cat['category_name'] for cat in categories]) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + + body.extend(self.extra_repr()) + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [] + body.append(f'Annotation file: \t{self.ann_file}') + body.append(f'Prefix of images: \t{self.img_prefix}') + return body diff --git a/mmpretrain/datasets/builder.py b/mmpretrain/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa3872fe9931a4946368f07dfc5f5913a3e1f9f --- /dev/null +++ b/mmpretrain/datasets/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.registry import DATASETS + + +def build_dataset(cfg): + """Build dataset. + + Examples: + >>> from mmpretrain.datasets import build_dataset + >>> mnist_train = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False)) + >>> print(mnist_train) + Dataset MNIST + Number of samples: 60000 + Number of categories: 10 + Prefix of data: data/mnist/ + >>> mnist_test = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True)) + >>> print(mnist_test) + Dataset MNIST + Number of samples: 10000 + Number of categories: 10 + Prefix of data: data/mnist/ + """ + return DATASETS.build(cfg) diff --git a/mmpretrain/datasets/caltech101.py b/mmpretrain/datasets/caltech101.py new file mode 100644 index 0000000000000000000000000000000000000000..71e5de85ff3bbf73c387a071f47113b46be36e2a --- /dev/null +++ b/mmpretrain/datasets/caltech101.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CALTECH101_CATEGORIES + + +@DATASETS.register_module() +class Caltech101(BaseDataset): + """The Caltech101 Dataset. + + Support the `Caltech101 `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Caltech101 dataset directory: :: + + caltech-101 + ├── 101_ObjectCategories + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── Annotations + │ ├── class_x + │ │ ├── xx1.mat + │ │ └── ... + │ └── ... + ├── meta + │ ├── train.txt + │ └── test.txt + └── .... + + Please note that since there is no official splitting for training and + test set, you can use the train.txt and text.txt provided by us or + create your own annotation files. Here is the download + `link `_ + for the annotations. + + Args: + data_root (str): The root directory for the Caltech101 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import Caltech101 + >>> train_dataset = Caltech101(data_root='data/caltech-101', split='train') + >>> train_dataset + Dataset Caltech101 + Number of samples: 3060 + Number of categories: 102 + Root of dataset: data/caltech-101 + >>> test_dataset = Caltech101(data_root='data/caltech-101', split='test') + >>> test_dataset + Dataset Caltech101 + Number of samples: 6728 + Number of categories: 102 + Root of dataset: data/caltech-101 + """ # noqa: E501 + + METAINFO = {'classes': CALTECH101_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + + if split == 'train': + ann_file = self.backend.join_path('meta', 'train.txt') + else: + ann_file = self.backend.join_path('meta', 'test.txt') + + data_prefix = '101_ObjectCategories' + test_mode = split == 'test' + + super(Caltech101, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + + for pair in pairs: + path, gt_label = pair.split() + img_path = self.backend.join_path(self.img_prefix, path) + info = dict(img_path=img_path, gt_label=int(gt_label)) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..9e75f7953b8f41750e2517d28c76047bfe37330a --- /dev/null +++ b/mmpretrain/datasets/categories.py @@ -0,0 +1,1661 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Pre-defined categories names of various datasets. + +VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', + 'sofa', 'train', 'tvmonitor') + +CUB_CATEGORIES = ( + 'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross', + 'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet', 'Parakeet_Auklet', + 'Rhinoceros_Auklet', 'Brewer_Blackbird', 'Red_winged_Blackbird', + 'Rusty_Blackbird', 'Yellow_headed_Blackbird', 'Bobolink', 'Indigo_Bunting', + 'Lazuli_Bunting', 'Painted_Bunting', 'Cardinal', 'Spotted_Catbird', + 'Gray_Catbird', 'Yellow_breasted_Chat', 'Eastern_Towhee', + 'Chuck_will_Widow', 'Brandt_Cormorant', 'Red_faced_Cormorant', + 'Pelagic_Cormorant', 'Bronzed_Cowbird', 'Shiny_Cowbird', 'Brown_Creeper', + 'American_Crow', 'Fish_Crow', 'Black_billed_Cuckoo', 'Mangrove_Cuckoo', + 'Yellow_billed_Cuckoo', 'Gray_crowned_Rosy_Finch', 'Purple_Finch', + 'Northern_Flicker', 'Acadian_Flycatcher', 'Great_Crested_Flycatcher', + 'Least_Flycatcher', 'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher', + 'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird', + 'Northern_Fulmar', 'Gadwall', 'American_Goldfinch', 'European_Goldfinch', + 'Boat_tailed_Grackle', 'Eared_Grebe', 'Horned_Grebe', 'Pied_billed_Grebe', + 'Western_Grebe', 'Blue_Grosbeak', 'Evening_Grosbeak', 'Pine_Grosbeak', + 'Rose_breasted_Grosbeak', 'Pigeon_Guillemot', 'California_Gull', + 'Glaucous_winged_Gull', 'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull', + 'Ring_billed_Gull', 'Slaty_backed_Gull', 'Western_Gull', + 'Anna_Hummingbird', 'Ruby_throated_Hummingbird', 'Rufous_Hummingbird', + 'Green_Violetear', 'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay', + 'Florida_Jay', 'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird', + 'Gray_Kingbird', 'Belted_Kingfisher', 'Green_Kingfisher', + 'Pied_Kingfisher', 'Ringed_Kingfisher', 'White_breasted_Kingfisher', + 'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard', + 'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser', + 'Mockingbird', 'Nighthawk', 'Clark_Nutcracker', 'White_breasted_Nuthatch', + 'Baltimore_Oriole', 'Hooded_Oriole', 'Orchard_Oriole', 'Scott_Oriole', + 'Ovenbird', 'Brown_Pelican', 'White_Pelican', 'Western_Wood_Pewee', + 'Sayornis', 'American_Pipit', 'Whip_poor_Will', 'Horned_Puffin', + 'Common_Raven', 'White_necked_Raven', 'American_Redstart', 'Geococcyx', + 'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow', + 'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow', + 'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow', 'Fox_Sparrow', + 'Grasshopper_Sparrow', 'Harris_Sparrow', 'Henslow_Sparrow', + 'Le_Conte_Sparrow', 'Lincoln_Sparrow', 'Nelson_Sharp_tailed_Sparrow', + 'Savannah_Sparrow', 'Seaside_Sparrow', 'Song_Sparrow', 'Tree_Sparrow', + 'Vesper_Sparrow', 'White_crowned_Sparrow', 'White_throated_Sparrow', + 'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow', 'Cliff_Swallow', + 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager', 'Artic_Tern', + 'Black_Tern', 'Caspian_Tern', 'Common_Tern', 'Elegant_Tern', + 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee', 'Brown_Thrasher', + 'Sage_Thrasher', 'Black_capped_Vireo', 'Blue_headed_Vireo', + 'Philadelphia_Vireo', 'Red_eyed_Vireo', 'Warbling_Vireo', + 'White_eyed_Vireo', 'Yellow_throated_Vireo', 'Bay_breasted_Warbler', + 'Black_and_white_Warbler', 'Black_throated_Blue_Warbler', + 'Blue_winged_Warbler', 'Canada_Warbler', 'Cape_May_Warbler', + 'Cerulean_Warbler', 'Chestnut_sided_Warbler', 'Golden_winged_Warbler', + 'Hooded_Warbler', 'Kentucky_Warbler', 'Magnolia_Warbler', + 'Mourning_Warbler', 'Myrtle_Warbler', 'Nashville_Warbler', + 'Orange_crowned_Warbler', 'Palm_Warbler', 'Pine_Warbler', + 'Prairie_Warbler', 'Prothonotary_Warbler', 'Swainson_Warbler', + 'Tennessee_Warbler', 'Wilson_Warbler', 'Worm_eating_Warbler', + 'Yellow_Warbler', 'Northern_Waterthrush', 'Louisiana_Waterthrush', + 'Bohemian_Waxwing', 'Cedar_Waxwing', 'American_Three_toed_Woodpecker', + 'Pileated_Woodpecker', 'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker', + 'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren', 'Cactus_Wren', + 'Carolina_Wren', 'House_Wren', 'Marsh_Wren', 'Rock_Wren', 'Winter_Wren', + 'Common_Yellowthroat') + +IMAGENET_CATEGORIES = ( + 'tench, Tinca tinca', + 'goldfish, Carassius auratus', + 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501 + 'tiger shark, Galeocerdo cuvieri', + 'hammerhead, hammerhead shark', + 'electric ray, crampfish, numbfish, torpedo', + 'stingray', + 'cock', + 'hen', + 'ostrich, Struthio camelus', + 'brambling, Fringilla montifringilla', + 'goldfinch, Carduelis carduelis', + 'house finch, linnet, Carpodacus mexicanus', + 'junco, snowbird', + 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 'robin, American robin, Turdus migratorius', + 'bulbul', + 'jay', + 'magpie', + 'chickadee', + 'water ouzel, dipper', + 'kite', + 'bald eagle, American eagle, Haliaeetus leucocephalus', + 'vulture', + 'great grey owl, great gray owl, Strix nebulosa', + 'European fire salamander, Salamandra salamandra', + 'common newt, Triturus vulgaris', + 'eft', + 'spotted salamander, Ambystoma maculatum', + 'axolotl, mud puppy, Ambystoma mexicanum', + 'bullfrog, Rana catesbeiana', + 'tree frog, tree-frog', + 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 'loggerhead, loggerhead turtle, Caretta caretta', + 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501 + 'mud turtle', + 'terrapin', + 'box turtle, box tortoise', + 'banded gecko', + 'common iguana, iguana, Iguana iguana', + 'American chameleon, anole, Anolis carolinensis', + 'whiptail, whiptail lizard', + 'agama', + 'frilled lizard, Chlamydosaurus kingi', + 'alligator lizard', + 'Gila monster, Heloderma suspectum', + 'green lizard, Lacerta viridis', + 'African chameleon, Chamaeleo chamaeleon', + 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501 + 'African crocodile, Nile crocodile, Crocodylus niloticus', + 'American alligator, Alligator mississipiensis', + 'triceratops', + 'thunder snake, worm snake, Carphophis amoenus', + 'ringneck snake, ring-necked snake, ring snake', + 'hognose snake, puff adder, sand viper', + 'green snake, grass snake', + 'king snake, kingsnake', + 'garter snake, grass snake', + 'water snake', + 'vine snake', + 'night snake, Hypsiglena torquata', + 'boa constrictor, Constrictor constrictor', + 'rock python, rock snake, Python sebae', + 'Indian cobra, Naja naja', + 'green mamba', + 'sea snake', + 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 'sidewinder, horned rattlesnake, Crotalus cerastes', + 'trilobite', + 'harvestman, daddy longlegs, Phalangium opilio', + 'scorpion', + 'black and gold garden spider, Argiope aurantia', + 'barn spider, Araneus cavaticus', + 'garden spider, Aranea diademata', + 'black widow, Latrodectus mactans', + 'tarantula', + 'wolf spider, hunting spider', + 'tick', + 'centipede', + 'black grouse', + 'ptarmigan', + 'ruffed grouse, partridge, Bonasa umbellus', + 'prairie chicken, prairie grouse, prairie fowl', + 'peacock', + 'quail', + 'partridge', + 'African grey, African gray, Psittacus erithacus', + 'macaw', + 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 'lorikeet', + 'coucal', + 'bee eater', + 'hornbill', + 'hummingbird', + 'jacamar', + 'toucan', + 'drake', + 'red-breasted merganser, Mergus serrator', + 'goose', + 'black swan, Cygnus atratus', + 'tusker', + 'echidna, spiny anteater, anteater', + 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501 + 'wallaby, brush kangaroo', + 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501 + 'wombat', + 'jellyfish', + 'sea anemone, anemone', + 'brain coral', + 'flatworm, platyhelminth', + 'nematode, nematode worm, roundworm', + 'conch', + 'snail', + 'slug', + 'sea slug, nudibranch', + 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 'chambered nautilus, pearly nautilus, nautilus', + 'Dungeness crab, Cancer magister', + 'rock crab, Cancer irroratus', + 'fiddler crab', + 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501 + 'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501 + 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501 + 'crayfish, crawfish, crawdad, crawdaddy', + 'hermit crab', + 'isopod', + 'white stork, Ciconia ciconia', + 'black stork, Ciconia nigra', + 'spoonbill', + 'flamingo', + 'little blue heron, Egretta caerulea', + 'American egret, great white heron, Egretta albus', + 'bittern', + 'crane', + 'limpkin, Aramus pictus', + 'European gallinule, Porphyrio porphyrio', + 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 'bustard', + 'ruddy turnstone, Arenaria interpres', + 'red-backed sandpiper, dunlin, Erolia alpina', + 'redshank, Tringa totanus', + 'dowitcher', + 'oystercatcher, oyster catcher', + 'pelican', + 'king penguin, Aptenodytes patagonica', + 'albatross, mollymawk', + 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501 + 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 'dugong, Dugong dugon', + 'sea lion', + 'Chihuahua', + 'Japanese spaniel', + 'Maltese dog, Maltese terrier, Maltese', + 'Pekinese, Pekingese, Peke', + 'Shih-Tzu', + 'Blenheim spaniel', + 'papillon', + 'toy terrier', + 'Rhodesian ridgeback', + 'Afghan hound, Afghan', + 'basset, basset hound', + 'beagle', + 'bloodhound, sleuthhound', + 'bluetick', + 'black-and-tan coonhound', + 'Walker hound, Walker foxhound', + 'English foxhound', + 'redbone', + 'borzoi, Russian wolfhound', + 'Irish wolfhound', + 'Italian greyhound', + 'whippet', + 'Ibizan hound, Ibizan Podenco', + 'Norwegian elkhound, elkhound', + 'otterhound, otter hound', + 'Saluki, gazelle hound', + 'Scottish deerhound, deerhound', + 'Weimaraner', + 'Staffordshire bullterrier, Staffordshire bull terrier', + 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501 + 'Bedlington terrier', + 'Border terrier', + 'Kerry blue terrier', + 'Irish terrier', + 'Norfolk terrier', + 'Norwich terrier', + 'Yorkshire terrier', + 'wire-haired fox terrier', + 'Lakeland terrier', + 'Sealyham terrier, Sealyham', + 'Airedale, Airedale terrier', + 'cairn, cairn terrier', + 'Australian terrier', + 'Dandie Dinmont, Dandie Dinmont terrier', + 'Boston bull, Boston terrier', + 'miniature schnauzer', + 'giant schnauzer', + 'standard schnauzer', + 'Scotch terrier, Scottish terrier, Scottie', + 'Tibetan terrier, chrysanthemum dog', + 'silky terrier, Sydney silky', + 'soft-coated wheaten terrier', + 'West Highland white terrier', + 'Lhasa, Lhasa apso', + 'flat-coated retriever', + 'curly-coated retriever', + 'golden retriever', + 'Labrador retriever', + 'Chesapeake Bay retriever', + 'German short-haired pointer', + 'vizsla, Hungarian pointer', + 'English setter', + 'Irish setter, red setter', + 'Gordon setter', + 'Brittany spaniel', + 'clumber, clumber spaniel', + 'English springer, English springer spaniel', + 'Welsh springer spaniel', + 'cocker spaniel, English cocker spaniel, cocker', + 'Sussex spaniel', + 'Irish water spaniel', + 'kuvasz', + 'schipperke', + 'groenendael', + 'malinois', + 'briard', + 'kelpie', + 'komondor', + 'Old English sheepdog, bobtail', + 'Shetland sheepdog, Shetland sheep dog, Shetland', + 'collie', + 'Border collie', + 'Bouvier des Flandres, Bouviers des Flandres', + 'Rottweiler', + 'German shepherd, German shepherd dog, German police dog, alsatian', + 'Doberman, Doberman pinscher', + 'miniature pinscher', + 'Greater Swiss Mountain dog', + 'Bernese mountain dog', + 'Appenzeller', + 'EntleBucher', + 'boxer', + 'bull mastiff', + 'Tibetan mastiff', + 'French bulldog', + 'Great Dane', + 'Saint Bernard, St Bernard', + 'Eskimo dog, husky', + 'malamute, malemute, Alaskan malamute', + 'Siberian husky', + 'dalmatian, coach dog, carriage dog', + 'affenpinscher, monkey pinscher, monkey dog', + 'basenji', + 'pug, pug-dog', + 'Leonberg', + 'Newfoundland, Newfoundland dog', + 'Great Pyrenees', + 'Samoyed, Samoyede', + 'Pomeranian', + 'chow, chow chow', + 'keeshond', + 'Brabancon griffon', + 'Pembroke, Pembroke Welsh corgi', + 'Cardigan, Cardigan Welsh corgi', + 'toy poodle', + 'miniature poodle', + 'standard poodle', + 'Mexican hairless', + 'timber wolf, grey wolf, gray wolf, Canis lupus', + 'white wolf, Arctic wolf, Canis lupus tundrarum', + 'red wolf, maned wolf, Canis rufus, Canis niger', + 'coyote, prairie wolf, brush wolf, Canis latrans', + 'dingo, warrigal, warragal, Canis dingo', + 'dhole, Cuon alpinus', + 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 'hyena, hyaena', + 'red fox, Vulpes vulpes', + 'kit fox, Vulpes macrotis', + 'Arctic fox, white fox, Alopex lagopus', + 'grey fox, gray fox, Urocyon cinereoargenteus', + 'tabby, tabby cat', + 'tiger cat', + 'Persian cat', + 'Siamese cat, Siamese', + 'Egyptian cat', + 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501 + 'lynx, catamount', + 'leopard, Panthera pardus', + 'snow leopard, ounce, Panthera uncia', + 'jaguar, panther, Panthera onca, Felis onca', + 'lion, king of beasts, Panthera leo', + 'tiger, Panthera tigris', + 'cheetah, chetah, Acinonyx jubatus', + 'brown bear, bruin, Ursus arctos', + 'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501 + 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 'sloth bear, Melursus ursinus, Ursus ursinus', + 'mongoose', + 'meerkat, mierkat', + 'tiger beetle', + 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 'ground beetle, carabid beetle', + 'long-horned beetle, longicorn, longicorn beetle', + 'leaf beetle, chrysomelid', + 'dung beetle', + 'rhinoceros beetle', + 'weevil', + 'fly', + 'bee', + 'ant, emmet, pismire', + 'grasshopper, hopper', + 'cricket', + 'walking stick, walkingstick, stick insect', + 'cockroach, roach', + 'mantis, mantid', + 'cicada, cicala', + 'leafhopper', + 'lacewing, lacewing fly', + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501 + 'damselfly', + 'admiral', + 'ringlet, ringlet butterfly', + 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 'cabbage butterfly', + 'sulphur butterfly, sulfur butterfly', + 'lycaenid, lycaenid butterfly', + 'starfish, sea star', + 'sea urchin', + 'sea cucumber, holothurian', + 'wood rabbit, cottontail, cottontail rabbit', + 'hare', + 'Angora, Angora rabbit', + 'hamster', + 'porcupine, hedgehog', + 'fox squirrel, eastern fox squirrel, Sciurus niger', + 'marmot', + 'beaver', + 'guinea pig, Cavia cobaya', + 'sorrel', + 'zebra', + 'hog, pig, grunter, squealer, Sus scrofa', + 'wild boar, boar, Sus scrofa', + 'warthog', + 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 'ox', + 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 'bison', + 'ram, tup', + 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501 + 'ibex, Capra ibex', + 'hartebeest', + 'impala, Aepyceros melampus', + 'gazelle', + 'Arabian camel, dromedary, Camelus dromedarius', + 'llama', + 'weasel', + 'mink', + 'polecat, fitch, foulmart, foumart, Mustela putorius', + 'black-footed ferret, ferret, Mustela nigripes', + 'otter', + 'skunk, polecat, wood pussy', + 'badger', + 'armadillo', + 'three-toed sloth, ai, Bradypus tridactylus', + 'orangutan, orang, orangutang, Pongo pygmaeus', + 'gorilla, Gorilla gorilla', + 'chimpanzee, chimp, Pan troglodytes', + 'gibbon, Hylobates lar', + 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 'guenon, guenon monkey', + 'patas, hussar monkey, Erythrocebus patas', + 'baboon', + 'macaque', + 'langur', + 'colobus, colobus monkey', + 'proboscis monkey, Nasalis larvatus', + 'marmoset', + 'capuchin, ringtail, Cebus capucinus', + 'howler monkey, howler', + 'titi, titi monkey', + 'spider monkey, Ateles geoffroyi', + 'squirrel monkey, Saimiri sciureus', + 'Madagascar cat, ring-tailed lemur, Lemur catta', + 'indri, indris, Indri indri, Indri brevicaudatus', + 'Indian elephant, Elephas maximus', + 'African elephant, Loxodonta africana', + 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 'barracouta, snoek', + 'eel', + 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501 + 'rock beauty, Holocanthus tricolor', + 'anemone fish', + 'sturgeon', + 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 'lionfish', + 'puffer, pufferfish, blowfish, globefish', + 'abacus', + 'abaya', + "academic gown, academic robe, judge's robe", + 'accordion, piano accordion, squeeze box', + 'acoustic guitar', + 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 'airliner', + 'airship, dirigible', + 'altar', + 'ambulance', + 'amphibian, amphibious vehicle', + 'analog clock', + 'apiary, bee house', + 'apron', + 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501 + 'assault rifle, assault gun', + 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 'bakery, bakeshop, bakehouse', + 'balance beam, beam', + 'balloon', + 'ballpoint, ballpoint pen, ballpen, Biro', + 'Band Aid', + 'banjo', + 'bannister, banister, balustrade, balusters, handrail', + 'barbell', + 'barber chair', + 'barbershop', + 'barn', + 'barometer', + 'barrel, cask', + 'barrow, garden cart, lawn cart, wheelbarrow', + 'baseball', + 'basketball', + 'bassinet', + 'bassoon', + 'bathing cap, swimming cap', + 'bath towel', + 'bathtub, bathing tub, bath, tub', + 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501 + 'beacon, lighthouse, beacon light, pharos', + 'beaker', + 'bearskin, busby, shako', + 'beer bottle', + 'beer glass', + 'bell cote, bell cot', + 'bib', + 'bicycle-built-for-two, tandem bicycle, tandem', + 'bikini, two-piece', + 'binder, ring-binder', + 'binoculars, field glasses, opera glasses', + 'birdhouse', + 'boathouse', + 'bobsled, bobsleigh, bob', + 'bolo tie, bolo, bola tie, bola', + 'bonnet, poke bonnet', + 'bookcase', + 'bookshop, bookstore, bookstall', + 'bottlecap', + 'bow', + 'bow tie, bow-tie, bowtie', + 'brass, memorial tablet, plaque', + 'brassiere, bra, bandeau', + 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 'breastplate, aegis, egis', + 'broom', + 'bucket, pail', + 'buckle', + 'bulletproof vest', + 'bullet train, bullet', + 'butcher shop, meat market', + 'cab, hack, taxi, taxicab', + 'caldron, cauldron', + 'candle, taper, wax light', + 'cannon', + 'canoe', + 'can opener, tin opener', + 'cardigan', + 'car mirror', + 'carousel, carrousel, merry-go-round, roundabout, whirligig', + "carpenter's kit, tool kit", + 'carton', + 'car wheel', + 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501 + 'cassette', + 'cassette player', + 'castle', + 'catamaran', + 'CD player', + 'cello, violoncello', + 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 'chain', + 'chainlink fence', + 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501 + 'chain saw, chainsaw', + 'chest', + 'chiffonier, commode', + 'chime, bell, gong', + 'china cabinet, china closet', + 'Christmas stocking', + 'church, church building', + 'cinema, movie theater, movie theatre, movie house, picture palace', + 'cleaver, meat cleaver, chopper', + 'cliff dwelling', + 'cloak', + 'clog, geta, patten, sabot', + 'cocktail shaker', + 'coffee mug', + 'coffeepot', + 'coil, spiral, volute, whorl, helix', + 'combination lock', + 'computer keyboard, keypad', + 'confectionery, confectionary, candy store', + 'container ship, containership, container vessel', + 'convertible', + 'corkscrew, bottle screw', + 'cornet, horn, trumpet, trump', + 'cowboy boot', + 'cowboy hat, ten-gallon hat', + 'cradle', + 'crane', + 'crash helmet', + 'crate', + 'crib, cot', + 'Crock Pot', + 'croquet ball', + 'crutch', + 'cuirass', + 'dam, dike, dyke', + 'desk', + 'desktop computer', + 'dial telephone, dial phone', + 'diaper, nappy, napkin', + 'digital clock', + 'digital watch', + 'dining table, board', + 'dishrag, dishcloth', + 'dishwasher, dish washer, dishwashing machine', + 'disk brake, disc brake', + 'dock, dockage, docking facility', + 'dogsled, dog sled, dog sleigh', + 'dome', + 'doormat, welcome mat', + 'drilling platform, offshore rig', + 'drum, membranophone, tympan', + 'drumstick', + 'dumbbell', + 'Dutch oven', + 'electric fan, blower', + 'electric guitar', + 'electric locomotive', + 'entertainment center', + 'envelope', + 'espresso maker', + 'face powder', + 'feather boa, boa', + 'file, file cabinet, filing cabinet', + 'fireboat', + 'fire engine, fire truck', + 'fire screen, fireguard', + 'flagpole, flagstaff', + 'flute, transverse flute', + 'folding chair', + 'football helmet', + 'forklift', + 'fountain', + 'fountain pen', + 'four-poster', + 'freight car', + 'French horn, horn', + 'frying pan, frypan, skillet', + 'fur coat', + 'garbage truck, dustcart', + 'gasmask, respirator, gas helmet', + 'gas pump, gasoline pump, petrol pump, island dispenser', + 'goblet', + 'go-kart', + 'golf ball', + 'golfcart, golf cart', + 'gondola', + 'gong, tam-tam', + 'gown', + 'grand piano, grand', + 'greenhouse, nursery, glasshouse', + 'grille, radiator grille', + 'grocery store, grocery, food market, market', + 'guillotine', + 'hair slide', + 'hair spray', + 'half track', + 'hammer', + 'hamper', + 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 'hand-held computer, hand-held microcomputer', + 'handkerchief, hankie, hanky, hankey', + 'hard disc, hard disk, fixed disk', + 'harmonica, mouth organ, harp, mouth harp', + 'harp', + 'harvester, reaper', + 'hatchet', + 'holster', + 'home theater, home theatre', + 'honeycomb', + 'hook, claw', + 'hoopskirt, crinoline', + 'horizontal bar, high bar', + 'horse cart, horse-cart', + 'hourglass', + 'iPod', + 'iron, smoothing iron', + "jack-o'-lantern", + 'jean, blue jean, denim', + 'jeep, landrover', + 'jersey, T-shirt, tee shirt', + 'jigsaw puzzle', + 'jinrikisha, ricksha, rickshaw', + 'joystick', + 'kimono', + 'knee pad', + 'knot', + 'lab coat, laboratory coat', + 'ladle', + 'lampshade, lamp shade', + 'laptop, laptop computer', + 'lawn mower, mower', + 'lens cap, lens cover', + 'letter opener, paper knife, paperknife', + 'library', + 'lifeboat', + 'lighter, light, igniter, ignitor', + 'limousine, limo', + 'liner, ocean liner', + 'lipstick, lip rouge', + 'Loafer', + 'lotion', + 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501 + "loupe, jeweler's loupe", + 'lumbermill, sawmill', + 'magnetic compass', + 'mailbag, postbag', + 'mailbox, letter box', + 'maillot', + 'maillot, tank suit', + 'manhole cover', + 'maraca', + 'marimba, xylophone', + 'mask', + 'matchstick', + 'maypole', + 'maze, labyrinth', + 'measuring cup', + 'medicine chest, medicine cabinet', + 'megalith, megalithic structure', + 'microphone, mike', + 'microwave, microwave oven', + 'military uniform', + 'milk can', + 'minibus', + 'miniskirt, mini', + 'minivan', + 'missile', + 'mitten', + 'mixing bowl', + 'mobile home, manufactured home', + 'Model T', + 'modem', + 'monastery', + 'monitor', + 'moped', + 'mortar', + 'mortarboard', + 'mosque', + 'mosquito net', + 'motor scooter, scooter', + 'mountain bike, all-terrain bike, off-roader', + 'mountain tent', + 'mouse, computer mouse', + 'mousetrap', + 'moving van', + 'muzzle', + 'nail', + 'neck brace', + 'necklace', + 'nipple', + 'notebook, notebook computer', + 'obelisk', + 'oboe, hautboy, hautbois', + 'ocarina, sweet potato', + 'odometer, hodometer, mileometer, milometer', + 'oil filter', + 'organ, pipe organ', + 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 'overskirt', + 'oxcart', + 'oxygen mask', + 'packet', + 'paddle, boat paddle', + 'paddlewheel, paddle wheel', + 'padlock', + 'paintbrush', + "pajama, pyjama, pj's, jammies", + 'palace', + 'panpipe, pandean pipe, syrinx', + 'paper towel', + 'parachute, chute', + 'parallel bars, bars', + 'park bench', + 'parking meter', + 'passenger car, coach, carriage', + 'patio, terrace', + 'pay-phone, pay-station', + 'pedestal, plinth, footstall', + 'pencil box, pencil case', + 'pencil sharpener', + 'perfume, essence', + 'Petri dish', + 'photocopier', + 'pick, plectrum, plectron', + 'pickelhaube', + 'picket fence, paling', + 'pickup, pickup truck', + 'pier', + 'piggy bank, penny bank', + 'pill bottle', + 'pillow', + 'ping-pong ball', + 'pinwheel', + 'pirate, pirate ship', + 'pitcher, ewer', + "plane, carpenter's plane, woodworking plane", + 'planetarium', + 'plastic bag', + 'plate rack', + 'plow, plough', + "plunger, plumber's helper", + 'Polaroid camera, Polaroid Land camera', + 'pole', + 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501 + 'poncho', + 'pool table, billiard table, snooker table', + 'pop bottle, soda bottle', + 'pot, flowerpot', + "potter's wheel", + 'power drill', + 'prayer rug, prayer mat', + 'printer', + 'prison, prison house', + 'projectile, missile', + 'projector', + 'puck, hockey puck', + 'punching bag, punch bag, punching ball, punchball', + 'purse', + 'quill, quill pen', + 'quilt, comforter, comfort, puff', + 'racer, race car, racing car', + 'racket, racquet', + 'radiator', + 'radio, wireless', + 'radio telescope, radio reflector', + 'rain barrel', + 'recreational vehicle, RV, R.V.', + 'reel', + 'reflex camera', + 'refrigerator, icebox', + 'remote control, remote', + 'restaurant, eating house, eating place, eatery', + 'revolver, six-gun, six-shooter', + 'rifle', + 'rocking chair, rocker', + 'rotisserie', + 'rubber eraser, rubber, pencil eraser', + 'rugby ball', + 'rule, ruler', + 'running shoe', + 'safe', + 'safety pin', + 'saltshaker, salt shaker', + 'sandal', + 'sarong', + 'sax, saxophone', + 'scabbard', + 'scale, weighing machine', + 'school bus', + 'schooner', + 'scoreboard', + 'screen, CRT screen', + 'screw', + 'screwdriver', + 'seat belt, seatbelt', + 'sewing machine', + 'shield, buckler', + 'shoe shop, shoe-shop, shoe store', + 'shoji', + 'shopping basket', + 'shopping cart', + 'shovel', + 'shower cap', + 'shower curtain', + 'ski', + 'ski mask', + 'sleeping bag', + 'slide rule, slipstick', + 'sliding door', + 'slot, one-armed bandit', + 'snorkel', + 'snowmobile', + 'snowplow, snowplough', + 'soap dispenser', + 'soccer ball', + 'sock', + 'solar dish, solar collector, solar furnace', + 'sombrero', + 'soup bowl', + 'space bar', + 'space heater', + 'space shuttle', + 'spatula', + 'speedboat', + "spider web, spider's web", + 'spindle', + 'sports car, sport car', + 'spotlight, spot', + 'stage', + 'steam locomotive', + 'steel arch bridge', + 'steel drum', + 'stethoscope', + 'stole', + 'stone wall', + 'stopwatch, stop watch', + 'stove', + 'strainer', + 'streetcar, tram, tramcar, trolley, trolley car', + 'stretcher', + 'studio couch, day bed', + 'stupa, tope', + 'submarine, pigboat, sub, U-boat', + 'suit, suit of clothes', + 'sundial', + 'sunglass', + 'sunglasses, dark glasses, shades', + 'sunscreen, sunblock, sun blocker', + 'suspension bridge', + 'swab, swob, mop', + 'sweatshirt', + 'swimming trunks, bathing trunks', + 'swing', + 'switch, electric switch, electrical switch', + 'syringe', + 'table lamp', + 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 'tape player', + 'teapot', + 'teddy, teddy bear', + 'television, television system', + 'tennis ball', + 'thatch, thatched roof', + 'theater curtain, theatre curtain', + 'thimble', + 'thresher, thrasher, threshing machine', + 'throne', + 'tile roof', + 'toaster', + 'tobacco shop, tobacconist shop, tobacconist', + 'toilet seat', + 'torch', + 'totem pole', + 'tow truck, tow car, wrecker', + 'toyshop', + 'tractor', + 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501 + 'tray', + 'trench coat', + 'tricycle, trike, velocipede', + 'trimaran', + 'tripod', + 'triumphal arch', + 'trolleybus, trolley coach, trackless trolley', + 'trombone', + 'tub, vat', + 'turnstile', + 'typewriter keyboard', + 'umbrella', + 'unicycle, monocycle', + 'upright, upright piano', + 'vacuum, vacuum cleaner', + 'vase', + 'vault', + 'velvet', + 'vending machine', + 'vestment', + 'viaduct', + 'violin, fiddle', + 'volleyball', + 'waffle iron', + 'wall clock', + 'wallet, billfold, notecase, pocketbook', + 'wardrobe, closet, press', + 'warplane, military plane', + 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 'washer, automatic washer, washing machine', + 'water bottle', + 'water jug', + 'water tower', + 'whiskey jug', + 'whistle', + 'wig', + 'window screen', + 'window shade', + 'Windsor tie', + 'wine bottle', + 'wing', + 'wok', + 'wooden spoon', + 'wool, woolen, woollen', + 'worm fence, snake fence, snake-rail fence, Virginia fence', + 'wreck', + 'yawl', + 'yurt', + 'web site, website, internet site, site', + 'comic book', + 'crossword puzzle, crossword', + 'street sign', + 'traffic light, traffic signal, stoplight', + 'book jacket, dust cover, dust jacket, dust wrapper', + 'menu', + 'plate', + 'guacamole', + 'consomme', + 'hot pot, hotpot', + 'trifle', + 'ice cream, icecream', + 'ice lolly, lolly, lollipop, popsicle', + 'French loaf', + 'bagel, beigel', + 'pretzel', + 'cheeseburger', + 'hotdog, hot dog, red hot', + 'mashed potato', + 'head cabbage', + 'broccoli', + 'cauliflower', + 'zucchini, courgette', + 'spaghetti squash', + 'acorn squash', + 'butternut squash', + 'cucumber, cuke', + 'artichoke, globe artichoke', + 'bell pepper', + 'cardoon', + 'mushroom', + 'Granny Smith', + 'strawberry', + 'orange', + 'lemon', + 'fig', + 'pineapple, ananas', + 'banana', + 'jackfruit, jak, jack', + 'custard apple', + 'pomegranate', + 'hay', + 'carbonara', + 'chocolate sauce, chocolate syrup', + 'dough', + 'meat loaf, meatloaf', + 'pizza, pizza pie', + 'potpie', + 'burrito', + 'red wine', + 'espresso', + 'cup', + 'eggnog', + 'alp', + 'bubble', + 'cliff, drop, drop-off', + 'coral reef', + 'geyser', + 'lakeside, lakeshore', + 'promontory, headland, head, foreland', + 'sandbar, sand bar', + 'seashore, coast, seacoast, sea-coast', + 'valley, vale', + 'volcano', + 'ballplayer, baseball player', + 'groom, bridegroom', + 'scuba diver', + 'rapeseed', + 'daisy', + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501 + 'corn', + 'acorn', + 'hip, rose hip, rosehip', + 'buckeye, horse chestnut, conker', + 'coral fungus', + 'agaric', + 'gyromitra', + 'stinkhorn, carrion fungus', + 'earthstar', + 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501 + 'bolete', + 'ear, spike, capitulum', + 'toilet tissue, toilet paper, bathroom tissue') + +CIFAR10_CATEGORIES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', + 'frog', 'horse', 'ship', 'truck') + +CIFAR100_CATEGORIES = ( + 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', + 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', + 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', + 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', + 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', + 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', + 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', + 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', + 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', + 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', + 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', + 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', + 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', + 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', + 'woman', 'worm') + +MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', + '9 - nine') + +FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', + 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', + 'Ankle boot') + +PLACES205_CATEGORIES = ( + 'abbey', 'airport_terminal', 'alley', 'amphitheater', 'amusement_park', + 'aquarium', 'aqueduct', 'arch', 'art_gallery', 'art_studio', + 'assembly_line', 'attic', 'auditorium', 'apartment_building/outdoor', + 'badlands', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', + 'baseball_field', 'basement', 'basilica', 'bayou', 'beauty_salon', + 'bedroom', 'boardwalk', 'boat_deck', 'bookstore', 'botanical_garden', + 'bowling_alley', 'boxing_ring', 'bridge', 'building_facade', + 'bus_interior', 'butchers_shop', 'butte', 'bakery/shop', 'cafeteria', + 'campsite', 'candy_store', 'canyon', 'castle', 'cemetery', 'chalet', + 'classroom', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', + 'conference_center', 'conference_room', 'construction_site', 'corn_field', + 'corridor', 'cottage_garden', 'courthouse', 'courtyard', 'creek', + 'crevasse', 'crosswalk', 'cathedral/outdoor', 'church/outdoor', 'dam', + 'dining_room', 'dock', 'dorm_room', 'driveway', 'desert/sand', + 'desert/vegetation', 'dinette/home', 'doorway/outdoor', 'engine_room', + 'excavation', 'fairway', 'fire_escape', 'fire_station', 'food_court', + 'forest_path', 'forest_road', 'formal_garden', 'fountain', + 'field/cultivated', 'field/wild', 'galley', 'game_room', 'garbage_dump', + 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'herb_garden', + 'highway', 'home_office', 'hospital', 'hospital_room', 'hot_spring', + 'hotel_room', 'hotel/outdoor', 'ice_cream_parlor', 'iceberg', 'igloo', + 'islet', 'ice_skating_rink/outdoor', 'inn/outdoor', 'jail_cell', 'kasbah', + 'kindergarden_classroom', 'kitchen', 'kitchenette', 'laundromat', + 'lighthouse', 'living_room', 'lobby', 'locker_room', 'mansion', 'marsh', + 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', + 'mountain_snowy', 'music_studio', 'market/outdoor', 'monastery/outdoor', + 'museum/indoor', 'nursery', 'ocean', 'office', 'office_building', + 'orchard', 'pagoda', 'palace', 'pantry', 'parking_lot', 'parlor', + 'pasture', 'patio', 'pavilion', 'phone_booth', 'picnic_area', 'playground', + 'plaza', 'pond', 'pulpit', 'racecourse', 'raft', 'railroad_track', + 'rainforest', 'reception', 'residential_neighborhood', 'restaurant', + 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'river', + 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'schoolhouse', + 'sea_cliff', 'shed', 'shoe_shop', 'shopfront', 'shower', 'ski_resort', + 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'staircase', + 'supermarket', 'swamp', 'stadium/baseball', 'stadium/football', + 'stage/indoor', 'subway_station/platform', 'swimming_pool/outdoor', + 'television_studio', 'topiary_garden', 'tower', 'train_railway', + 'tree_farm', 'trench', 'temple/east_asia', 'temple/south_asia', + 'track/outdoor', 'train_station/platform', 'underwater/coral_reef', + 'valley', 'vegetable_garden', 'veranda', 'viaduct', 'volcano', + 'waiting_room', 'water_tower', 'watering_hole', 'wheat_field', 'wind_farm', + 'windmill', 'yard') + +OxfordIIITPet_CATEGORIES = ( + 'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', + 'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer', + 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel', + 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', + 'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon', + 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', + 'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier', + 'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier', + 'wheaten_terrier', 'yorkshire_terrier') + +DTD_CATEGORIES = ('banded', 'blotchy', 'braided', 'bubbly', 'bumpy', + 'chequered', 'cobwebbed', 'cracked', 'crosshatched', + 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', + 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', + 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', + 'matted', 'meshed', 'paisley', 'perforated', 'pitted', + 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', + 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', + 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', + 'wrinkled', 'zigzagged') + +FGVCAIRCRAFT_CATEGORIES = ( + '707-320', '727-200', '737-200', '737-300', '737-400', '737-500', + '737-600', '737-700', '737-800', '737-900', '747-100', '747-200', + '747-300', '747-400', '757-200', '757-300', '767-200', '767-300', + '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', + 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500', + 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200', + 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47', + 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', + 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', + 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400', + 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145', + 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18', + 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70', + 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', + 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', + 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', + 'Tu-154', 'Yak-42') + +STANFORDCARS_CATEGORIES = ( + 'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012', + 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012', + 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012', + 'Aston Martin V8 Vantage Convertible 2012', + 'Aston Martin V8 Vantage Coupe 2012', + 'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012', + 'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', + 'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', + 'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011', + 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012', + 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012', + 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012', + 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012', + 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007', + 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012', + 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012', + 'BMW Z4 Convertible 2012', + 'Bentley Continental Supersports Conv. Convertible 2012', + 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011', + 'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007', + 'Bentley Continental Flying Spur Sedan 2007', + 'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009', + 'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012', + 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012', + 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007', + 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012', + 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012', + 'Chevrolet Corvette Ron Fellows Edition Z06 2007', + 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012', + 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007', + 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012', + 'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012', + 'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010', + 'Chevrolet TrailBlazer SS 2009', + 'Chevrolet Silverado 2500HD Regular Cab 2012', + 'Chevrolet Silverado 1500 Classic Extended Cab 2007', + 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007', + 'Chevrolet Malibu Sedan 2007', + 'Chevrolet Silverado 1500 Extended Cab 2012', + 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009', + 'Chrysler Sebring Convertible 2010', + 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010', + 'Chrysler Crossfire Convertible 2008', + 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002', + 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007', + 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010', + 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009', + 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010', + 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008', + 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012', + 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012', + 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998', + 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012', + 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012', + 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012', + 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012', + 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007', + 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012', + 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006', + 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007', + 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012', + 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012', + 'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012', + 'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993', + 'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009', + 'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007', + 'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012', + 'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012', + 'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012', + 'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007', + 'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012', + 'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012', + 'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012', + 'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012', + 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012', + 'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012', + 'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012', + 'Lamborghini Gallardo LP 570-4 Superleggera 2012', + 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012', + 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011', + 'MINI Cooper Roadster Convertible 2012', + 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011', + 'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993', + 'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009', + 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012', + 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012', + 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012', + 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998', + 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012', + 'Ram C/V Cargo Van Minivan 2012', + 'Rolls-Royce Phantom Drophead Coupe Convertible 2012', + 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012', + 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009', + 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007', + 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012', + 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012', + 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012', + 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012', + 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991', + 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012', + 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007', + 'smart fortwo Convertible 2012') + +SUN397_CATEGORIES = ( + 'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', + 'amusement_arcade', 'amusement_park', 'anechoic_chamber', + 'apartment_building_outdoor', 'apse_indoor', 'aquarium', 'aqueduct', + 'arch', 'archive', 'arrival_gate_outdoor', 'art_gallery', 'art_school', + 'art_studio', 'assembly_line', 'athletic_field_outdoor', 'atrium_public', + 'attic', 'auditorium', 'auto_factory', 'badlands', + 'badminton_court_indoor', 'baggage_claim', 'bakery_shop', + 'balcony_exterior', 'balcony_interior', 'ball_pit', 'ballroom', + 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', + 'baseball_field', 'basement', 'basilica', 'basketball_court_outdoor', + 'bathroom', 'batters_box', 'bayou', 'bazaar_indoor', 'bazaar_outdoor', + 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', + 'bistro_indoor', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', + 'booth_indoor', 'botanical_garden', 'bow_window_indoor', + 'bow_window_outdoor', 'bowling_alley', 'boxing_ring', 'brewery_indoor', + 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', + 'butchers_shop', 'butte', 'cabin_outdoor', 'cafeteria', 'campsite', + 'campus', 'canal_natural', 'canal_urban', 'candy_store', 'canyon', + 'car_interior_backseat', 'car_interior_frontseat', 'carrousel', + 'casino_indoor', 'castle', 'catacomb', 'cathedral_indoor', + 'cathedral_outdoor', 'cavern_indoor', 'cemetery', 'chalet', + 'cheese_factory', 'chemistry_lab', 'chicken_coop_indoor', + 'chicken_coop_outdoor', 'childs_room', 'church_indoor', 'church_outdoor', + 'classroom', 'clean_room', 'cliff', 'cloister_indoor', 'closet', + 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', + 'conference_center', 'conference_room', 'construction_site', + 'control_room', 'control_tower_outdoor', 'corn_field', 'corral', + 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', + 'covered_bridge_exterior', 'creek', 'crevasse', 'crosswalk', + 'cubicle_office', 'dam', 'delicatessen', 'dentists_office', 'desert_sand', + 'desert_vegetation', 'diner_indoor', 'diner_outdoor', 'dinette_home', + 'dinette_vehicle', 'dining_car', 'dining_room', 'discotheque', 'dock', + 'doorway_outdoor', 'dorm_room', 'driveway', 'driving_range_outdoor', + 'drugstore', 'electrical_substation', 'elevator_door', 'elevator_interior', + 'elevator_shaft', 'engine_room', 'escalator_indoor', 'excavation', + 'factory_indoor', 'fairway', 'fastfood_restaurant', 'field_cultivated', + 'field_wild', 'fire_escape', 'fire_station', 'firing_range_indoor', + 'fishpond', 'florist_shop_indoor', 'food_court', 'forest_broadleaf', + 'forest_needleleaf', 'forest_path', 'forest_road', 'formal_garden', + 'fountain', 'galley', 'game_room', 'garage_indoor', 'garbage_dump', + 'gas_station', 'gazebo_exterior', 'general_store_indoor', + 'general_store_outdoor', 'gift_shop', 'golf_course', 'greenhouse_indoor', + 'greenhouse_outdoor', 'gymnasium_indoor', 'hangar_indoor', + 'hangar_outdoor', 'harbor', 'hayfield', 'heliport', 'herb_garden', + 'highway', 'hill', 'home_office', 'hospital', 'hospital_room', + 'hot_spring', 'hot_tub_outdoor', 'hotel_outdoor', 'hotel_room', 'house', + 'hunting_lodge_outdoor', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', + 'ice_skating_rink_indoor', 'ice_skating_rink_outdoor', 'iceberg', 'igloo', + 'industrial_area', 'inn_outdoor', 'islet', 'jacuzzi_indoor', 'jail_indoor', + 'jail_cell', 'jewelry_shop', 'kasbah', 'kennel_indoor', 'kennel_outdoor', + 'kindergarden_classroom', 'kitchen', 'kitchenette', 'labyrinth_outdoor', + 'lake_natural', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', + 'library_indoor', 'library_outdoor', 'lido_deck_outdoor', 'lift_bridge', + 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', + 'locker_room', 'mansion', 'manufactured_home', 'market_indoor', + 'market_outdoor', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', + 'moat_water', 'monastery_outdoor', 'mosque_indoor', 'mosque_outdoor', + 'motel', 'mountain', 'mountain_snowy', 'movie_theater_indoor', + 'museum_indoor', 'music_store', 'music_studio', + 'nuclear_power_plant_outdoor', 'nursery', 'oast_house', + 'observatory_outdoor', 'ocean', 'office', 'office_building', + 'oil_refinery_outdoor', 'oilrig', 'operating_room', 'orchard', + 'outhouse_outdoor', 'pagoda', 'palace', 'pantry', 'park', + 'parking_garage_indoor', 'parking_garage_outdoor', 'parking_lot', 'parlor', + 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', + 'physics_laboratory', 'picnic_area', 'pilothouse_indoor', + 'planetarium_outdoor', 'playground', 'playroom', 'plaza', 'podium_indoor', + 'podium_outdoor', 'pond', 'poolroom_establishment', 'poolroom_home', + 'power_plant_outdoor', 'promenade_deck', 'pub_indoor', 'pulpit', + 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', + 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', + 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', + 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', + 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', + 'shed', 'shoe_shop', 'shopfront', 'shopping_mall_indoor', 'shower', + 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', + 'slum', 'snowfield', 'squash_court', 'stable', 'stadium_baseball', + 'stadium_football', 'stage_indoor', 'staircase', 'street', + 'subway_interior', 'subway_station_platform', 'supermarket', 'sushi_bar', + 'swamp', 'swimming_pool_indoor', 'swimming_pool_outdoor', + 'synagogue_indoor', 'synagogue_outdoor', 'television_studio', + 'temple_east_asia', 'temple_south_asia', 'tennis_court_indoor', + 'tennis_court_outdoor', 'tent_outdoor', 'theater_indoor_procenium', + 'theater_indoor_seats', 'thriftshop', 'throne_room', 'ticket_booth', + 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'track_outdoor', + 'train_railway', 'train_station_platform', 'tree_farm', 'tree_house', + 'trench', 'underwater_coral_reef', 'utility_room', 'valley', + 'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office', + 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', + 'volleyball_court_indoor', 'volleyball_court_outdoor', 'waiting_room', + 'warehouse_indoor', 'water_tower', 'waterfall_block', 'waterfall_fan', + 'waterfall_plunge', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', + 'wind_farm', 'windmill', 'wine_cellar_barrel_storage', + 'wine_cellar_bottle_storage', 'wrestling_ring_indoor', 'yard', + 'youth_hostel') + +CALTECH101_CATEGORIES = ( + 'BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes', + 'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver', + 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', + 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', + 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', + 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', + 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', + 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', + 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', + 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', + 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', + 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', + 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', + 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', + 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', + 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', + 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', + 'wrench', 'yin_yang') + +FOOD101_CATEGORIES = ( + 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', + 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', + 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', + 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', + 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', + 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', + 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', + 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', + 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', + 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', + 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', + 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', + 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', + 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', + 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', + 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', + 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', + 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', + 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', + 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', + 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', + 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles') + +CIFAR100_CATEGORIES_CN = ( + '苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩', + '桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云', + '蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩', + '仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树', + '摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树', + '田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海', + '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒', + '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼', + '柳树', '狼', '女人', '蠕虫') + +IMAGENET_SIMPLE_CATEGORIES = ( + 'tench', 'goldfish', 'great white shark', 'tiger shark', + 'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen', + 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', + 'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee', + 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', + 'great grey owl', 'fire salamander', 'smooth newt', 'newt', + 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', + 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana', + 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', + 'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile', + 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', + 'garter snake', 'water snake', 'vine snake', 'night snake', + 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', + 'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake', + 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', + 'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede', + 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', + 'quail', 'partridge', 'african grey parrot', 'macaw', + 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', + 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', + 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', + 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', + 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', + 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', + 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', + 'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill', + 'flamingo', 'little blue heron', 'great egret', 'bittern bird', + 'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', + 'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher', + 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', + 'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon', + 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', + 'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', + 'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound', + 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', + 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', + 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', + 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', + 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', + 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier', + 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', + 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', + 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', + 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever', + 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', + 'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', + 'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog', + 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', + 'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', + 'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie', + 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', + 'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler', + 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', + 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', + 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', + 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', + 'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher', + 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog', + 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', + 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', + 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', + 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo', + 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', + 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat', + 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', + 'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', + 'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', + 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', + 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', + 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', + 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', + 'damselfly', 'red admiral butterfly', 'ringlet butterfly', + 'monarch butterfly', 'small white butterfly', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', + 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', + 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', + 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', + 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', + 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', + 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', + 'white-headed capuchin', 'howler monkey', 'titi monkey', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', + 'indri', 'Asian elephant', 'African bush elephant', 'red panda', + 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', + 'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', + 'abaya', 'academic gown', 'accordion', 'acoustic guitar', + 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', + 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', + 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', + 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', + 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', + 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', + 'military hat (bearskin or shako)', 'beer bottle', 'beer glass', + 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', + 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie', + 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', + 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', + 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', + 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', + 'cardboard box / carton', 'car wheel', 'automated teller machine', + 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', + 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', + 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet', + 'Christmas stocking', 'church', 'movie theater', 'cleaver', + 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', + 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', + 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet', + 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', + 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', + 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', + 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', + 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', + 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick', + 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', + 'electric locomotive', 'entertainment center', 'envelope', + 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', + 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', + 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', + 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', + 'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', + 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', + 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', + 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', + 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', + 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass', + 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', + 'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask', + 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', + 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped', + 'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', + 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', + 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', + 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart', + 'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', + 'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', + 'parachute', 'parallel bars', 'park bench', 'parking meter', + 'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case', + 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', + 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', + 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship', + 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', + 'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', + 'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill', + 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', + 'radiator', 'radio', 'radio telescope', 'rain barrel', + 'recreational vehicle', 'fishing casting reel', 'reflex camera', + 'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle', + 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', + 'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker', + 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', + 'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', + 'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store', + 'shoji screen / room divider', 'shopping basket', 'shopping cart', + 'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask', + 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', + 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', + 'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', + 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', + 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', + 'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen', + 'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing', + 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', + 'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof', + 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', + 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', + 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray', + 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', + 'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', + 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', + 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', + 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', + 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', + 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', + 'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', + 'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', + 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', + 'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate', + 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle', + 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', + 'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini', + 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', + 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', + 'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', + 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', + 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito', + 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff', + 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', + 'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver', + 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip', + 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', + 'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom', + 'bolete', 'corn cob', 'toilet paper') diff --git a/mmpretrain/datasets/cifar.py b/mmpretrain/datasets/cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..2a011daee0d74e6b06613106f7587b8ad8a7ed90 --- /dev/null +++ b/mmpretrain/datasets/cifar.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pickle +from typing import List, Optional + +import mmengine.dist as dist +import numpy as np +from mmengine.fileio import (LocalBackend, exists, get, get_file_backend, + join_path) +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES +from .utils import check_md5, download_and_extract_archive + + +@DATASETS.register_module() +class CIFAR10(BaseDataset): + """`CIFAR10 `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py + + Args: + data_root (str): The root directory of the CIFAR Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ # noqa: E501 + + base_folder = 'cifar-10-batches-py' + url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + filename = 'cifar-10-python.tar.gz' + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + METAINFO = {'classes': CIFAR10_CATEGORIES} + + def __init__(self, + data_root: str = '', + split: str = 'train', + metainfo: Optional[dict] = None, + download: bool = True, + data_prefix: str = '', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + if not data_root and not data_prefix: + raise RuntimeError('Please set ``data_root`` to' + 'specify the dataset path') + + self.download = download + super().__init__( + # The CIFAR dataset doesn't need specify annotation file + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=dict(root=data_prefix), + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) + + if dist.is_main_process() and not self._check_integrity(): + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') + + if self.download: + download_and_extract_archive( + self.url, root, filename=self.filename, md5=self.tgz_md5) + else: + raise RuntimeError( + f'Cannot find {self.__class__.__name__} dataset in ' + f"{self.data_prefix['root']}, you can specify " + '`download=True` to download automatically.') + + dist.barrier() + assert self._check_integrity(), \ + 'Download failed or shared storage is unavailable. Please ' \ + f'download the dataset manually through {self.url}.' + + if self.split == 'train': + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + imgs = [] + gt_labels = [] + + # load the picked numpy arrays + for file_name, _ in downloaded_list: + file_path = join_path(root, self.base_folder, file_name) + entry = pickle.loads(get(file_path), encoding='latin1') + imgs.append(entry['data']) + if 'labels' in entry: + gt_labels.extend(entry['labels']) + else: + gt_labels.extend(entry['fine_labels']) + + imgs = np.vstack(imgs).reshape(-1, 3, 32, 32) + imgs = imgs.transpose((0, 2, 3, 1)) # convert to HWC + + if self.CLASSES is None: + # The metainfo in the file has the lowest priority, therefore + # we only need to load it if classes is not specified. + self._load_meta() + + data_list = [] + for img, gt_label in zip(imgs, gt_labels): + info = {'img': img, 'gt_label': int(gt_label)} + data_list.append(info) + return data_list + + def _load_meta(self): + """Load categories information from metafile.""" + root = self.data_prefix['root'] + + path = join_path(root, self.base_folder, self.meta['filename']) + md5 = self.meta.get('md5', None) + if not exists(path) or (md5 is not None and not check_md5(path, md5)): + raise RuntimeError( + 'Dataset metadata file not found or corrupted.' + + ' You can use `download=True` to download it') + data = pickle.loads(get(path), encoding='latin1') + self._metainfo.setdefault('classes', data[self.meta['key']]) + + def _check_integrity(self): + """Check the integrity of data files.""" + root = self.data_prefix['root'] + + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = join_path(root, self.base_folder, filename) + if not exists(fpath): + return False + if md5 is not None and not check_md5(fpath, md5): + return False + return True + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [f"Prefix of data: \t{self.data_prefix['root']}"] + return body + + +@DATASETS.register_module() +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + Args: + data_root (str): The root directory of the CIFAR Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + base_folder = 'cifar-100-python' + url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' + filename = 'cifar-100-python.tar.gz' + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } + METAINFO = {'classes': CIFAR100_CATEGORIES} diff --git a/mmpretrain/datasets/coco_caption.py b/mmpretrain/datasets/coco_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..541cda80398f7fcc7d3304d3d9f43155685ebe57 --- /dev/null +++ b/mmpretrain/datasets/coco_caption.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class COCOCaption(BaseDataset): + """COCO Caption dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``.. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + for ann in annotations: + data_info = { + 'image_id': Path(ann['image']).stem.split('_')[-1], + 'img_path': file_backend.join_path(img_prefix, ann['image']), + 'gt_caption': ann['caption'], + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..be8a0bcb864dddad53e96e6342f9bd987ae222e2 --- /dev/null +++ b/mmpretrain/datasets/coco_retrieval.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from collections import OrderedDict +from os import PathLike +from typing import List, Sequence, Union + +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS, TRANSFORMS +from .base_dataset import BaseDataset + + +def expanduser(data_prefix): + if isinstance(data_prefix, (str, PathLike)): + return osp.expanduser(data_prefix) + else: + return data_prefix + + +@DATASETS.register_module() +class COCORetrieval(BaseDataset): + """COCO Retrieval dataset. + + COCO (Common Objects in Context): The COCO dataset contains more than + 330K images,each of which has approximately 5 descriptive annotations. + This dataset was releasedin collaboration between Microsoft and Carnegie + Mellon University + + COCO_2014 dataset directory: :: + + COCO_2014 + ├── val2014 + ├── train2014 + ├── annotations + ├── instances_train2014.json + ├── instances_val2014.json + ├── person_keypoints_train2014.json + ├── person_keypoints_val2014.json + ├── captions_train2014.json + ├── captions_val2014.json + + Args: + ann_file (str): Annotation file path. + test_mode (bool): Whether dataset is used for evaluation. This will + decide the annotation format in data list annotations. + Defaults to False. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import COCORetrieval + >>> train_dataset=COCORetrieval(data_root='coco2014/') + >>> train_dataset + Dataset COCORetrieval + Number of samples: 414113 + Annotation file: /coco2014/annotations/captions_train2014.json + Prefix of images: /coco2014/ + >>> from mmpretrain.datasets import COCORetrieval + >>> val_dataset = COCORetrieval(data_root='coco2014/') + >>> val_dataset + Dataset COCORetrieval + Number of samples: 202654 + Annotation file: /coco2014/annotations/captions_val2014.json + Prefix of images: /coco2014/ + """ + + def __init__(self, + ann_file: str, + test_mode: bool = False, + data_prefix: Union[str, dict] = '', + data_root: str = '', + pipeline: Sequence = (), + **kwargs): + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + pipeline=transforms, + ann_file=ann_file, + **kwargs, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + # get file backend + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + anno_info = json.load(open(self.ann_file, 'r')) + # mapping img_id to img filename + img_dict = OrderedDict() + for idx, img in enumerate(anno_info['images']): + if img['id'] not in img_dict: + img_rel_path = img['coco_url'].rsplit('/', 2)[-2:] + img_path = file_backend.join_path(img_prefix, *img_rel_path) + + # create new idx for image + img_dict[img['id']] = dict( + ori_id=img['id'], + image_id=idx, # will be used for evaluation + img_path=img_path, + text=[], + gt_text_id=[], + gt_image_id=[], + ) + + train_list = [] + for idx, anno in enumerate(anno_info['annotations']): + anno['text'] = anno.pop('caption') + anno['ori_id'] = anno.pop('id') + anno['text_id'] = idx # will be used for evaluation + # 1. prepare train data list item + train_data = anno.copy() + train_image = img_dict[train_data['image_id']] + train_data['img_path'] = train_image['img_path'] + train_data['image_ori_id'] = train_image['ori_id'] + train_data['image_id'] = train_image['image_id'] + train_data['is_matched'] = True + train_list.append(train_data) + # 2. prepare eval data list item based on img dict + img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id']) + img_dict[anno['image_id']]['text'].append(anno['text']) + img_dict[anno['image_id']]['gt_image_id'].append( + train_image['image_id']) + + self.img_size = len(img_dict) + self.text_size = len(anno_info['annotations']) + + # return needed format data list + if self.test_mode: + return list(img_dict.values()) + return train_list diff --git a/mmpretrain/datasets/coco_vqa.py b/mmpretrain/datasets/coco_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..85f4bdcf39ef82ec47a2072dc198e6b8792d8768 --- /dev/null +++ b/mmpretrain/datasets/coco_vqa.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import re +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class COCOVQA(BaseDataset): + """VQAv2 dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + question_file: str, + ann_file: str = '', + **kwarg): + self.question_file = question_file + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.question_file) and self.question_file: + self.question_file = osp.join(self.data_root, self.question_file) + + return super()._join_prefix() + + def _create_image_index(self): + img_prefix = self.data_prefix['img_path'] + + files = mmengine.list_dir_or_file(img_prefix, list_dir=False) + image_index = {} + for file in files: + image_id = re.findall(r'\d{12}', file) + if len(image_id) > 0: + image_id = int(image_id[-1]) + image_index[image_id] = mmengine.join_path(img_prefix, file) + + return image_index + + def load_data_list(self) -> List[dict]: + """Load data list.""" + questions = mmengine.load(self.question_file)['questions'] + if self.ann_file: + annotations = mmengine.load(self.ann_file)['annotations'] + assert len(questions) == len(annotations) + else: + annotations = [None] * len(questions) + + # The original VQAv2 annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + data_list = [] + for question, ann in zip(questions, annotations): + # question example + # { + # 'image_id': 262144, + # 'question': "Is the ball flying towards the batter?", + # 'question_id': 262144000 + # } + # + # ann example + # { + # 'question_type': "what are the", + # 'answer_type': "other", + # 'answers': [ + # {'answer': 'watching', + # 'answer_id': 1, + # 'answer_confidence': 'yes'}, + # ... + # ], + # 'image_id': 262148, + # 'question_id': 262148000, + # 'multiple_choice_answer': 'watching', + # 'answer_type': 'other', + # } + + data_info = question + data_info['img_path'] = self.image_index[question['image_id']] + + if ann is not None: + assert ann['question_id'] == question['question_id'] + + # add answer_weight & answer_count, delete duplicate answer + answers = [item['answer'] for item in ann.pop('answers')] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + data_info.update(ann) + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/cub.py b/mmpretrain/datasets/cub.py new file mode 100644 index 0000000000000000000000000000000000000000..8db126216fb3408e2dd18255db04a851eb5fe08f --- /dev/null +++ b/mmpretrain/datasets/cub.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CUB_CATEGORIES + + +@DATASETS.register_module() +class CUB(BaseDataset): + """The CUB-200-2011 Dataset. + + Support the `CUB-200-2011 `_ Dataset. + Comparing with the `CUB-200 `_ Dataset, + there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset + directory structure is as follows. + + CUB dataset directory: :: + + CUB_200_2011 + ├── images + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── images.txt + ├── image_class_labels.txt + ├── train_test_split.txt + └── .... + + Args: + data_root (str): The root directory for CUB-200-2011 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import CUB + >>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train') + >>> train_dataset + Dataset CUB + Number of samples: 5994 + Number of categories: 200 + Root of dataset: data/CUB_200_2011 + >>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test') + >>> test_dataset + Dataset CUB + Number of samples: 5794 + Number of categories: 200 + Root of dataset: data/CUB_200_2011 + """ # noqa: E501 + + METAINFO = {'classes': CUB_CATEGORIES} + + def __init__(self, + data_root: str, + split: str = 'train', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + ann_file = 'images.txt' + data_prefix = 'images' + image_class_labels_file = 'image_class_labels.txt' + train_test_split_file = 'train_test_split.txt' + + self.backend = get_file_backend(data_root, enable_singleton=True) + self.image_class_labels_file = self.backend.join_path( + data_root, image_class_labels_file) + self.train_test_split_file = self.backend.join_path( + data_root, train_test_split_file) + super(CUB, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def _load_data_from_txt(self, filepath): + """load data from CUB txt file, the every line of the file is idx and a + data item.""" + pairs = list_from_file(filepath) + data_dict = dict() + for pair in pairs: + idx, data_item = pair.split() + # all the index starts from 1 in CUB files, + # here we need to '- 1' to let them start from 0. + data_dict[int(idx) - 1] = data_item + return data_dict + + def load_data_list(self): + """Load images and ground truth labels.""" + sample_dict = self._load_data_from_txt(self.ann_file) + + label_dict = self._load_data_from_txt(self.image_class_labels_file) + + split_dict = self._load_data_from_txt(self.train_test_split_file) + + assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\ + f'sample_ids should be same in files {self.ann_file}, ' \ + f'{self.image_class_labels_file} and {self.train_test_split_file}' + + data_list = [] + for sample_id in sample_dict.keys(): + if split_dict[sample_id] == '1' and self.split == 'test': + # skip train samples when split='test' + continue + elif split_dict[sample_id] == '0' and self.split == 'train': + # skip test samples when split='train' + continue + + img_path = self.backend.join_path(self.img_prefix, + sample_dict[sample_id]) + gt_label = int(label_dict[sample_id]) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/custom.py b/mmpretrain/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..bb491ff0cc7f816f629603d3b8be55e3f787c373 --- /dev/null +++ b/mmpretrain/datasets/custom.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +from mmengine.fileio import (BaseStorageBackend, get_file_backend, + list_from_file) +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +def find_folders( + root: str, + backend: Optional[BaseStorageBackend] = None +) -> Tuple[List[str], Dict[str, int]]: + """Find classes by folders under a root. + + Args: + root (string): root directory of folders + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[List[str], Dict[str, int]]: + + - folders: The name of sub folders under the root. + - folder_to_idx: The map from folder name to class idx. + """ + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + folders = list( + backend.list_dir_or_file( + root, + list_dir=True, + list_file=False, + recursive=False, + )) + folders.sort() + folder_to_idx = {folders[i]: i for i in range(len(folders))} + return folders, folder_to_idx + + +def get_samples( + root: str, + folder_to_idx: Dict[str, int], + is_valid_file: Callable, + backend: Optional[BaseStorageBackend] = None, +): + """Make dataset by walking all images under a root. + + Args: + root (string): root directory of folders + folder_to_idx (dict): the map from class name to class idx + is_valid_file (Callable): A function that takes path of a file + and check if the file is a valid sample file. + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[list, set]: + + - samples: a list of tuple where each element is (image, class_idx) + - empty_folders: The folders don't have any valid files. + """ + samples = [] + available_classes = set() + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + + if folder_to_idx is not None: + for folder_name in sorted(list(folder_to_idx.keys())): + _dir = backend.join_path(root, folder_name) + files = backend.list_dir_or_file( + _dir, + list_dir=False, + list_file=True, + recursive=True, + ) + for file in sorted(list(files)): + if is_valid_file(file): + path = backend.join_path(folder_name, file) + item = (path, folder_to_idx[folder_name]) + samples.append(item) + available_classes.add(folder_name) + empty_folders = set(folder_to_idx.keys()) - available_classes + else: + files = backend.list_dir_or_file( + root, + list_dir=False, + list_file=True, + recursive=True, + ) + samples = [file for file in sorted(list(files)) if is_valid_file(file)] + empty_folders = None + + return samples, empty_folders + + +@DATASETS.register_module() +class CustomDataset(BaseDataset): + """A generic dataset for multiple tasks. + + The dataset supports two kinds of style. + + 1. Use an annotation file to specify all samples, and each line indicates a + sample: + + The annotation file (for ``with_label=True``, supervised tasks.): :: + + folder_1/xxx.png 0 + folder_1/xxy.png 1 + 123.png 4 + nsdf3.png 3 + ... + + The annotation file (for ``with_label=False``, unsupervised tasks.): :: + + folder_1/xxx.png + folder_1/xxy.png + 123.png + nsdf3.png + ... + + Sample files: :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + ├── 123.png + ├── nsdf3.png + └── ... + + Please use the argument ``metainfo`` to specify extra information for + the task, like ``{'classes': ('bird', 'cat', 'deer', 'dog', 'frog')}``. + + 2. Place all samples in one folder as below: + + Sample files (for ``with_label=True``, supervised tasks, we use the name + of sub-folders as the categories names): :: + + data_prefix/ + ├── class_x + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + │ └── xxz.png + └── class_y + ├── 123.png + ├── nsdf3.png + ├── ... + └── asd932_.png + + Sample files (for ``with_label=False``, unsupervised tasks, we use all + sample files under the specified folder): :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + ├── 123.png + ├── nsdf3.png + └── ... + + If the ``ann_file`` is specified, the dataset will be generated by the + first way, otherwise, try the second way. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for the data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + with_label (bool): Whether the annotation file includes ground truth + labels, or use sub-folders to specify categories. + Defaults to True. + extensions (Sequence[str]): A sequence of allowed extensions. Defaults + to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + with_label=True, + extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', + '.bmp', '.pgm', '.tif'), + metainfo: Optional[dict] = None, + lazy_init: bool = False, + **kwargs): + assert (ann_file or data_prefix or data_root), \ + 'One of `ann_file`, `data_root` and `data_prefix` must '\ + 'be specified.' + + self.extensions = tuple(set([i.lower() for i in extensions])) + self.with_label = with_label + + super().__init__( + # The base class requires string ann_file but this class doesn't + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + # Force to lazy_init for some modification before loading data. + lazy_init=True, + **kwargs) + + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + def _find_samples(self): + """find samples from ``data_prefix``.""" + if self.with_label: + classes, folder_to_idx = find_folders(self.img_prefix) + samples, empty_classes = get_samples( + self.img_prefix, + folder_to_idx, + is_valid_file=self.is_valid_file, + ) + + self.folder_to_idx = folder_to_idx + + if self.CLASSES is not None: + assert len(self.CLASSES) == len(classes), \ + f"The number of subfolders ({len(classes)}) doesn't " \ + f'match the number of specified classes ' \ + f'({len(self.CLASSES)}). Please check the data folder.' + else: + self._metainfo['classes'] = tuple(classes) + else: + samples, empty_classes = get_samples( + self.img_prefix, + None, + is_valid_file=self.is_valid_file, + ) + + if len(samples) == 0: + raise RuntimeError( + f'Found 0 files in subfolders of: {self.data_prefix}. ' + f'Supported extensions are: {",".join(self.extensions)}') + + if empty_classes: + logger = MMLogger.get_current_instance() + logger.warning( + 'Found no valid file in the folder ' + f'{", ".join(empty_classes)}. ' + f"Supported extensions are: {', '.join(self.extensions)}") + + return samples + + def load_data_list(self): + """Load image paths and gt_labels.""" + if not self.ann_file: + samples = self._find_samples() + elif self.with_label: + lines = list_from_file(self.ann_file) + samples = [x.strip().rsplit(' ', 1) for x in lines] + else: + samples = list_from_file(self.ann_file) + + # Pre-build file backend to prevent verbose file backend inference. + backend = get_file_backend(self.img_prefix, enable_singleton=True) + data_list = [] + for sample in samples: + if self.with_label: + filename, gt_label = sample + img_path = backend.join_path(self.img_prefix, filename) + info = {'img_path': img_path, 'gt_label': int(gt_label)} + else: + img_path = backend.join_path(self.img_prefix, sample) + info = {'img_path': img_path} + data_list.append(info) + return data_list + + def is_valid_file(self, filename: str) -> bool: + """Check if a file is a valid sample.""" + return filename.lower().endswith(self.extensions) diff --git a/mmpretrain/datasets/dataset_wrappers.py b/mmpretrain/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..1adff10beb024940f9066a407cc76ddb06b27404 --- /dev/null +++ b/mmpretrain/datasets/dataset_wrappers.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import numpy as np +from mmengine.dataset import BaseDataset, force_full_init + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class KFoldDataset: + """A wrapper of dataset for K-Fold cross-validation. + + K-Fold cross-validation divides all the samples in groups of samples, + called folds, of almost equal sizes. And we use k-1 of folds to do training + and use the fold left to do validation. + + Args: + dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be + divided + fold (int): The fold used to do validation. Defaults to 0. + num_splits (int): The number of all folds. Defaults to 5. + test_mode (bool): Use the training dataset or validation dataset. + Defaults to False. + seed (int, optional): The seed to shuffle the dataset before splitting. + If None, not shuffle the dataset. Defaults to None. + """ + + def __init__(self, + dataset, + fold=0, + num_splits=5, + test_mode=False, + seed=None): + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + # Init the dataset wrapper lazily according to the dataset setting. + lazy_init = dataset.get('lazy_init', False) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError(f'Unsupported dataset type {type(dataset)}.') + + self._metainfo = getattr(self.dataset, 'metainfo', {}) + self.fold = fold + self.num_splits = num_splits + self.test_mode = test_mode + self.seed = seed + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of ``self.dataset``. + + Returns: + dict: Meta information of the dataset. + """ + # Prevent `self._metainfo` from being modified by outside. + return copy.deepcopy(self._metainfo) + + def full_init(self): + """fully initialize the dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + ori_len = len(self.dataset) + indices = list(range(ori_len)) + if self.seed is not None: + rng = np.random.default_rng(self.seed) + rng.shuffle(indices) + + test_start = ori_len * self.fold // self.num_splits + test_end = ori_len * (self.fold + 1) // self.num_splits + if self.test_mode: + indices = indices[test_start:test_end] + else: + indices = indices[:test_start] + indices[test_end:] + + self._ori_indices = indices + self.dataset = self.dataset.get_subset(indices) + + self._fully_initialized = True + + @force_full_init + def _get_ori_dataset_idx(self, idx: int) -> int: + """Convert global idx to local index. + + Args: + idx (int): Global index of ``KFoldDataset``. + + Returns: + int: The original index in the whole dataset. + """ + return self._ori_indices[idx] + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``KFoldDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def __len__(self): + return len(self.dataset) + + @force_full_init + def __getitem__(self, idx): + return self.dataset[idx] + + @force_full_init + def get_cat_ids(self, idx): + return self.dataset.get_cat_ids(idx) + + @force_full_init + def get_gt_labels(self): + return self.dataset.get_gt_labels() + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + type_ = 'test' if self.test_mode else 'training' + body.append(f'Type: \t{type_}') + body.append(f'Seed: \t{self.seed}') + + def ordinal(n): + # Copy from https://codegolf.stackexchange.com/a/74047 + suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4] + return f'{n}{suffix}' + + body.append( + f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold') + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + else: + body.append('The `CLASSES` meta info is not set.') + + body.append( + f'Original dataset type:\t{self.dataset.__class__.__name__}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmpretrain/datasets/dtd.py b/mmpretrain/datasets/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..034d0b1b444afebfc420eeff7e138072f7d7ee1f --- /dev/null +++ b/mmpretrain/datasets/dtd.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import DTD_CATEGORIES + + +@DATASETS.register_module() +class DTD(BaseDataset): + """The Describable Texture Dataset (DTD). + + Support the `Describable Texture Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + DTD dataset directory: :: + + dtd + ├── images + │ ├── banded + | | ├──banded_0002.jpg + | | ├──banded_0004.jpg + | | └── ... + │ └── ... + ├── imdb + │ └── imdb.mat + ├── labels + | | ├──labels_joint_anno.txt + | | ├──test1.txt + | | ├──test2.txt + | | └── ... + │ └── ... + └── .... + + Args: + data_root (str): The root directory for Describable Texture dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import DTD + >>> train_dataset = DTD(data_root='data/dtd', split='trainval') + >>> train_dataset + Dataset DTD + Number of samples: 3760 + Number of categories: 47 + Root of dataset: data/dtd + >>> test_dataset = DTD(data_root='data/dtd', split='test') + >>> test_dataset + Dataset DTD + Number of samples: 1880 + Number of categories: 47 + Root of dataset: data/dtd + """ # noqa: E501 + + METAINFO = {'classes': DTD_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + data_prefix = 'images' + test_mode = split == 'test' + + self.backend = get_file_backend(data_root, enable_singleton=True) + ann_file = self.backend.join_path('imdb', 'imdb.mat') + + super(DTD, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + data = mat4py.loadmat(self.ann_file)['images'] + names = data['name'] + labels = data['class'] + parts = data['set'] + num = len(names) + assert num == len(labels) == len(parts), 'get error ann file' + + if self.split == 'train': + target_set = {1} + elif self.split == 'val': + target_set = {2} + elif self.split == 'test': + target_set = {3} + else: + target_set = {1, 2} + + data_list = [] + for i in range(num): + if parts[i] in target_set: + img_name = names[i] + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/fgvcaircraft.py b/mmpretrain/datasets/fgvcaircraft.py new file mode 100644 index 0000000000000000000000000000000000000000..696992c06bbf02f097d017a519d42f758ba5f16f --- /dev/null +++ b/mmpretrain/datasets/fgvcaircraft.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FGVCAIRCRAFT_CATEGORIES + + +@DATASETS.register_module() +class FGVCAircraft(BaseDataset): + """The FGVC_Aircraft Dataset. + + Support the `FGVC_Aircraft Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + FGVC_Aircraft dataset directory: :: + + fgvc-aircraft-2013b + └── data + ├── images + │ ├── 1.jpg + │ ├── 2.jpg + │ └── ... + ├── images_variant_train.txt + ├── images_variant_test.txt + ├── images_variant_trainval.txt + ├── images_variant_val.txt + ├── variants.txt + └── .... + + Args: + data_root (str): The root directory for FGVC_Aircraft dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import FGVCAircraft + >>> train_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='trainval') + >>> train_dataset + Dataset FGVCAircraft + Number of samples: 6667 + Number of categories: 100 + Root of dataset: data/fgvc-aircraft-2013b + >>> test_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='test') + >>> test_dataset + Dataset FGVCAircraft + Number of samples: 3333 + Number of categories: 100 + Root of dataset: data/fgvc-aircraft-2013b + """ # noqa: E501 + + METAINFO = {'classes': FGVCAIRCRAFT_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + ann_file = self.backend.join_path('data', + f'images_variant_{split}.txt') + data_prefix = self.backend.join_path('data', 'images') + test_mode = split == 'test' + + super(FGVCAircraft, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + pair = pair.split() + img_name = pair[0] + class_name = ' '.join(pair[1:]) + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/flamingo.py b/mmpretrain/datasets/flamingo.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5745a1437537fccbc304d158a0f0c8d09f032a --- /dev/null +++ b/mmpretrain/datasets/flamingo.py @@ -0,0 +1,295 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from abc import abstractmethod +from collections import Counter +from typing import List + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS +from .coco_vqa import COCOVQA + + +class FlamingoFewShotMixin: + """Flamingo fewshot eval dataset minin. + + Args: + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + Note: 0 does not mean a strict zero-shot in Flamingo setting. + It will use 2 only-text prompt without in context images. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + incontext_prompt_temp (str): In context prompt template for few shot + examples. Defaults to ''. + final_prompt_temp (str): Final query prompt template. Defaults to ''. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + incontext_prompt_temp: str = '', + final_prompt_temp: str = '', + **kwarg): + self.num_shots = num_shots + self.num_support_examples = num_support_examples + self.num_query_examples = num_query_examples + self.incontext_prompt_temp = incontext_prompt_temp + self.final_prompt_temp = final_prompt_temp + super().__init__(**kwarg) + + def get_subset_idx(self, total_num): + random_idx = np.random.choice( + total_num, + self.num_support_examples + self.num_query_examples, + replace=False) + + support_idx = random_idx[:self.num_support_examples] + query_idx = random_idx[self.num_support_examples:] + return support_idx, query_idx + + @abstractmethod + def parse_basic_anno(self, anno: dict) -> dict: + """Parse basic annotation for support and query set.""" + pass + + @abstractmethod + def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list.""" + pass + + +@DATASETS.register_module() +class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA): + """Flamingo few shot VQAv2 dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + ann_file (str): Annotation file path. + question_file (str): Question file path. + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + Note: 0 does not mean a strict zero-shot in Flamingo setting. + It will use 2 only-text prompt without in context images. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + question_file: str, + ann_file: str = '', + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + **kwarg): + super().__init__( + data_root=data_root, + question_file=question_file, + ann_file=ann_file, + num_shots=num_shots, + num_support_examples=num_support_examples, + num_query_examples=num_query_examples, + **kwarg) + + def parse_basic_anno(self, ann: dict) -> dict: + """Parse basic annotation for support and query set. + + Args: + anno (dict): Annotation for single example. + + Return: + dict: Parsed annotation for single example. + """ + if ann is None: + return {} + + answers = [a['answer'] for a in ann['answers']] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + answer_info = { + 'gt_answer': list(count.keys()), + 'gt_answer_weight': answer_weight + } + return answer_info + + def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list. + + Args: + anno (dict): Annotation for single example. + support_list (List): List of support subset to subsample few shots. + + Return: + dict: Parsed annotation for single example. + """ + # prepare n shots examples + shots = random.sample(support_list, self.num_shots) + + # append image path for n shots + img_path = [shot['img_path'] for shot in shots] + img_path.append(query['img_path']) + query['img_path'] = img_path + + query['shots'] = [ + dict( + question=item['question'], + answer=item['gt_answer'][0], + ) for item in shots + ] + return query + + def load_data_list(self) -> List[dict]: + """Load data list.""" + questions = mmengine.load(self.question_file)['questions'] + if self.ann_file: + annotations = mmengine.load(self.ann_file)['annotations'] + assert len(questions) == len(annotations) + else: + annotations = [None] * len(questions) + if self.num_shots > 0: + raise ValueError('Unable to construct few-shot examples ' + 'since no annotation file.') + + # The original VQAv2 annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + num_data = len(questions) + support_idx, query_idx = self.get_subset_idx(num_data) + + # prepare support subset + if self.num_shots > 0: + support_list = [] + for idx in support_idx: + question = questions[idx] + ann = annotations[idx] + support = {**question, **self.parse_basic_anno(ann)} + support['img_path'] = self.image_index[question['image_id']] + support_list.append(support) + + # prepare query subset + data_list = [] + for idx in query_idx: + question = questions[idx] + ann = annotations[idx] + data_info = {**question, **self.parse_basic_anno(ann)} + data_info['img_path'] = self.image_index[question['image_id']] + if self.num_shots > 0: + data_info = self.parse_fewshot_anno(data_info, support_list) + data_list.append(data_info) + + return data_list + + +@DATASETS.register_module() +class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset): + """Flamingo few shot COCO Caption dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + ann_file: str, + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + **kwarg): + super().__init__( + data_root=data_root, + ann_file=ann_file, + num_shots=num_shots, + num_support_examples=num_support_examples, + num_query_examples=num_query_examples, + **kwarg) + + def parse_basic_anno(self, ann: dict, coco: COCO) -> dict: + """Parse basic annotation for support and query set. + + Args: + anno (dict): Annotation for single example. + coco (COCO): The coco dataset. + + Return: + dict: Parsed annotation for single example. + """ + img_prefix = self.data_prefix['img_path'] + img = coco.imgs[ann['image_id']] + data_info = dict( + img_path=mmengine.join_path(img_prefix, img['file_name']), + gt_caption=ann['caption'], + image_id=ann['image_id'], + ) + return data_info + + def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list. + + Args: + query (dict): Annotation for single example. + support_list (List): List of support subset to subsample few shots. + coco (COCO): The coco dataset. + + Return: + dict: Parsed annotation for single example. + """ + # prepare n shots examples + shots = random.sample(support_list, self.num_shots) + + # append image path for n shots + img_path = [shot['img_path'] for shot in shots] + img_path.append(query['img_path']) + query['img_path'] = img_path + + query['shots'] = [dict(caption=item['gt_caption']) for item in shots] + return query + + def load_data_list(self) -> List[dict]: + """Load data list.""" + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + + num_data = len(coco.anns) + support_idx, query_idx = self.get_subset_idx(num_data) + ann_ids = list(coco.anns) + + # prepare support subset + if self.num_shots > 0: + support_list = [] + for idx in support_idx: + support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) + support_list.append(support) + + # prepare query subset + query_list = [] + for idx in query_idx: + data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) + if self.num_shots > 0: + data_info = self.parse_fewshot_anno(data_info, support_list) + query_list.append(data_info) + + return query_list diff --git a/mmpretrain/datasets/flickr30k_caption.py b/mmpretrain/datasets/flickr30k_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f6841a2c87a0b3eaa3a7abd5b8fda1cb235bc0 --- /dev/null +++ b/mmpretrain/datasets/flickr30k_caption.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class Flickr30kCaption(BaseDataset): + """Flickr30k Caption dataset. To generate coco-style GT annotation for + evaluation, please refer to + tools/dataset_converters/convert_flickr30k_ann.py. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + + for img in annotations['images']: + + # img_example={ + # "sentids": [0, 1, 2], + # "imgid": 0, + # "sentences": [ + # {"raw": "Two men in green shirts standing in a yard.", + # "imgid": 0, "sentid": 0}, + # {"raw": "A man in a blue shirt standing in a garden.", + # "imgid": 0, "sentid": 1}, + # {"raw": "Two friends enjoy time spent together.", + # "imgid": 0, "sentid": 2} + # ], + # "split": "train", + # "filename": "1000092795.jpg" + # }, + + if img['split'] != self.split: + continue + + for sentence in img['sentences']: + data_info = { + 'image_id': img['imgid'], + 'img_path': file_backend.join_path(img_prefix, + img['filename']), + 'gt_caption': sentence['raw'] + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/flickr30k_retrieval.py b/mmpretrain/datasets/flickr30k_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..9f43c151b2079b3f72cf620577923efc57987316 --- /dev/null +++ b/mmpretrain/datasets/flickr30k_retrieval.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List + +import mmengine +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class Flickr30kRetrieval(BaseDataset): + """Flickr30k Retrieval dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + # get file backend + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + annotations = mmengine.load(self.ann_file) + + # mapping img_id to img filename + img_dict = OrderedDict() + img_idx = 0 + sentence_idx = 0 + train_list = [] + for img in annotations['images']: + + # img_example={ + # "sentids": [0, 1, 2], + # "imgid": 0, + # "sentences": [ + # {"raw": "Two men in green shirts standing in a yard.", + # "imgid": 0, "sentid": 0}, + # {"raw": "A man in a blue shirt standing in a garden.", + # "imgid": 0, "sentid": 1}, + # {"raw": "Two friends enjoy time spent together.", + # "imgid": 0, "sentid": 2} + # ], + # "split": "train", + # "filename": "1000092795.jpg" + # }, + + if img['split'] != self.split: + continue + + # create new idx for image + train_image = dict( + ori_id=img['imgid'], + image_id=img_idx, # used for evaluation + img_path=file_backend.join_path(img_prefix, img['filename']), + text=[], + gt_text_id=[], + gt_image_id=[], + ) + + for sentence in img['sentences']: + ann = {} + ann['text'] = sentence['raw'] + ann['ori_id'] = sentence['sentid'] + ann['text_id'] = sentence_idx # used for evaluation + + ann['image_ori_id'] = train_image['ori_id'] + ann['image_id'] = train_image['image_id'] + ann['img_path'] = train_image['img_path'] + ann['is_matched'] = True + + # 1. prepare train data list item + train_list.append(ann) + # 2. prepare eval data list item based on img dict + train_image['text'].append(ann['text']) + train_image['gt_text_id'].append(ann['text_id']) + train_image['gt_image_id'].append(ann['image_id']) + + sentence_idx += 1 + + img_dict[img['imgid']] = train_image + img_idx += 1 + + self.img_size = len(img_dict) + self.text_size = len(train_list) + + # return needed format data list + if self.test_mode: + return list(img_dict.values()) + return train_list diff --git a/mmpretrain/datasets/flowers102.py b/mmpretrain/datasets/flowers102.py new file mode 100644 index 0000000000000000000000000000000000000000..fe76dcc8422c8692261800b204a6262b60002e81 --- /dev/null +++ b/mmpretrain/datasets/flowers102.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class Flowers102(BaseDataset): + """The Oxford 102 Flower Dataset. + + Support the `Oxford 102 Flowers Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Flowers102 dataset directory: :: + + Flowers102 + ├── jpg + │ ├── image_00001.jpg + │ ├── image_00002.jpg + │ └── ... + ├── imagelabels.mat + ├── setid.mat + └── ... + + Args: + data_root (str): The root directory for Oxford 102 Flowers dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import Flowers102 + >>> train_dataset = Flowers102(data_root='data/Flowers102', split='trainval') + >>> train_dataset + Dataset Flowers102 + Number of samples: 2040 + Root of dataset: data/Flowers102 + >>> test_dataset = Flowers102(data_root='data/Flowers102', split='test') + >>> test_dataset + Dataset Flowers102 + Number of samples: 6149 + Root of dataset: data/Flowers102 + """ # noqa: E501 + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + ann_file = 'imagelabels.mat' + data_prefix = 'jpg' + train_test_split_file = 'setid.mat' + test_mode = split == 'test' + + self.backend = get_file_backend(data_root, enable_singleton=True) + + self.train_test_split_file = self.backend.join_path( + data_root, train_test_split_file) + + super(Flowers102, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + label_dict = mat4py.loadmat(self.ann_file)['labels'] + split_list = mat4py.loadmat(self.train_test_split_file) + + if self.split == 'train': + split_list = split_list['trnid'] + elif self.split == 'val': + split_list = split_list['valid'] + elif self.split == 'test': + split_list = split_list['tstid'] + else: + train_ids = split_list['trnid'] + val_ids = split_list['valid'] + train_ids.extend(val_ids) + split_list = train_ids + + data_list = [] + for sample_id in split_list: + img_name = 'image_%05d.jpg' % (sample_id) + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = int(label_dict[sample_id - 1]) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/food101.py b/mmpretrain/datasets/food101.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce7ffeee91c6843c259149770e9de4ad9f4317a --- /dev/null +++ b/mmpretrain/datasets/food101.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FOOD101_CATEGORIES + + +@DATASETS.register_module() +class Food101(BaseDataset): + """The Food101 Dataset. + + Support the `Food101 Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Food101 dataset directory: :: + + food-101 + ├── images + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── meta + │ ├── train.txt + │ └── test.txt + └── .... + + Args: + data_root (str): The root directory for Food101 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import Food101 + >>> train_dataset = Food101(data_root='data/food-101', split='train') + >>> train_dataset + Dataset Food101 + Number of samples: 75750 + Number of categories: 101 + Root of dataset: data/food-101 + >>> test_dataset = Food101(data_root='data/food-101', split='test') + >>> test_dataset + Dataset Food101 + Number of samples: 25250 + Number of categories: 101 + Root of dataset: data/food-101 + """ # noqa: E501 + + METAINFO = {'classes': FOOD101_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'train': + ann_file = self.backend.join_path('meta', 'train.txt') + else: + ann_file = self.backend.join_path('meta', 'test.txt') + + test_mode = split == 'test' + data_prefix = 'images' + + super(Food101, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + class_name, img_name = pair.split('/') + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, class_name, + img_name) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/gqa_dataset.py b/mmpretrain/datasets/gqa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..741791bc2bb51f768e8907aac7f002f0e730aeea --- /dev/null +++ b/mmpretrain/datasets/gqa_dataset.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class GQA(BaseDataset): + """GQA dataset. + + We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501 + + train: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501 + val: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501 + test: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501 + + and images from the official website: + https://cs.stanford.edu/people/dorarad/gqa/index.html + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # 'question': "Is it overcast?", + # 'answer': 'no, + # 'image_id': n161313.jpg, + # 'question_id': 262148000, + # .... + # } + data_info = dict() + data_info['img_path'] = osp.join(self.data_prefix['img_path'], + ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = ann['answer'] + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/iconqa.py b/mmpretrain/datasets/iconqa.py new file mode 100644 index 0000000000000000000000000000000000000000..20c4d87ddea463f7c326cb0062b2634d4d06342e --- /dev/null +++ b/mmpretrain/datasets/iconqa.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import list_dir_or_file +from mmengine.utils import check_file_exist + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class IconQA(BaseDataset): + """IconQA: A benchmark for abstract diagram understanding + and visual language reasoning. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of the specific task and split. + eg. ``iconqa/val/choose_text/``. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + sample_list = list( + list_dir_or_file(self.data_prefix['img_path'], list_file=False)) + + data_list = list() + for sample_id in sample_list: + # data json + # { + # "question": "How likely is it that you will pick a black one?", + # "choices": [ + # "certain", + # "unlikely", + # "impossible", + # "probable" + # ], + # "answer": 2, + # "ques_type": "choose_txt", + # "grade": "grade1", + # "label": "S2" + # } + data_info = mmengine.load( + mmengine.join_path(self.data_prefix['img_path'], sample_id, + 'data.json')) + data_info['gt_answer'] = data_info['choices'][int( + data_info['answer'])] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], sample_id, 'image.png') + check_file_exist(data_info['img_path']) + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/imagenet.py b/mmpretrain/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..771d6ee454e3dc094962ca09036888f97ffb2d21 --- /dev/null +++ b/mmpretrain/datasets/imagenet.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +from mmengine import fileio +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .categories import IMAGENET_CATEGORIES +from .custom import CustomDataset + + +@DATASETS.register_module() +class ImageNet(CustomDataset): + """`ImageNet `_ Dataset. + + The dataset supports two kinds of directory format, + + :: + + imagenet + ├── train + │ ├──class_x + | | ├── x1.jpg + | | ├── x2.jpg + | | └── ... + │ ├── class_y + | | ├── y1.jpg + | | ├── y2.jpg + | | └── ... + | └── ... + ├── val + │ ├──class_x + | | └── ... + │ ├── class_y + | | └── ... + | └── ... + └── test + ├── test1.jpg + ├── test2.jpg + └── ... + + or :: + + imagenet + ├── train + │ ├── x1.jpg + │ ├── y1.jpg + │ └── ... + ├── val + │ ├── x3.jpg + │ ├── y3.jpg + │ └── ... + ├── test + │ ├── test1.jpg + │ ├── test2.jpg + │ └── ... + └── meta + ├── train.txt + └── val.txt + + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + split (str): The dataset split, supports "train", "val" and "test". + Default to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + + + Examples: + >>> from mmpretrain.datasets import ImageNet + >>> train_dataset = ImageNet(data_root='data/imagenet', split='train') + >>> train_dataset + Dataset ImageNet + Number of samples: 1281167 + Number of categories: 1000 + Root of dataset: data/imagenet + >>> test_dataset = ImageNet(data_root='data/imagenet', split='val') + >>> test_dataset + Dataset ImageNet + Number of samples: 50000 + Number of categories: 1000 + Root of dataset: data/imagenet + """ # noqa: E501 + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + METAINFO = {'classes': IMAGENET_CATEGORIES} + + def __init__(self, + data_root: str = '', + split: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + **kwargs): + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + + if split: + splits = ['train', 'val', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + + if split == 'test': + logger = MMLogger.get_current_instance() + logger.info( + 'Since the ImageNet1k test set does not provide label' + 'annotations, `with_label` is set to False') + kwargs['with_label'] = False + + data_prefix = split if data_prefix == '' else data_prefix + + if ann_file == '': + _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt') + if fileio.exists(_ann_path): + ann_file = fileio.join_path('meta', f'{split}.txt') + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body + + +@DATASETS.register_module() +class ImageNet21k(CustomDataset): + """ImageNet21k Dataset. + + Since the dataset ImageNet21k is extremely big, contains 21k+ classes + and 1.4B files. We won't provide the default categories list. Please + specify it from the ``classes`` argument. + The dataset directory structure is as follows, + + ImageNet21k dataset directory :: + + imagenet21k + ├── train + │ ├──class_x + | | ├── x1.jpg + | | ├── x2.jpg + | | └── ... + │ ├── class_y + | | ├── y1.jpg + | | ├── y2.jpg + | | └── ... + | └── ... + └── meta + └── train.txt + + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + multi_label (bool): Not implement by now. Use multi label or not. + Defaults to False. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import ImageNet21k + >>> train_dataset = ImageNet21k(data_root='data/imagenet21k', split='train') + >>> train_dataset + Dataset ImageNet21k + Number of samples: 14197088 + Annotation file: data/imagenet21k/meta/train.txt + Prefix of images: data/imagenet21k/train + """ # noqa: E501 + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + + def __init__(self, + data_root: str = '', + split: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + multi_label: bool = False, + **kwargs): + if multi_label: + raise NotImplementedError( + 'The `multi_label` option is not supported by now.') + self.multi_label = multi_label + + if split: + splits = ['train'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'.\ + If you want to specify your own validation set or test set,\ + please set split to None." + + self.split = split + data_prefix = split if data_prefix == '' else data_prefix + + if not ann_file: + _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt') + if fileio.exists(_ann_path): + ann_file = fileio.join_path('meta', f'{split}.txt') + + logger = MMLogger.get_current_instance() + + if not ann_file: + logger.warning( + 'The ImageNet21k dataset is large, and scanning directory may ' + 'consume long time. Considering to specify the `ann_file` to ' + 'accelerate the initialization.') + + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) + + if self.CLASSES is None: + logger.warning( + 'The CLASSES is not stored in the `ImageNet21k` class. ' + 'Considering to specify the `classes` argument if you need ' + 'do inference on the ImageNet-21k dataset') diff --git a/mmpretrain/datasets/infographic_vqa.py b/mmpretrain/datasets/infographic_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..46f5b0a37455677fe548c04f305cffb77402b775 --- /dev/null +++ b/mmpretrain/datasets/infographic_vqa.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class InfographicVQA(BaseDataset): + """Infographic VQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + annotations = annotations['data'] + + data_list = [] + for ann in annotations: + # ann example + # { + # "questionId": 98313, + # "question": "Which social platform has heavy female audience?", + # "image_local_name": "37313.jpeg", + # "image_url": "https://xxx.png", + # "ocr_output_file": "37313.json", + # "answers": [ + # "pinterest" + # ], + # "data_split": "val" + # } + data_info = dict() + data_info['question'] = ann['question'] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image_local_name']) + if 'answers' in ann.keys(): # test splits do not include gt + data_info['gt_answer'] = ann['answers'] + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/inshop.py b/mmpretrain/datasets/inshop.py new file mode 100644 index 0000000000000000000000000000000000000000..f64f1779632d4a98d0e36d59750f4a1e8cbd4aed --- /dev/null +++ b/mmpretrain/datasets/inshop.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class InShop(BaseDataset): + """InShop Dataset for Image Retrieval. + + Please download the images from the homepage + 'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html' + (In-shop Clothes Retrieval Benchmark -> Img -> img.zip, + Eval/list_eval_partition.txt), and organize them as follows way: :: + + In-shop Clothes Retrieval Benchmark (data_root)/ + ├── Eval / + │ └── list_eval_partition.txt (ann_file) + ├── Img (img_prefix) + │ └── img/ + ├── README.txt + └── ..... + + Args: + data_root (str): The root directory for dataset. + split (str): Choose from 'train', 'query' and 'gallery'. + Defaults to 'train'. + data_prefix (str | dict): Prefix for training data. + Defaults to 'Img'. + ann_file (str): Annotation file path, path relative to + ``data_root``. Defaults to 'Eval/list_eval_partition.txt'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import InShop + >>> + >>> # build train InShop dataset + >>> inshop_train_cfg = dict(data_root='data/inshop', split='train') + >>> inshop_train = InShop(**inshop_train_cfg) + >>> inshop_train + Dataset InShop + Number of samples: 25882 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build query InShop dataset + >>> inshop_query_cfg = dict(data_root='data/inshop', split='query') + >>> inshop_query = InShop(**inshop_query_cfg) + >>> inshop_query + Dataset InShop + Number of samples: 14218 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build gallery InShop dataset + >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery') + >>> inshop_gallery = InShop(**inshop_gallery_cfg) + >>> inshop_gallery + Dataset InShop + Number of samples: 12612 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + """ + + def __init__(self, + data_root: str, + split: str = 'train', + data_prefix: str = 'Img', + ann_file: str = 'Eval/list_eval_partition.txt', + **kwargs): + + assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \ + f" must be one of ['train', 'query', 'gallery'], bu get '{split}'" + self.backend = get_file_backend(data_root, enable_singleton=True) + self.split = split + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + **kwargs) + + def _process_annotations(self): + lines = list_from_file(self.ann_file) + + anno_train = dict(metainfo=dict(), data_list=list()) + anno_gallery = dict(metainfo=dict(), data_list=list()) + + # item_id to label, each item corresponds to one class label + class_num = 0 + gt_label_train = {} + + # item_id to label, each label corresponds to several items + gallery_num = 0 + gt_label_gallery = {} + + # (lines[0], lines[1]) is the image number and the field name; + # Each line format as 'image_name, item_id, evaluation_status' + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'train': + if item_id not in gt_label_train: + gt_label_train[item_id] = class_num + class_num += 1 + # item_id to class_id (for the training set) + anno_train['data_list'].append( + dict(img_path=img_path, gt_label=gt_label_train[item_id])) + elif status == 'gallery': + if item_id not in gt_label_gallery: + gt_label_gallery[item_id] = [] + # Since there are multiple images for each item, + # record the corresponding item for each image. + gt_label_gallery[item_id].append(gallery_num) + anno_gallery['data_list'].append( + dict(img_path=img_path, sample_idx=gallery_num)) + gallery_num += 1 + + if self.split == 'train': + anno_train['metainfo']['class_number'] = class_num + anno_train['metainfo']['sample_number'] = \ + len(anno_train['data_list']) + return anno_train + elif self.split == 'gallery': + anno_gallery['metainfo']['sample_number'] = gallery_num + return anno_gallery + + # Generate the label for the query(val) set + anno_query = dict(metainfo=dict(), data_list=list()) + query_num = 0 + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'query': + anno_query['data_list'].append( + dict( + img_path=img_path, gt_label=gt_label_gallery[item_id])) + query_num += 1 + + anno_query['metainfo']['sample_number'] = query_num + return anno_query + + def load_data_list(self): + """load data list. + + For the train set, return image and ground truth label. For the query + set, return image and ids of images in gallery. For the gallery set, + return image and its id. + """ + data_info = self._process_annotations() + data_list = data_info['data_list'] + return data_list + + def extra_repr(self): + """The extra repr information of the dataset.""" + body = [f'Root of dataset: \t{self.data_root}'] + return body diff --git a/mmpretrain/datasets/minigpt4_dataset.py b/mmpretrain/datasets/minigpt4_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e14e5c354e26b7a3810173d3a344f96c9a3ee049 --- /dev/null +++ b/mmpretrain/datasets/minigpt4_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class MiniGPT4Dataset(BaseDataset): + """Dataset for training MiniGPT4. + + MiniGPT4 dataset directory: + + minigpt4_dataset + ├── image + │ ├── id0.jpg + │ │── id1.jpg + │ │── id2.jpg + │ └── ... + └── conversation_data.json + + The structure of conversation_data.json: + + [ + // English data + { + "id": str(id0), + "conversation": "###Ask: [Ask content] + ###Answer: [Answer content]" + }, + + // Chinese data + { + "id": str(id1), + "conversation": "###问: [Ask content] + ###答:[Answer content]" + }, + + ... + ] + + Args: + data_root (str): The root directory for ``ann_file`` and ``image``. + ann_file (str): Conversation file path. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + file_backend = get_file_backend(self.data_root) + conversation_path = file_backend.join_path(self.data_root, + self.ann_file) + conversation = mmengine.load(conversation_path) + img_ids = {} + n = 0 + for conv in conversation: + img_id = conv['id'] + if img_id not in img_ids.keys(): + img_ids[img_id] = n + n += 1 + + img_root = file_backend.join_path(self.data_root, 'image') + data_list = [] + for conv in conversation: + img_file = '{}.jpg'.format(conv['id']) + chat_content = conv['conversation'] + lang = 'en' if chat_content.startswith('###Ask: ') else 'zh' + data_info = { + 'image_id': img_ids[conv['id']], + 'img_path': file_backend.join_path(img_root, img_file), + 'chat_content': chat_content, + 'lang': lang, + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/mnist.py b/mmpretrain/datasets/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..425267fe8034860d3b78c6af5b565ddb6efc7c10 --- /dev/null +++ b/mmpretrain/datasets/mnist.py @@ -0,0 +1,234 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import codecs +from typing import List, Optional +from urllib.parse import urljoin + +import mmengine.dist as dist +import numpy as np +import torch +from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES +from .utils import (download_and_extract_archive, open_maybe_compressed_file, + rm_suffix) + + +@DATASETS.register_module() +class MNIST(BaseDataset): + """`MNIST `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py + + Args: + data_root (str): The root directory of the MNIST Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ # noqa: E501 + + url_prefix = 'http://yann.lecun.com/exdb/mnist/' + # train images and labels + train_list = [ + ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'], + ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'], + ] + # test images and labels + test_list = [ + ['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'], + ['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'], + ] + METAINFO = {'classes': MNIST_CATEGORITES} + + def __init__(self, + data_root: str = '', + split: str = 'train', + metainfo: Optional[dict] = None, + download: bool = True, + data_prefix: str = '', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + if not data_root and not data_prefix: + raise RuntimeError('Please set ``data_root`` to' + 'specify the dataset path') + + self.download = download + super().__init__( + # The MNIST dataset doesn't need specify annotation file + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=dict(root=data_prefix), + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) + + if dist.is_main_process() and not self._check_exists(): + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') + + if self.download: + self._download() + else: + raise RuntimeError( + f'Cannot find {self.__class__.__name__} dataset in ' + f"{self.data_prefix['root']}, you can specify " + '`download=True` to download automatically.') + + dist.barrier() + assert self._check_exists(), \ + 'Download failed or shared storage is unavailable. Please ' \ + f'download the dataset manually through {self.url_prefix}.' + + if not self.test_mode: + file_list = self.train_list + else: + file_list = self.test_list + + # load data from SN3 files + imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0]))) + gt_labels = read_label_file( + join_path(root, rm_suffix(file_list[1][0]))) + + data_infos = [] + for img, gt_label in zip(imgs, gt_labels): + gt_label = np.array(gt_label, dtype=np.int64) + info = {'img': img.numpy(), 'gt_label': gt_label} + data_infos.append(info) + return data_infos + + def _check_exists(self): + """Check the exists of data files.""" + root = self.data_prefix['root'] + + for filename, _ in (self.train_list + self.test_list): + # get extracted filename of data + extract_filename = rm_suffix(filename) + fpath = join_path(root, extract_filename) + if not exists(fpath): + return False + return True + + def _download(self): + """Download and extract data files.""" + root = self.data_prefix['root'] + + for filename, md5 in (self.train_list + self.test_list): + url = urljoin(self.url_prefix, filename) + download_and_extract_archive( + url, download_root=root, filename=filename, md5=md5) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [f"Prefix of data: \t{self.data_prefix['root']}"] + return body + + +@DATASETS.register_module() +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ + Dataset. + + Args: + data_root (str): The root directory of the MNIST Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + # train images and labels + train_list = [ + ['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'], + ['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'], + ] + # test images and labels + test_list = [ + ['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'], + ['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'], + ] + METAINFO = {'classes': FASHIONMNIST_CATEGORITES} + + +def get_int(b: bytes) -> int: + """Convert bytes to int.""" + return int(codecs.encode(b, 'hex'), 16) + + +def read_sn3_pascalvincent_tensor(path: str, + strict: bool = True) -> torch.Tensor: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx- + io.lsh'). + + Argument may be a filename, compressed filename, or file object. + """ + # typemap + if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): + read_sn3_pascalvincent_tensor.typemap = { + 8: (torch.uint8, np.uint8, np.uint8), + 9: (torch.int8, np.int8, np.int8), + 11: (torch.int16, np.dtype('>i2'), 'i2'), + 12: (torch.int32, np.dtype('>i4'), 'i4'), + 13: (torch.float32, np.dtype('>f4'), 'f4'), + 14: (torch.float64, np.dtype('>f8'), 'f8') + } + # read + with open_maybe_compressed_file(path) as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert nd >= 1 and nd <= 3 + assert ty >= 8 and ty <= 14 + m = read_sn3_pascalvincent_tensor.typemap[ty] + s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)] + parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) + assert parsed.shape[0] == np.prod(s) or not strict + return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) + + +def read_label_file(path: str) -> torch.Tensor: + """Read labels from SN3 label file.""" + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert (x.dtype == torch.uint8) + assert (x.ndimension() == 1) + return x.long() + + +def read_image_file(path: str) -> torch.Tensor: + """Read images from SN3 image file.""" + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert (x.dtype == torch.uint8) + assert (x.ndimension() == 3) + return x diff --git a/mmpretrain/datasets/multi_label.py b/mmpretrain/datasets/multi_label.py new file mode 100644 index 0000000000000000000000000000000000000000..58a9c7cd5f097689d29700004e2ed815934a1594 --- /dev/null +++ b/mmpretrain/datasets/multi_label.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class MultiLabelDataset(BaseDataset): + """Multi-label Dataset. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + The annotation format is shown as follows. + + .. code-block:: none + + { + "metainfo": + { + "classes":['A', 'B', 'C'....] + }, + "data_list": + [ + { + "img_path": "test_img1.jpg", + 'gt_label': [0, 1], + }, + { + "img_path": "test_img2.jpg", + 'gt_label': [2], + }, + ] + .... + } + + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category ids by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image categories of specified index. + """ + return self.get_data_info(idx)['gt_label'] diff --git a/mmpretrain/datasets/multi_task.py b/mmpretrain/datasets/multi_task.py new file mode 100644 index 0000000000000000000000000000000000000000..443df0e7d7de11962d472d33b25b4bbff562524f --- /dev/null +++ b/mmpretrain/datasets/multi_task.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from os import PathLike +from typing import Optional, Sequence + +import mmengine +from mmcv.transforms import Compose +from mmengine.fileio import get_file_backend + +from .builder import DATASETS + + +def expanduser(path): + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +def isabs(uri): + return osp.isabs(uri) or ('://' in uri) + + +@DATASETS.register_module() +class MultiTaskDataset: + """Custom dataset for multi-task dataset. + + To use the dataset, please generate and provide an annotation file in the + below format: + + .. code-block:: json + + { + "metainfo": { + "tasks": + [ + 'gender' + 'wear' + ] + }, + "data_list": [ + { + "img_path": "a.jpg", + gt_label:{ + "gender": 0, + "wear": [1, 0, 1, 0] + } + }, + { + "img_path": "b.jpg", + gt_label:{ + "gender": 1, + "wear": [1, 0, 1, 0] + } + } + ] + } + + Assume we put our dataset in the ``data/mydataset`` folder in the + repository and organize it as the below format: :: + + mmpretrain/ + └── data + └── mydataset + ├── annotation + │   ├── train.json + │   ├── test.json + │   └── val.json + ├── train + │   ├── a.jpg + │   └── ... + ├── test + │   ├── b.jpg + │   └── ... + └── val + ├── c.jpg + └── ... + + We can use the below config to build datasets: + + .. code:: python + + >>> from mmpretrain.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="annotation/train.json", + ... data_root="data/mydataset", + ... # The `img_path` field in the train annotation file is relative + ... # to the `train` folder. + ... data_prefix='train', + ... ) + >>> train_dataset = build_dataset(train_cfg) + + Or we can put all files in the same folder: :: + + mmpretrain/ + └── data + └── mydataset + ├── train.json + ├── test.json + ├── val.json + ├── a.jpg + ├── b.jpg + ├── c.jpg + └── ... + + And we can use the below config to build datasets: + + .. code:: python + + >>> from mmpretrain.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="train.json", + ... data_root="data/mydataset", + ... # the `data_prefix` is not required since all paths are + ... # relative to the `data_root`. + ... ) + >>> train_dataset = build_dataset(train_cfg) + + + Args: + ann_file (str): The annotation file path. It can be either absolute + path or relative path to the ``data_root``. + metainfo (dict, optional): The extra meta information. It should be + a dict with the same format as the ``"metainfo"`` field in the + annotation file. Defaults to None. + data_root (str, optional): The root path of the data directory. It's + the prefix of the ``data_prefix`` and the ``ann_file``. And it can + be a remote path like "s3://openmmlab/xxx/". Defaults to None. + data_prefix (str, optional): The base folder relative to the + ``data_root`` for the ``"img_path"`` field in the annotation file. + Defaults to None. + pipeline (Sequence[dict]): A list of dict, where each element + represents a operation defined in + :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple. + test_mode (bool): in train mode or test mode. Defaults to False. + """ + METAINFO = dict() + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: Optional[str] = None, + pipeline: Sequence = (), + test_mode: bool = False): + + self.data_root = expanduser(data_root) + + # Inference the file client + if self.data_root is not None: + self.file_backend = get_file_backend(uri=self.data_root) + else: + self.file_backend = None + + self.ann_file = self._join_root(expanduser(ann_file)) + self.data_prefix = self._join_root(data_prefix) + + self.test_mode = test_mode + self.pipeline = Compose(pipeline) + self.data_list = self.load_data_list(self.ann_file, metainfo) + + def _join_root(self, path): + """Join ``self.data_root`` with the specified path. + + If the path is an absolute path, just return the path. And if the + path is None, return ``self.data_root``. + + Examples: + >>> self.data_root = 'a/b/c' + >>> self._join_root('d/e/') + 'a/b/c/d/e' + >>> self._join_root('https://openmmlab.com') + 'https://openmmlab.com' + >>> self._join_root(None) + 'a/b/c' + """ + if path is None: + return self.data_root + if isabs(path): + return path + + joined_path = self.file_backend.join_path(self.data_root, path) + return joined_path + + @classmethod + def _get_meta_info(cls, in_metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + in_metainfo (dict): Meta information dict. + + Returns: + dict: Parsed meta information. + """ + # `cls.METAINFO` will be overwritten by in_meta + metainfo = copy.deepcopy(cls.METAINFO) + if in_metainfo is None: + return metainfo + + metainfo.update(in_metainfo) + + return metainfo + + def load_data_list(self, ann_file, metainfo_override=None): + """Load annotations from an annotation file. + + Args: + ann_file (str): Absolute annotation file path if ``self.root=None`` + or relative path if ``self.root=/path/to/data/``. + + Returns: + list[dict]: A list of annotation. + """ + annotations = mmengine.load(ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + if 'data_list' not in annotations: + raise ValueError('The annotation file must have the `data_list` ' + 'field.') + metainfo = annotations.get('metainfo', {}) + raw_data_list = annotations['data_list'] + + # Set meta information. + assert isinstance(metainfo, dict), 'The `metainfo` field in the '\ + f'annotation file should be a dict, but got {type(metainfo)}' + if metainfo_override is not None: + assert isinstance(metainfo_override, dict), 'The `metainfo` ' \ + f'argument should be a dict, but got {type(metainfo_override)}' + metainfo.update(metainfo_override) + self._metainfo = self._get_meta_info(metainfo) + + data_list = [] + for i, raw_data in enumerate(raw_data_list): + try: + data_list.append(self.parse_data_info(raw_data)) + except AssertionError as e: + raise RuntimeError( + f'The format check fails during parse the item {i} of ' + f'the annotation file with error: {e}') + return data_list + + def parse_data_info(self, raw_data): + """Parse raw annotation to target format. + + This method will return a dict which contains the data information of a + sample. + + Args: + raw_data (dict): Raw data information load from ``ann_file`` + + Returns: + dict: Parsed annotation. + """ + assert isinstance(raw_data, dict), \ + f'The item should be a dict, but got {type(raw_data)}' + assert 'img_path' in raw_data, \ + "The item doesn't have `img_path` field." + data = dict( + img_path=self._join_root(raw_data['img_path']), + gt_label=raw_data['gt_label'], + ) + return data + + @property + def metainfo(self) -> dict: + """Get meta information of dataset. + + Returns: + dict: meta information collected from ``cls.METAINFO``, + annotation file and metainfo argument during instantiation. + """ + return copy.deepcopy(self._metainfo) + + def prepare_data(self, idx): + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + results = copy.deepcopy(self.data_list[idx]) + return self.pipeline(results) + + def __len__(self): + """Get the length of the whole dataset. + + Returns: + int: The length of filtered dataset. + """ + return len(self.data_list) + + def __getitem__(self, idx): + """Get the idx-th image and data information of dataset after + ``self.pipeline``. + + Args: + idx (int): The index of of the data. + + Returns: + dict: The idx-th image and data information after + ``self.pipeline``. + """ + return self.prepare_data(idx) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [f'Number of samples: \t{self.__len__()}'] + if self.data_root is not None: + body.append(f'Root location: \t{self.data_root}') + body.append(f'Annotation file: \t{self.ann_file}') + if self.data_prefix is not None: + body.append(f'Prefix of images: \t{self.data_prefix}') + # -------------------- extra repr -------------------- + tasks = self.metainfo['tasks'] + body.append(f'For {len(tasks)} tasks') + for task in tasks: + body.append(f' {task} ') + # ---------------------------------------------------- + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmpretrain/datasets/nlvr2.py b/mmpretrain/datasets/nlvr2.py new file mode 100644 index 0000000000000000000000000000000000000000..0063090657714406049a6daa6fa3c0d868422590 --- /dev/null +++ b/mmpretrain/datasets/nlvr2.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import List + +from mmengine.fileio import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class NLVR2(BaseDataset): + """COCO Caption dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list.""" + + data_list = [] + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + examples = list_from_file(self.ann_file) + + for example in examples: + example = json.loads(example) + prefix = example['identifier'].rsplit('-', 1)[0] + train_data = {} + train_data['text'] = example['sentence'] + train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']] + train_data['img_path'] = [ + file_backend.join_path(img_prefix, prefix + f'-img{i}.png') + for i in range(2) + ] + + data_list.append(train_data) + + return data_list diff --git a/mmpretrain/datasets/nocaps.py b/mmpretrain/datasets/nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..65116e9cecc2d9983ef72ca3eee24ff7baedacc0 --- /dev/null +++ b/mmpretrain/datasets/nocaps.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class NoCaps(BaseDataset): + """NoCaps dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``.. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + + file_backend = get_file_backend(img_prefix) + data_list = [] + for ann in coco.anns.values(): + image_id = ann['image_id'] + image_path = file_backend.join_path( + img_prefix, coco.imgs[image_id]['file_name']) + data_info = { + 'image_id': image_id, + 'img_path': image_path, + 'gt_caption': None + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/ocr_vqa.py b/mmpretrain/datasets/ocr_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..55aa6913e3c4464444e8b971ccabf68aa2d99904 --- /dev/null +++ b/mmpretrain/datasets/ocr_vqa.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class OCRVQA(BaseDataset): + """OCR-VQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + + split_dict = {1: 'train', 2: 'val', 3: 'test'} + + annotations = mmengine.load(self.ann_file) + + # ann example + # "761183272": { + # "imageURL": \ + # "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg", + # "questions": [ + # "Who wrote this book?", + # "What is the title of this book?", + # "What is the genre of this book?", + # "Is this a games related book?", + # "What is the year printed on this calendar?"], + # "answers": [ + # "Sandra Boynton", + # "Mom's Family Wall Calendar 2016", + # "Calendars", + # "No", + # "2016"], + # "title": "Mom's Family Wall Calendar 2016", + # "authorName": "Sandra Boynton", + # "genre": "Calendars", + # "split": 1 + # }, + + data_list = [] + + for key, ann in annotations.items(): + if self.split != split_dict[ann['split']]: + continue + + extension = osp.splitext(ann['imageURL'])[1] + if extension not in ['.jpg', '.png']: + continue + img_path = mmengine.join_path(self.data_prefix['img_path'], + key + extension) + for question, answer in zip(ann['questions'], ann['answers']): + data_info = {} + data_info['img_path'] = img_path + data_info['question'] = question + data_info['gt_answer'] = answer + data_info['gt_answer_weight'] = [1.0] + + data_info['imageURL'] = ann['imageURL'] + data_info['title'] = ann['title'] + data_info['authorName'] = ann['authorName'] + data_info['genre'] = ann['genre'] + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/oxfordiiitpet.py b/mmpretrain/datasets/oxfordiiitpet.py new file mode 100644 index 0000000000000000000000000000000000000000..23c8b7db8679e99c6ed2698b9eb140cd6151d445 --- /dev/null +++ b/mmpretrain/datasets/oxfordiiitpet.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import OxfordIIITPet_CATEGORIES + + +@DATASETS.register_module() +class OxfordIIITPet(BaseDataset): + """The Oxford-IIIT Pets Dataset. + + Support the `Oxford-IIIT Pets Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Oxford-IIIT_Pets dataset directory: :: + + Oxford-IIIT_Pets + ├── images + │ ├── Abyssinian_1.jpg + │ ├── Abyssinian_2.jpg + │ └── ... + ├── annotations + │ ├── trainval.txt + │ ├── test.txt + │ ├── list.txt + │ └── ... + └── .... + + Args: + data_root (str): The root directory for Oxford-IIIT Pets dataset. + split (str, optional): The dataset split, supports "trainval" and "test". + Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import OxfordIIITPet + >>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval') + >>> train_dataset + Dataset OxfordIIITPet + Number of samples: 3680 + Number of categories: 37 + Root of dataset: data/Oxford-IIIT_Pets + >>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test') + >>> test_dataset + Dataset OxfordIIITPet + Number of samples: 3669 + Number of categories: 37 + Root of dataset: data/Oxford-IIIT_Pets + """ # noqa: E501 + + METAINFO = {'classes': OxfordIIITPet_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'trainval': + ann_file = self.backend.join_path('annotations', 'trainval.txt') + else: + ann_file = self.backend.join_path('annotations', 'test.txt') + + data_prefix = 'images' + test_mode = split == 'test' + + super(OxfordIIITPet, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + img_name, class_id, _, _ = pair.split() + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = int(class_id) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/places205.py b/mmpretrain/datasets/places205.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ba1ff631a7a4840b66cf63ec53585ec064560d --- /dev/null +++ b/mmpretrain/datasets/places205.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +from mmpretrain.registry import DATASETS +from .categories import PLACES205_CATEGORIES +from .custom import CustomDataset + + +@DATASETS.register_module() +class Places205(CustomDataset): + """`Places205 `_ Dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults + to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + """ + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + METAINFO = {'classes': PLACES205_CATEGORIES} + + def __init__(self, + data_root: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + **kwargs): + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) diff --git a/mmpretrain/datasets/refcoco.py b/mmpretrain/datasets/refcoco.py new file mode 100644 index 0000000000000000000000000000000000000000..39c3d3e65e5ffdcb5a49fc183473138cfba8938a --- /dev/null +++ b/mmpretrain/datasets/refcoco.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class RefCOCO(BaseDataset): + """RefCOCO dataset. + + RefCOCO is a popular dataset used for the task of visual grounding. + Here are the steps for accessing and utilizing the + RefCOCO dataset. + + You can access the RefCOCO dataset from the official source: + https://github.com/lichengunc/refer + + The RefCOCO dataset is organized in a structured format: :: + + FeaturesDict({ + 'coco_annotations': Sequence({ + 'area': int64, + 'bbox': BBoxFeature(shape=(4,), dtype=float32), + 'id': int64, + 'label': int64, + }), + 'image': Image(shape=(None, None, 3), dtype=uint8), + 'image/id': int64, + 'objects': Sequence({ + 'area': int64, + 'bbox': BBoxFeature(shape=(4,), dtype=float32), + 'gt_box_index': int64, + 'id': int64, + 'label': int64, + 'refexp': Sequence({ + 'raw': Text(shape=(), dtype=string), + 'refexp_id': int64, + }), + }), + }) + + Args: + ann_file (str): Annotation file path. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str): Prefix for training data. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root, + ann_file, + data_prefix, + split_file, + split='train', + **kwargs): + self.split_file = split_file + self.split = split + + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwargs, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.split_file) and self.split_file: + self.split_file = osp.join(self.data_root, self.split_file) + + return super()._join_prefix() + + def load_data_list(self) -> List[dict]: + """Load data list.""" + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + splits = mmengine.load(self.split_file, file_format='pkl') + img_prefix = self.data_prefix['img_path'] + + data_list = [] + join_path = mmengine.fileio.get_file_backend(img_prefix).join_path + for refer in splits: + if refer['split'] != self.split: + continue + + ann = coco.anns[refer['ann_id']] + img = coco.imgs[ann['image_id']] + sentences = refer['sentences'] + bbox = np.array(ann['bbox'], dtype=np.float32) + bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY + + for sent in sentences: + data_info = { + 'img_path': join_path(img_prefix, img['file_name']), + 'image_id': ann['image_id'], + 'ann_id': ann['id'], + 'text': sent['sent'], + 'gt_bboxes': bbox[None, :], + } + data_list.append(data_info) + + if len(data_list) == 0: + raise ValueError(f'No sample in split "{self.split}".') + + return data_list diff --git a/mmpretrain/datasets/samplers/__init__.py b/mmpretrain/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bccf9c34659e19764871a696260cf5884696ca1 --- /dev/null +++ b/mmpretrain/datasets/samplers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .repeat_aug import RepeatAugSampler +from .sequential import SequentialSampler + +__all__ = ['RepeatAugSampler', 'SequentialSampler'] diff --git a/mmpretrain/datasets/samplers/repeat_aug.py b/mmpretrain/datasets/samplers/repeat_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..d833a1954d7d9d181c368d5b3b956c25df241c1a --- /dev/null +++ b/mmpretrain/datasets/samplers/repeat_aug.py @@ -0,0 +1,101 @@ +import math +from typing import Iterator, Optional, Sized + +import torch +from mmengine.dist import get_dist_info, is_main_process, sync_random_seed +from torch.utils.data import Sampler + +from mmpretrain.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class RepeatAugSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset for + distributed, with repeated augmentation. It ensures that different each + augmented version of a sample will be visible to a different process (GPU). + Heavily based on torch.utils.data.DistributedSampler. + + This sampler was taken from + https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py + Used in + Copyright (c) 2015-present, Facebook, Inc. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + num_repeats (int): The repeat times of every sample. Defaults to 3. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + """ + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + num_repeats: int = 3, + seed: Optional[int] = None): + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if not self.shuffle and is_main_process(): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning('The RepeatAugSampler always picks a ' + 'fixed part of data if `shuffle=False`.') + + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.num_repeats = num_repeats + + # The number of repeated samples in the rank + self.num_samples = math.ceil( + len(self.dataset) * num_repeats / world_size) + # The total number of repeated samples in all ranks. + self.total_size = self.num_samples * world_size + # The number of selected samples in the rank + self.num_selected_samples = math.ceil(len(self.dataset) / world_size) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] + indices = [x for x in indices for _ in range(self.num_repeats)] + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + indices += indices[:padding_size] + assert len(indices) == self.total_size + + # subsample per rank + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + + # return up to num selected samples + return iter(indices[:self.num_selected_samples]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_selected_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/mmpretrain/datasets/samplers/sequential.py b/mmpretrain/datasets/samplers/sequential.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b940c2eabc2ab9c2401cd1923776fc067e9f6c --- /dev/null +++ b/mmpretrain/datasets/samplers/sequential.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterator + +import torch +from mmengine.dataset import DefaultSampler + +from mmpretrain.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class SequentialSampler(DefaultSampler): + """Sequential sampler which supports different subsample policy. + + Args: + dataset (Sized): The dataset. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + subsample_type (str): The method to subsample data on different rank. + Supported type: + + - ``'default'``: Original torch behavior. Sample the examples one + by one for each GPU in terms. For instance, 8 examples on 2 GPUs, + GPU0: [0,2,4,8], GPU1: [1,3,5,7] + - ``'sequential'``: Subsample all examples to n chunk sequntially. + For instance, 8 examples on 2 GPUs, + GPU0: [0,1,2,3], GPU1: [4,5,6,7] + """ + + def __init__(self, subsample_type: str = 'default', **kwargs) -> None: + super().__init__(shuffle=False, **kwargs) + + if subsample_type not in ['default', 'sequential']: + raise ValueError(f'Unsupported subsample typer "{subsample_type}",' + ' please choose from ["default", "sequential"]') + self.subsample_type = subsample_type + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + if self.subsample_type == 'default': + indices = indices[self.rank:self.total_size:self.world_size] + elif self.subsample_type == 'sequential': + num_samples_per_rank = self.total_size // self.world_size + indices = indices[self.rank * + num_samples_per_rank:(self.rank + 1) * + num_samples_per_rank] + + return iter(indices) diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py new file mode 100644 index 0000000000000000000000000000000000000000..8e442491be85540980c0309b65d32a12c9c85542 --- /dev/null +++ b/mmpretrain/datasets/scienceqa.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Callable, List, Sequence + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class ScienceQA(BaseDataset): + """ScienceQA dataset. + + This dataset is used to load the multimodal data of ScienceQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + split (str): The split of dataset. Options: ``train``, ``val``, + ``test``, ``trainval``, ``minival``, and ``minitest``. + split_file (str): The split file of dataset, which contains the + ids of data samples in the split. + ann_file (str): Annotation file path. + image_only (bool): Whether only to load data with image. Defaults to + False. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + split: str, + split_file: str, + ann_file: str, + image_only: bool = False, + data_prefix: dict = dict(img_path=''), + pipeline: Sequence[Callable] = (), + **kwargs): + assert split in [ + 'train', 'val', 'test', 'trainval', 'minival', 'minitest' + ], f'Invalid split {split}' + self.split = split + self.split_file = os.path.join(data_root, split_file) + self.image_only = image_only + + super().__init__( + data_root=data_root, + ann_file=ann_file, + data_prefix=data_prefix, + pipeline=pipeline, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + current_data_split = mmengine.load(self.split_file)[self.split] # noqa + + file_backend = get_file_backend(img_prefix) + + data_list = [] + for data_id in current_data_split: + ann = annotations[data_id] + if self.image_only and ann['image'] is None: + continue + data_info = { + 'image_id': + data_id, + 'question': + ann['question'], + 'choices': + ann['choices'], + 'gt_answer': + ann['answer'], + 'hint': + ann['hint'], + 'image_name': + ann['image'], + 'task': + ann['task'], + 'grade': + ann['grade'], + 'subject': + ann['subject'], + 'topic': + ann['topic'], + 'category': + ann['category'], + 'skill': + ann['skill'], + 'lecture': + ann['lecture'], + 'solution': + ann['solution'], + 'split': + ann['split'], + 'img_path': + file_backend.join_path(img_prefix, data_id, ann['image']) + if ann['image'] is not None else None, + 'has_image': + True if ann['image'] is not None else False, + } + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/stanfordcars.py b/mmpretrain/datasets/stanfordcars.py new file mode 100644 index 0000000000000000000000000000000000000000..355697943cf693869f35f2a0bd71abdfa0396722 --- /dev/null +++ b/mmpretrain/datasets/stanfordcars.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import STANFORDCARS_CATEGORIES + + +@DATASETS.register_module() +class StanfordCars(BaseDataset): + """The Stanford Cars Dataset. + + Support the `Stanford Cars Dataset `_ Dataset. + The official website provides two ways to organize the dataset. + Therefore, after downloading and decompression, the dataset directory structure is as follows. + + Stanford Cars dataset directory: :: + + Stanford_Cars + ├── car_ims + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── cars_annos.mat + + or :: + + Stanford_Cars + ├── cars_train + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + ├── cars_test + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── devkit + ├── cars_meta.mat + ├── cars_train_annos.mat + ├── cars_test_annos.mat + ├── cars_test_annoswithlabels.mat + ├── eval_train.m + └── train_perfect_preds.txt + + Args: + data_root (str): The root directory for Stanford Cars dataset. + split (str, optional): The dataset split, supports "train" + and "test". Default to "train". + + Examples: + >>> from mmpretrain.datasets import StanfordCars + >>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train') + >>> train_dataset + Dataset StanfordCars + Number of samples: 8144 + Number of categories: 196 + Root of dataset: data/Stanford_Cars + >>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test') + >>> test_dataset + Dataset StanfordCars + Number of samples: 8041 + Number of categories: 196 + Root of dataset: data/Stanford_Cars + """ # noqa: E501 + + METAINFO = {'classes': STANFORDCARS_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + test_mode = split == 'test' + self.backend = get_file_backend(data_root, enable_singleton=True) + + anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat') + if self.backend.exists(anno_file_path): + ann_file = 'cars_annos.mat' + data_prefix = '' + else: + if test_mode: + ann_file = self.backend.join_path( + 'devkit', 'cars_test_annos_withlabels.mat') + data_prefix = 'cars_test' + else: + ann_file = self.backend.join_path('devkit', + 'cars_train_annos.mat') + data_prefix = 'cars_train' + + if not self.backend.exists( + self.backend.join_path(data_root, ann_file)): + doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501 + raise RuntimeError( + f'The dataset is incorrectly organized, please \ + refer to {doc_url} and reorganize your folders.') + + super(StanfordCars, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + data = mat4py.loadmat(self.ann_file)['annotations'] + + data_list = [] + if 'test' in data.keys(): + # first way + img_paths, labels, test = data['relative_im_path'], data[ + 'class'], data['test'] + num = len(img_paths) + assert num == len(labels) == len(test), 'get error ann file' + for i in range(num): + if not self.test_mode and test[i] == 1: + continue + if self.test_mode and test[i] == 0: + continue + img_path = self.backend.join_path(self.img_prefix, + img_paths[i]) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + else: + # second way + img_names, labels = data['fname'], data['class'] + num = len(img_names) + assert num == len(labels), 'get error ann file' + for i in range(num): + img_path = self.backend.join_path(self.img_prefix, + img_names[i]) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/sun397.py b/mmpretrain/datasets/sun397.py new file mode 100644 index 0000000000000000000000000000000000000000..1039a0690f8096082d5c55f89d743478fdf5b22d --- /dev/null +++ b/mmpretrain/datasets/sun397.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import SUN397_CATEGORIES + + +@DATASETS.register_module() +class SUN397(BaseDataset): + """The SUN397 Dataset. + + Support the `SUN397 Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + SUN397 dataset directory: :: + + SUN397 + ├── SUN397 + │ ├── a + │ │ ├── abbey + │ | | ├── sun_aaalbzqrimafwbiv.jpg + │ | | └── ... + │ │ ├── airplane_cabin + │ | | ├── sun_aadqdkqaslqqoblu.jpg + │ | | └── ... + │ | └── ... + │ ├── b + │ │ └── ... + │ ├── c + │ │ └── ... + │ └── ... + └── Partitions + ├── ClassName.txt + ├── Training_01.txt + ├── Testing_01.txt + └── ... + + Args: + data_root (str): The root directory for Stanford Cars dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import SUN397 + >>> train_dataset = SUN397(data_root='data/SUN397', split='train') + >>> train_dataset + Dataset SUN397 + Number of samples: 19850 + Number of categories: 397 + Root of dataset: data/SUN397 + >>> test_dataset = SUN397(data_root='data/SUN397', split='test') + >>> test_dataset + Dataset SUN397 + Number of samples: 19850 + Number of categories: 397 + Root of dataset: data/SUN397 + + **Note that some images are not a jpg file although the name ends with ".jpg". + The backend of SUN397 should be "pillow" as below to read these images properly,** + + .. code-block:: python + + pipeline = [ + dict(type='LoadImageFromFile', imdecode_backend='pillow'), + dict(type='RandomResizedCrop', scale=224), + dict(type='PackInputs') + ] + """ # noqa: E501 + + METAINFO = {'classes': SUN397_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'train': + ann_file = self.backend.join_path('Partitions', 'Training_01.txt') + else: + ann_file = self.backend.join_path('Partitions', 'Testing_01.txt') + + data_prefix = 'SUN397' + test_mode = split == 'test' + + super(SUN397, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + img_path = self.backend.join_path(self.img_prefix, pair[1:]) + items = pair.split('/') + class_name = '_'.join(items[2:-1]) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def __getitem__(self, idx: int) -> dict: + try: + return super().__getitem__(idx) + except AttributeError: + raise RuntimeError( + 'Some images in the SUN397 dataset are not a jpg file ' + 'although the name ends with ".jpg". The backend of SUN397 ' + 'should be "pillow" to read these images properly.') + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/textvqa.py b/mmpretrain/datasets/textvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..48a82b45ef1a4cc0bad2ab45b32b8ba8d28b2a60 --- /dev/null +++ b/mmpretrain/datasets/textvqa.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class TextVQA(BaseDataset): + """TextVQA dataset. + + val image: + https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip + test image: + https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip + val json: + https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json + test json: + https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json + + folder structure: + data/textvqa + ├── annotations + │ ├── TextVQA_0.5.1_test.json + │ └── TextVQA_0.5.1_val.json + └── images + ├── test_images + └── train_images + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file)['data'] + + data_list = [] + + for ann in annotations: + + # ann example + # { + # 'question': 'what is the brand of...is camera?', + # 'image_id': '003a8ae2ef43b901', + # 'image_classes': [ + # 'Cassette deck', 'Printer', ... + # ], + # 'flickr_original_url': 'https://farm2.static...04a6_o.jpg', + # 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg', + # 'image_width': 1024, + # 'image_height': 664, + # 'answers': [ + # 'nous les gosses', + # 'dakota', + # 'clos culombu', + # 'dakota digital' ... + # ], + # 'question_tokens': + # ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'], + # 'question_id': 34602, + # 'set_name': 'val' + # } + + data_info = dict(question=ann['question']) + data_info['question_id'] = ann['question_id'] + data_info['image_id'] = ann['image_id'] + + img_path = mmengine.join_path(self.data_prefix['img_path'], + ann['image_id'] + '.jpg') + data_info['img_path'] = img_path + + data_info['question_id'] = ann['question_id'] + + if 'answers' in ann: + answers = [item for item in ann.pop('answers')] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..617503f26e8c0ec9e1b48b952df2e22a5a5b522d --- /dev/null +++ b/mmpretrain/datasets/transforms/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize, + RandomFlip, RandomGrayscale, RandomResize, Resize) + +from mmpretrain.registry import TRANSFORMS +from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform, + Brightness, ColorTransform, Contrast, Cutout, + Equalize, GaussianBlur, Invert, Posterize, + RandAugment, Rotate, Sharpness, Shear, Solarize, + SolarizeAdd, Translate) +from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs, + PILToNumpy, Transpose) +from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption, + ColorJitter, EfficientNetCenterCrop, + EfficientNetRandomCrop, Lighting, + MAERandomResizedCrop, RandomCrop, RandomErasing, + RandomResizedCrop, + RandomResizedCropAndInterpolationWithTwoPic, + RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator) +from .utils import get_transform_idx, remove_transform +from .wrappers import ApplyToList, MultiView + +for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip, + RandomGrayscale, RandomResize, Resize): + TRANSFORMS.register_module(module=t) + +__all__ = [ + 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop', + 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert', + 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', + 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', + 'PackInputs', 'Albumentations', 'EfficientNetRandomCrop', + 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', + 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator', + 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize', + 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView', + 'ApplyToList', 'CleanCaption', 'RandomTranslatePad', + 'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx', + 'remove_transform', 'MAERandomResizedCrop' +] diff --git a/mmpretrain/datasets/transforms/auto_augment.py b/mmpretrain/datasets/transforms/auto_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..4705d5ec04e38aa8286904a1e02b0bc56d79f09e --- /dev/null +++ b/mmpretrain/datasets/transforms/auto_augment.py @@ -0,0 +1,1244 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from copy import deepcopy +from math import ceil +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform, Compose, RandomChoice +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_list_of, is_seq_of +from PIL import Image, ImageFilter + +from mmpretrain.registry import TRANSFORMS + + +def merge_hparams(policy: dict, hparams: dict) -> dict: + """Merge hyperparameters into policy config. + + Only merge partial hyperparameters required of the policy. + + Args: + policy (dict): Original policy config dict. + hparams (dict): Hyperparameters need to be merged. + + Returns: + dict: Policy config dict after adding ``hparams``. + """ + policy = deepcopy(policy) + op = TRANSFORMS.get(policy['type']) + assert op is not None, f'Invalid policy type "{policy["type"]}".' + + op_args = inspect.getfullargspec(op.__init__).args + for key, value in hparams.items(): + if key in op_args and key not in policy: + policy[key] = value + return policy + + +@TRANSFORMS.register_module() +class AutoAugment(RandomChoice): + """Auto augmentation. + + This data augmentation is proposed in `AutoAugment: Learning Augmentation + Policies from Data `_. + + Args: + policies (str | list[list[dict]]): The policies of auto augmentation. + If string, use preset policies collection like "imagenet". If list, + Each item is a sub policies, composed by several augmentation + policy dicts. When AutoAugment is called, a random sub policies in + ``policies`` will be selected to augment images. + hparams (dict): Configs of hyperparameters. Hyperparameters will be + used in policies that require these arguments if these arguments + are not set in policy dicts. Defaults to ``dict(pad_val=128)``. + + .. admonition:: Available preset policies + + - ``"imagenet"``: Policy for ImageNet, come from + `DeepVoltaire/AutoAugment`_ + + .. _DeepVoltaire/AutoAugment: https://github.com/DeepVoltaire/AutoAugment + """ + + def __init__(self, + policies: Union[str, List[List[dict]]], + hparams: dict = dict(pad_val=128)): + if isinstance(policies, str): + assert policies in AUTOAUG_POLICIES, 'Invalid policies, ' \ + f'please choose from {list(AUTOAUG_POLICIES.keys())}.' + policies = AUTOAUG_POLICIES[policies] + self.hparams = hparams + self.policies = [[merge_hparams(t, hparams) for t in sub] + for sub in policies] + transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies] + + super().__init__(transforms=transforms) + + def __repr__(self) -> str: + policies_str = '' + for sub in self.policies: + policies_str += '\n ' + ', \t'.join([t['type'] for t in sub]) + + repr_str = self.__class__.__name__ + repr_str += f'(policies:{policies_str}\n)' + return repr_str + + +@TRANSFORMS.register_module() +class RandAugment(BaseTransform): + r"""Random augmentation. + + This data augmentation is proposed in `RandAugment: Practical automated + data augmentation with a reduced search space + `_. + + Args: + policies (str | list[dict]): The policies of random augmentation. + If string, use preset policies collection like "timm_increasing". + If list, each item is one specific augmentation policy dict. + The policy dict shall should have these keys: + + - ``type`` (str), The type of augmentation. + - ``magnitude_range`` (Sequence[number], optional): For those + augmentation have magnitude, you need to specify the magnitude + level mapping range. For example, assume ``total_level`` is 10, + ``magnitude_level=3`` specify magnitude is 3 if + ``magnitude_range=(0, 10)`` while specify magnitude is 7 if + ``magnitude_range=(10, 0)``. + - other keyword arguments of the augmentation. + + num_policies (int): Number of policies to select from policies each + time. + magnitude_level (int | float): Magnitude level for all the augmentation + selected. + magnitude_std (Number | str): Deviation of magnitude noise applied. + + - If positive number, the magnitude obeys normal distribution + :math:`\mathcal{N}(magnitude_level, magnitude_std)`. + - If 0 or negative number, magnitude remains unchanged. + - If str "inf", the magnitude obeys uniform distribution + :math:`Uniform(min, magnitude)`. + total_level (int | float): Total level for the magnitude. Defaults to + 10. + hparams (dict): Configs of hyperparameters. Hyperparameters will be + used in policies that require these arguments if these arguments + are not set in policy dicts. Defaults to ``dict(pad_val=128)``. + + .. admonition:: Available preset policies + + - ``"timm_increasing"``: The ``_RAND_INCREASING_TRANSFORMS`` policy + from `timm`_ + + .. _timm: https://github.com/rwightman/pytorch-image-models + + Examples: + + To use "timm-increasing" policies collection, select two policies every + time, and magnitude_level of every policy is 6 (total is 10 by default) + + >>> import numpy as np + >>> from mmpretrain.datasets import RandAugment + >>> transform = RandAugment( + ... policies='timm_increasing', + ... num_policies=2, + ... magnitude_level=6, + ... ) + >>> data = {'img': np.random.randint(0, 256, (224, 224, 3))} + >>> results = transform(data) + >>> print(results['img'].shape) + (224, 224, 3) + + If you want the ``magnitude_level`` randomly changes every time, you + can use ``magnitude_std`` to specify the random distribution. For + example, a normal distribution :math:`\mathcal{N}(6, 0.5)`. + + >>> transform = RandAugment( + ... policies='timm_increasing', + ... num_policies=2, + ... magnitude_level=6, + ... magnitude_std=0.5, + ... ) + + You can also use your own policies: + + >>> policies = [ + ... dict(type='AutoContrast'), + ... dict(type='Rotate', magnitude_range=(0, 30)), + ... dict(type='ColorTransform', magnitude_range=(0, 0.9)), + ... ] + >>> transform = RandAugment( + ... policies=policies, + ... num_policies=2, + ... magnitude_level=6 + ... ) + + Note: + ``magnitude_std`` will introduce some randomness to policy, modified by + https://github.com/rwightman/pytorch-image-models. + + When magnitude_std=0, we calculate the magnitude as follows: + + .. math:: + \text{magnitude} = \frac{\text{magnitude_level}} + {\text{totallevel}} \times (\text{val2} - \text{val1}) + + \text{val1} + """ + + def __init__(self, + policies: Union[str, List[dict]], + num_policies: int, + magnitude_level: int, + magnitude_std: Union[Number, str] = 0., + total_level: int = 10, + hparams: dict = dict(pad_val=128)): + if isinstance(policies, str): + assert policies in RANDAUG_POLICIES, 'Invalid policies, ' \ + f'please choose from {list(RANDAUG_POLICIES.keys())}.' + policies = RANDAUG_POLICIES[policies] + + assert is_list_of(policies, dict), 'policies must be a list of dict.' + + assert isinstance(magnitude_std, (Number, str)), \ + '`magnitude_std` must be of number or str type, ' \ + f'got {type(magnitude_std)} instead.' + if isinstance(magnitude_std, str): + assert magnitude_std == 'inf', \ + '`magnitude_std` must be of number or "inf", ' \ + f'got "{magnitude_std}" instead.' + + assert num_policies > 0, 'num_policies must be greater than 0.' + assert magnitude_level >= 0, 'magnitude_level must be no less than 0.' + assert total_level > 0, 'total_level must be greater than 0.' + + self.num_policies = num_policies + self.magnitude_level = magnitude_level + self.magnitude_std = magnitude_std + self.total_level = total_level + self.hparams = hparams + self.policies = [] + self.transforms = [] + + randaug_cfg = dict( + magnitude_level=magnitude_level, + total_level=total_level, + magnitude_std=magnitude_std) + + for policy in policies: + self._check_policy(policy) + policy = merge_hparams(policy, hparams) + policy.pop('magnitude_key', None) # For backward compatibility + if 'magnitude_range' in policy: + policy.update(randaug_cfg) + self.policies.append(policy) + self.transforms.append(TRANSFORMS.build(policy)) + + def __iter__(self): + """Iterate all transforms.""" + return iter(self.transforms) + + def _check_policy(self, policy): + """Check whether the sub-policy dict is available.""" + assert isinstance(policy, dict) and 'type' in policy, \ + 'Each policy must be a dict with key "type".' + type_name = policy['type'] + + if 'magnitude_range' in policy: + magnitude_range = policy['magnitude_range'] + assert is_seq_of(magnitude_range, Number), \ + f'`magnitude_range` of RandAugment policy {type_name} ' \ + 'should be a sequence with two numbers.' + + @cache_randomness + def random_policy_indices(self) -> np.ndarray: + """Return the random chosen transform indices.""" + indices = np.arange(len(self.policies)) + return np.random.choice(indices, size=self.num_policies).tolist() + + def transform(self, results: dict) -> Optional[dict]: + """Randomly choose a sub-policy to apply.""" + + chosen_policies = [ + self.transforms[i] for i in self.random_policy_indices() + ] + + sub_pipeline = Compose(chosen_policies) + return sub_pipeline(results) + + def __repr__(self) -> str: + policies_str = '' + for policy in self.policies: + policies_str += '\n ' + f'{policy["type"]}' + if 'magnitude_range' in policy: + val1, val2 = policy['magnitude_range'] + policies_str += f' ({val1}, {val2})' + + repr_str = self.__class__.__name__ + repr_str += f'(num_policies={self.num_policies}, ' + repr_str += f'magnitude_level={self.magnitude_level}, ' + repr_str += f'total_level={self.total_level}, ' + repr_str += f'policies:{policies_str}\n)' + return repr_str + + +class BaseAugTransform(BaseTransform): + r"""The base class of augmentation transform for RandAugment. + + This class provides several common attributions and methods to support the + magnitude level mapping and magnitude level randomness in + :class:`RandAugment`. + + Args: + magnitude_level (int | float): Magnitude level. + magnitude_range (Sequence[number], optional): For augmentation have + magnitude argument, maybe "magnitude", "angle" or other, you can + specify the magnitude level mapping range to generate the magnitude + argument. For example, assume ``total_level`` is 10, + ``magnitude_level=3`` specify magnitude is 3 if + ``magnitude_range=(0, 10)`` while specify magnitude is 7 if + ``magnitude_range=(10, 0)``. Defaults to None. + magnitude_std (Number | str): Deviation of magnitude noise applied. + + - If positive number, the magnitude obeys normal distribution + :math:`\mathcal{N}(magnitude, magnitude_std)`. + - If 0 or negative number, magnitude remains unchanged. + - If str "inf", the magnitude obeys uniform distribution + :math:`Uniform(min, magnitude)`. + + Defaults to 0. + total_level (int | float): Total level for the magnitude. Defaults to + 10. + prob (float): The probability for performing transformation therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0. + """ + + def __init__(self, + magnitude_level: int = 10, + magnitude_range: Tuple[float, float] = None, + magnitude_std: Union[str, float] = 0., + total_level: int = 10, + prob: float = 0.5, + random_negative_prob: float = 0.5): + self.magnitude_level = magnitude_level + self.magnitude_range = magnitude_range + self.magnitude_std = magnitude_std + self.total_level = total_level + self.prob = prob + self.random_negative_prob = random_negative_prob + + @cache_randomness + def random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.prob + + @cache_randomness + def random_magnitude(self): + """Randomly generate magnitude.""" + magnitude = self.magnitude_level + # if magnitude_std is positive number or 'inf', move + # magnitude_value randomly. + if self.magnitude_std == 'inf': + magnitude = np.random.uniform(0, magnitude) + elif self.magnitude_std > 0: + magnitude = np.random.normal(magnitude, self.magnitude_std) + magnitude = np.clip(magnitude, 0, self.total_level) + + val1, val2 = self.magnitude_range + magnitude = (magnitude / self.total_level) * (val2 - val1) + val1 + return magnitude + + @cache_randomness + def random_negative(self, value): + """Randomly negative the value.""" + if np.random.rand() < self.random_negative_prob: + return -value + else: + return value + + def extra_repr(self): + """Extra repr string when auto-generating magnitude is enabled.""" + if self.magnitude_range is not None: + repr_str = f', magnitude_level={self.magnitude_level}, ' + repr_str += f'magnitude_range={self.magnitude_range}, ' + repr_str += f'magnitude_std={self.magnitude_std}, ' + repr_str += f'total_level={self.total_level}, ' + return repr_str + else: + return '' + + +@TRANSFORMS.register_module() +class Shear(BaseAugTransform): + """Shear images. + + Args: + magnitude (int | float | None): The magnitude used for shear. If None, + generate from ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing shear therefore should be + in range [0, 1]. Defaults to 0.5. + direction (str): The shearing direction. Options are 'horizontal' and + 'vertical'. Defaults to 'horizontal'. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'bicubic'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + direction: str = 'horizontal', + random_negative_prob: float = 0.5, + interpolation: str = 'bicubic', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + assert direction in ('horizontal', 'vertical'), 'direction must be ' \ + f'either "horizontal" or "vertical", got "{direction}" instead.' + self.direction = direction + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_sheared = mmcv.imshear( + img, + magnitude, + direction=self.direction, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_sheared.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'direction={self.direction}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Translate(BaseAugTransform): + """Translate images. + + Args: + magnitude (int | float | None): The magnitude used for translate. Note + that the offset is calculated by magnitude * size in the + corresponding direction. With a magnitude of 1, the whole image + will be moved out of the range. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing translate therefore should + be in range [0, 1]. Defaults to 0.5. + direction (str): The translating direction. Options are 'horizontal' + and 'vertical'. Defaults to 'horizontal'. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + direction: str = 'horizontal', + random_negative_prob: float = 0.5, + interpolation: str = 'nearest', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + assert direction in ('horizontal', 'vertical'), 'direction must be ' \ + f'either "horizontal" or "vertical", got "{direction}" instead.' + self.direction = direction + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + height, width = img.shape[:2] + if self.direction == 'horizontal': + offset = magnitude * width + else: + offset = magnitude * height + img_translated = mmcv.imtranslate( + img, + offset, + direction=self.direction, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_translated.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'direction={self.direction}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Rotate(BaseAugTransform): + """Rotate images. + + Args: + angle (float, optional): The angle used for rotate. Positive values + stand for clockwise rotation. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If None, the center of the image will be used. + Defaults to None. + scale (float): Isotropic scale factor. Defaults to 1.0. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing rotate therefore should be + in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the angle + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + angle: Optional[float] = None, + center: Optional[Tuple[float]] = None, + scale: float = 1.0, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + random_negative_prob: float = 0.5, + interpolation: str = 'nearest', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (angle is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `angle` and `magnitude_range`.' + + self.angle = angle + self.center = center + self.scale = scale + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.angle is not None: + angle = self.random_negative(self.angle) + else: + angle = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_rotated = mmcv.imrotate( + img, + angle, + center=self.center, + scale=self.scale, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_rotated.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(angle={self.angle}, ' + repr_str += f'center={self.center}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class AutoContrast(BaseAugTransform): + """Auto adjust image contrast. + + Args: + prob (float): The probability for performing auto contrast + therefore should be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_contrasted = mmcv.auto_contrast(img) + results['img'] = img_contrasted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Invert(BaseAugTransform): + """Invert images. + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_inverted = mmcv.iminvert(img) + results['img'] = img_inverted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Equalize(BaseAugTransform): + """Equalize the image histogram. + + Args: + prob (float): The probability for performing equalize therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_equalized = mmcv.imequalize(img) + results['img'] = img_equalized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Solarize(BaseAugTransform): + """Solarize images (invert all pixel values above a threshold). + + Args: + thr (int | float | None): The threshold above which the pixels value + will be inverted. If None, generate from ``magnitude_range``, + see :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + thr: Union[int, float, None] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (thr is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `thr` and `magnitude_range`.' + + self.thr = thr + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.thr is not None: + thr = self.thr + else: + thr = self.random_magnitude() + + img = results['img'] + img_solarized = mmcv.solarize(img, thr=thr) + results['img'] = img_solarized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(thr={self.thr}, ' + repr_str += f'prob={self.prob}{self.extra_repr()}))' + return repr_str + + +@TRANSFORMS.register_module() +class SolarizeAdd(BaseAugTransform): + """SolarizeAdd images (add a certain value to pixels below a threshold). + + Args: + magnitude (int | float | None): The value to be added to pixels below + the thr. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + thr (int | float): The threshold below which the pixels value will be + adjusted. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + thr: Union[int, float] = 128, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + assert isinstance(thr, (int, float)), 'The thr type must '\ + f'be int or float, but got {type(thr)} instead.' + self.thr = thr + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.magnitude + else: + magnitude = self.random_magnitude() + + img = results['img'] + img_solarized = np.where(img < self.thr, + np.minimum(img + magnitude, 255), img) + results['img'] = img_solarized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'thr={self.thr}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Posterize(BaseAugTransform): + """Posterize images (reduce the number of bits for each color channel). + + Args: + bits (int, optional): Number of bits for each pixel in the output img, + which should be less or equal to 8. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + prob (float): The probability for posterizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + bits: Optional[int] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (bits is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `bits` and `magnitude_range`.' + + if bits is not None: + assert bits <= 8, \ + f'The bits must be less than 8, got {bits} instead.' + self.bits = bits + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.bits is not None: + bits = self.bits + else: + bits = self.random_magnitude() + + # To align timm version, we need to round up to integer here. + bits = ceil(bits) + + img = results['img'] + img_posterized = mmcv.posterize(img, bits=bits) + results['img'] = img_posterized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(bits={self.bits}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Contrast(BaseAugTransform): + """Adjust images contrast. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + contrast. A positive magnitude would enhance the contrast and + a negative magnitude would make the image grayer. A magnitude=0 + gives the origin img. If None, generate from ``magnitude_range``, + see :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing contrast adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude) + results['img'] = img_contrasted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class ColorTransform(BaseAugTransform): + """Adjust images color balance. + + Args: + magnitude (int | float | None): The magnitude used for color transform. + A positive magnitude would enhance the color and a negative + magnitude would make the image grayer. A magnitude=0 gives the + origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing ColorTransform therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude) + results['img'] = img_color_adjusted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Brightness(BaseAugTransform): + """Adjust images brightness. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + brightness. A positive magnitude would enhance the brightness and a + negative magnitude would make the image darker. A magnitude=0 gives + the origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing brightness adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude) + results['img'] = img_brightened.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Sharpness(BaseAugTransform): + """Adjust images sharpness. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + sharpness. A positive magnitude would enhance the sharpness and a + negative magnitude would make the image bulr. A magnitude=0 gives + the origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing sharpness adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude) + results['img'] = img_sharpened.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Cutout(BaseAugTransform): + """Cutout images. + + Args: + shape (int | tuple(int) | None): Expected cutout shape (h, w). + If given as a single value, the value will be used for both h and + w. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If it is a sequence, it must have the same length with the image + channels. Defaults to 128. + prob (float): The probability for performing cutout therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + shape: Union[int, Tuple[int], None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (shape is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `shape` and `magnitude_range`.' + + self.shape = shape + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.shape is not None: + shape = self.shape + else: + shape = int(self.random_magnitude()) + + img = results['img'] + img_cutout = mmcv.cutout(img, shape, pad_val=self.pad_val) + results['img'] = img_cutout.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(shape={self.shape}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class GaussianBlur(BaseAugTransform): + """Gaussian blur images. + + Args: + radius (int, float, optional): The blur radius. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + prob (float): The probability for posterizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + radius: Union[int, float, None] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (radius is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `radius` and `magnitude_range`.' + + self.radius = radius + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.radius is not None: + radius = self.radius + else: + radius = self.random_magnitude() + + img = results['img'] + pil_img = Image.fromarray(img) + pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=radius)) + results['img'] = np.array(pil_img, dtype=img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(radius={self.radius}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +# yapf: disable +# flake8: noqa +AUTOAUG_POLICIES = { + # Policy for ImageNet, refers to + # https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py + 'imagenet': [ + [dict(type='Posterize', bits=4, prob=0.4), dict(type='Rotate', angle=30., prob=0.6)], + [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)], + [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)], + [dict(type='Posterize', bits=5, prob=0.6), dict(type='Posterize', bits=5, prob=0.6)], + [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)], + [dict(type='Equalize', prob=0.4), dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)], + [dict(type='Solarize', thr=256 / 9 * 6, prob=0.6), dict(type='Equalize', prob=0.6)], + [dict(type='Posterize', bits=6, prob=0.8), dict(type='Equalize', prob=1.)], + [dict(type='Rotate', angle=10., prob=0.2), dict(type='Solarize', thr=256 / 9, prob=0.6)], + [dict(type='Equalize', prob=0.6), dict(type='Posterize', bits=5, prob=0.4)], + [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0., prob=0.4)], + [dict(type='Rotate', angle=30., prob=0.4), dict(type='Equalize', prob=0.6)], + [dict(type='Equalize', prob=0.0), dict(type='Equalize', prob=0.8)], + [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)], + [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0.2, prob=1.)], + [dict(type='ColorTransform', magnitude=0.8, prob=0.8), dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)], + [dict(type='Sharpness', magnitude=0.7, prob=0.4), dict(type='Invert', prob=0.6)], + [dict(type='Shear', magnitude=0.3 / 9 * 5, prob=0.6, direction='horizontal'), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0., prob=0.4), dict(type='Equalize', prob=0.6)], + [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)], + [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)], + [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)], + [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)], + ], +} + +RANDAUG_POLICIES = { + # Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models + 'timm_increasing': [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Invert'), + dict(type='Rotate', magnitude_range=(0, 30)), + dict(type='Posterize', magnitude_range=(4, 0)), + dict(type='Solarize', magnitude_range=(256, 0)), + dict(type='SolarizeAdd', magnitude_range=(0, 110)), + dict(type='ColorTransform', magnitude_range=(0, 0.9)), + dict(type='Contrast', magnitude_range=(0, 0.9)), + dict(type='Brightness', magnitude_range=(0, 0.9)), + dict(type='Sharpness', magnitude_range=(0, 0.9)), + dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'), + dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'), + dict(type='Translate', magnitude_range=(0, 0.45), direction='horizontal'), + dict(type='Translate', magnitude_range=(0, 0.45), direction='vertical'), + ], + 'simple_increasing': [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Rotate', magnitude_range=(0, 30)), + dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'), + dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'), + ], +} diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d331636a883ce602e419e0867aea7b513b4d87 --- /dev/null +++ b/mmpretrain/datasets/transforms/formatting.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from collections.abc import Sequence + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as F +from mmcv.transforms import BaseTransform +from mmengine.utils import is_str +from PIL import Image + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample, MultiTaskDataSample + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + """ + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError( + f'Type {type(data)} cannot be converted to tensor.' + 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' + '`Sequence`, `int` and `float`') + + +@TRANSFORMS.register_module() +class PackInputs(BaseTransform): + """Pack the inputs data. + + **Required Keys:** + + - ``input_key`` + - ``*algorithm_keys`` + - ``*meta_keys`` + + **Deleted Keys:** + + All other keys in the dict. + + **Added Keys:** + + - inputs (:obj:`torch.Tensor`): The forward data of models. + - data_samples (:obj:`~mmpretrain.structures.DataSample`): The + annotation info of the sample. + + Args: + input_key (str): The key of element to feed into the model forwarding. + Defaults to 'img'. + algorithm_keys (Sequence[str]): The keys of custom elements to be used + in the algorithm. Defaults to an empty tuple. + meta_keys (Sequence[str]): The keys of meta information to be saved in + the data sample. Defaults to :attr:`PackInputs.DEFAULT_META_KEYS`. + + .. admonition:: Default algorithm keys + + Besides the specified ``algorithm_keys``, we will set some default keys + into the output data sample and do some formatting. Therefore, you + don't need to set these keys in the ``algorithm_keys``. + + - ``gt_label``: The ground-truth label. The value will be converted + into a 1-D tensor. + - ``gt_score``: The ground-truth score. The value will be converted + into a 1-D tensor. + - ``mask``: The mask for some self-supervise tasks. The value will + be converted into a tensor. + + .. admonition:: Default meta keys + + - ``sample_idx``: The id of the image sample. + - ``img_path``: The path to the image file. + - ``ori_shape``: The original shape of the image as a tuple (H, W). + - ``img_shape``: The shape of the image after the pipeline as a + tuple (H, W). + - ``scale_factor``: The scale factor between the resized image and + the original image. + - ``flip``: A boolean indicating if image flip transform was used. + - ``flip_direction``: The flipping direction. + """ + + DEFAULT_META_KEYS = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction') + + def __init__(self, + input_key='img', + algorithm_keys=(), + meta_keys=DEFAULT_META_KEYS): + self.input_key = input_key + self.algorithm_keys = algorithm_keys + self.meta_keys = meta_keys + + @staticmethod + def format_input(input_): + if isinstance(input_, list): + return [PackInputs.format_input(item) for item in input_] + elif isinstance(input_, np.ndarray): + if input_.ndim == 2: # For grayscale image. + input_ = np.expand_dims(input_, -1) + if input_.ndim == 3 and not input_.flags.c_contiguous: + input_ = np.ascontiguousarray(input_.transpose(2, 0, 1)) + input_ = to_tensor(input_) + elif input_.ndim == 3: + # convert to tensor first to accelerate, see + # https://github.com/open-mmlab/mmdetection/pull/9533 + input_ = to_tensor(input_).permute(2, 0, 1).contiguous() + else: + # convert input with other shape to tensor without permute, + # like video input (num_crops, C, T, H, W). + input_ = to_tensor(input_) + elif isinstance(input_, Image.Image): + input_ = F.pil_to_tensor(input_) + elif not isinstance(input_, torch.Tensor): + raise TypeError(f'Unsupported input type {type(input_)}.') + + return input_ + + def transform(self, results: dict) -> dict: + """Method to pack the input data.""" + + packed_results = dict() + if self.input_key in results: + input_ = results[self.input_key] + packed_results['inputs'] = self.format_input(input_) + + data_sample = DataSample() + + # Set default keys + if 'gt_label' in results: + data_sample.set_gt_label(results['gt_label']) + if 'gt_score' in results: + data_sample.set_gt_score(results['gt_score']) + if 'mask' in results: + data_sample.set_mask(results['mask']) + + # Set custom algorithm keys + for key in self.algorithm_keys: + if key in results: + data_sample.set_field(results[key], key) + + # Set meta keys + for key in self.meta_keys: + if key in results: + data_sample.set_field(results[key], key, field_type='metainfo') + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(input_key='{self.input_key}', " + repr_str += f'algorithm_keys={self.algorithm_keys}, ' + repr_str += f'meta_keys={self.meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class PackMultiTaskInputs(BaseTransform): + """Convert all image labels of multi-task dataset to a dict of tensor. + + Args: + multi_task_fields (Sequence[str]): + input_key (str): + task_handlers (dict): + """ + + def __init__(self, + multi_task_fields, + input_key='img', + task_handlers=dict()): + self.multi_task_fields = multi_task_fields + self.input_key = input_key + self.task_handlers = defaultdict(PackInputs) + for task_name, task_handler in task_handlers.items(): + self.task_handlers[task_name] = TRANSFORMS.build(task_handler) + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3}, + 'img': array([[[ 0, 0, 0]) + """ + packed_results = dict() + results = results.copy() + + if self.input_key in results: + input_ = results[self.input_key] + packed_results['inputs'] = PackInputs.format_input(input_) + + task_results = defaultdict(dict) + for field in self.multi_task_fields: + if field in results: + value = results.pop(field) + for k, v in value.items(): + task_results[k].update({field: v}) + + data_sample = MultiTaskDataSample() + for task_name, task_result in task_results.items(): + task_handler = self.task_handlers[task_name] + task_pack_result = task_handler({**results, **task_result}) + data_sample.set_field(task_pack_result['data_samples'], task_name) + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self): + repr = self.__class__.__name__ + task_handlers = ', '.join( + f"'{name}': {handler.__class__.__name__}" + for name, handler in self.task_handlers.items()) + repr += f'(multi_task_fields={self.multi_task_fields}, ' + repr += f"input_key='{self.input_key}', " + repr += f'task_handlers={{{task_handlers}}})' + return repr + + +@TRANSFORMS.register_module() +class Transpose(BaseTransform): + """Transpose numpy array. + + **Required Keys:** + + - ``*keys`` + + **Modified Keys:** + + - ``*keys`` + + Args: + keys (List[str]): The fields to convert to tensor. + order (List[int]): The output dimensions order. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def transform(self, results): + """Method to transpose array.""" + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL')) +class NumpyToPIL(BaseTransform): + """Convert the image from OpenCV format to :obj:`PIL.Image.Image`. + + **Required Keys:** + + - ``img`` + + **Modified Keys:** + + - ``img`` + + Args: + to_rgb (bool): Whether to convert img to rgb. Defaults to True. + """ + + def __init__(self, to_rgb: bool = False) -> None: + self.to_rgb = to_rgb + + def transform(self, results: dict) -> dict: + """Method to convert images to :obj:`PIL.Image.Image`.""" + img = results['img'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img + + results['img'] = Image.fromarray(img) + return results + + def __repr__(self) -> str: + return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' + + +@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy')) +class PILToNumpy(BaseTransform): + """Convert img to :obj:`numpy.ndarray`. + + **Required Keys:** + + - ``img`` + + **Modified Keys:** + + - ``img`` + + Args: + to_bgr (bool): Whether to convert img to rgb. Defaults to True. + dtype (str, optional): The dtype of the converted numpy array. + Defaults to None. + """ + + def __init__(self, to_bgr: bool = False, dtype=None) -> None: + self.to_bgr = to_bgr + self.dtype = dtype + + def transform(self, results: dict) -> dict: + """Method to convert img to :obj:`numpy.ndarray`.""" + img = np.array(results['img'], dtype=self.dtype) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img + + results['img'] = img + return results + + def __repr__(self) -> str: + return self.__class__.__name__ + \ + f'(to_bgr={self.to_bgr}, dtype={self.dtype})' + + +@TRANSFORMS.register_module() +class Collect(BaseTransform): + """Collect and only reserve the specified fields. + + **Required Keys:** + + - ``*keys`` + + **Deleted Keys:** + + All keys except those in the argument ``*keys``. + + Args: + keys (Sequence[str]): The keys of the fields to be collected. + """ + + def __init__(self, keys): + self.keys = keys + + def transform(self, results): + data = {} + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..4c640f6b1fa6d4e250ce2f8db59c038382e915f6 --- /dev/null +++ b/mmpretrain/datasets/transforms/processing.py @@ -0,0 +1,1795 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import math +import numbers +import re +import string +from enum import EnumMeta +from numbers import Number +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import mmengine +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as F +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness +from PIL import Image +from torchvision import transforms +from torchvision.transforms.transforms import InterpolationMode + +from mmpretrain.registry import TRANSFORMS + +try: + import albumentations +except ImportError: + albumentations = None + + +def _str_to_torch_dtype(t: str): + """mapping str format dtype to torch.dtype.""" + import torch # noqa: F401,F403 + return eval(f'torch.{t}') + + +def _interpolation_modes_from_str(t: str): + """mapping str format to Interpolation.""" + t = t.lower() + inverse_modes_mapping = { + 'nearest': InterpolationMode.NEAREST, + 'bilinear': InterpolationMode.BILINEAR, + 'bicubic': InterpolationMode.BICUBIC, + 'box': InterpolationMode.BOX, + 'hammimg': InterpolationMode.HAMMING, + 'lanczos': InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[t] + + +class TorchVisonTransformWrapper: + + def __init__(self, transform, *args, **kwargs): + if 'interpolation' in kwargs and isinstance(kwargs['interpolation'], + str): + kwargs['interpolation'] = _interpolation_modes_from_str( + kwargs['interpolation']) + if 'dtype' in kwargs and isinstance(kwargs['dtype'], str): + kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype']) + self.t = transform(*args, **kwargs) + + def __call__(self, results): + results['img'] = self.t(results['img']) + return results + + def __repr__(self) -> str: + return f'TorchVision{repr(self.t)}' + + +def register_vision_transforms() -> List[str]: + """Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS`` + registry. + + Returns: + List[str]: A list of registered transforms' name. + """ + vision_transforms = [] + for module_name in dir(torchvision.transforms): + if not re.match('[A-Z]', module_name): + # must startswith a capital letter + continue + _transform = getattr(torchvision.transforms, module_name) + if inspect.isclass(_transform) and callable( + _transform) and not isinstance(_transform, (EnumMeta)): + from functools import partial + TRANSFORMS.register_module( + module=partial( + TorchVisonTransformWrapper, transform=_transform), + name=f'torchvision/{module_name}') + vision_transforms.append(f'torchvision/{module_name}') + return vision_transforms + + +# register all the transforms in torchvision by using a transform wrapper +VISION_TRANSFORMS = register_vision_transforms() + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Crop the given Image at a random location. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + crop_size (int | Sequence): Desired output size of the crop. If + crop_size is an int instead of sequence like (h, w), a square crop + (crop_size, crop_size) is made. + padding (int | Sequence, optional): Optional padding on each border + of the image. If a sequence of length 4 is provided, it is used to + pad left, top, right, bottom borders respectively. If a sequence + of length 2 is provided, it is used to pad left/right, top/bottom + borders, respectively. Default: None, which means no padding. + pad_if_needed (bool): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + Default: False. + pad_val (Number | Sequence[Number]): Pixel pad_val value for constant + fill. If a tuple of length 3, it is used to pad_val R, G, B + channels respectively. Default: 0. + padding_mode (str): Type of padding. Defaults to "constant". Should + be one of the following: + + - ``constant``: Pads with a constant value, this value is specified + with pad_val. + - ``edge``: pads with the last value at the edge of the image. + - ``reflect``: Pads with reflection of image without repeating the + last value on the edge. For example, padding [1, 2, 3, 4] + with 2 elements on both sides in reflect mode will result + in [3, 2, 1, 2, 3, 4, 3, 2]. + - ``symmetric``: Pads with reflection of image repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with + 2 elements on both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3]. + """ + + def __init__(self, + crop_size: Union[Sequence, int], + padding: Optional[Union[Sequence, int]] = None, + pad_if_needed: bool = False, + pad_val: Union[Number, Sequence[Number]] = 0, + padding_mode: str = 'constant'): + if isinstance(crop_size, Sequence): + assert len(crop_size) == 2 + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + else: + assert crop_size > 0 + self.crop_size = (crop_size, crop_size) + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + self.padding = padding + self.pad_if_needed = pad_if_needed + self.pad_val = pad_val + self.padding_mode = padding_mode + + @cache_randomness + def rand_crop_params(self, img: np.ndarray): + """Get parameters for ``crop`` for a random crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to ``crop`` for random crop. + """ + h, w = img.shape[:2] + target_h, target_w = self.crop_size + if w == target_w and h == target_h: + return 0, 0, h, w + elif w < target_w or h < target_h: + target_w = min(w, target_w) + target_h = min(h, target_h) + + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + if self.padding is not None: + img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val) + + # pad img if needed + if self.pad_if_needed: + h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2) + w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2) + + img = mmcv.impad( + img, + padding=(w_pad, h_pad, w_pad, h_pad), + pad_val=self.pad_val, + padding_mode=self.padding_mode) + + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + np.array([ + offset_w, + offset_h, + offset_w + target_w - 1, + offset_h + target_h - 1, + ])) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', padding={self.padding}' + repr_str += f', pad_if_needed={self.pad_if_needed}' + repr_str += f', pad_val={self.pad_val}' + repr_str += f', padding_mode={self.padding_mode})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomResizedCrop(BaseTransform): + """Crop the given image to random scale and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a + random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio + is made. This crop is finally resized to given size. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + scale (sequence | int): Desired output scale of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). + max_attempts (int): Maximum number of attempts before falling back to + Central Crop. Defaults to 10. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to + 'bilinear'. + backend (str): The image resize backend type, accepted values are + 'cv2' and 'pillow'. Defaults to 'cv2'. + """ + + def __init__(self, + scale: Union[Sequence, int], + crop_ratio_range: Tuple[float, float] = (0.08, 1.0), + aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.), + max_attempts: int = 10, + interpolation: str = 'bilinear', + backend: str = 'cv2') -> None: + if isinstance(scale, Sequence): + assert len(scale) == 2 + assert scale[0] > 0 and scale[1] > 0 + self.scale = scale + else: + assert scale > 0 + self.scale = (scale, scale) + if (crop_ratio_range[0] > crop_ratio_range[1]) or ( + aspect_ratio_range[0] > aspect_ratio_range[1]): + raise ValueError( + 'range should be of kind (min, max). ' + f'But received crop_ratio_range {crop_ratio_range} ' + f'and aspect_ratio_range {aspect_ratio_range}.') + assert isinstance(max_attempts, int) and max_attempts >= 0, \ + 'max_attempts mush be int and no less than 0.' + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') + + self.crop_ratio_range = crop_ratio_range + self.aspect_ratio_range = aspect_ratio_range + self.max_attempts = max_attempts + self.interpolation = interpolation + self.backend = backend + + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. + """ + h, w = img.shape[:2] + area = h * w + + for _ in range(self.max_attempts): + target_area = np.random.uniform(*self.crop_ratio_range) * area + log_ratio = (math.log(self.aspect_ratio_range[0]), + math.log(self.aspect_ratio_range[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + target_w = int(round(math.sqrt(target_area * aspect_ratio))) + target_h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < target_w <= w and 0 < target_h <= h: + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + # Fallback to central crop + in_ratio = float(w) / float(h) + if in_ratio < min(self.aspect_ratio_range): + target_w = w + target_h = int(round(target_w / min(self.aspect_ratio_range))) + elif in_ratio > max(self.aspect_ratio_range): + target_h = h + target_w = int(round(target_h * max(self.aspect_ratio_range))) + else: # whole image + target_w = w + target_h = h + offset_h = (h - target_h) // 2 + offset_w = (w - target_w) // 2 + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly resized cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + target_w - 1, + offset_h + target_h - 1 + ])) + img = mmcv.imresize( + img, + tuple(self.scale[::-1]), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(scale={self.scale}' + repr_str += ', crop_ratio_range=' + repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}' + repr_str += ', aspect_ratio_range=' + repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}' + repr_str += f', max_attempts={self.max_attempts}' + repr_str += f', interpolation={self.interpolation}' + repr_str += f', backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class EfficientNetRandomCrop(RandomResizedCrop): + """EfficientNet style RandomResizedCrop. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + scale (int): Desired output scale of the crop. Only int size is + accepted, a square crop (size, size) is made. + min_covered (Number): Minimum ratio of the cropped area to the original + area. Defaults to 0.1. + crop_padding (int): The crop padding parameter in efficientnet style + center crop. Defaults to 32. + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). + max_attempts (int): Maximum number of attempts before falling back to + Central Crop. Defaults to 10. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to + 'bicubic'. + backend (str): The image resize backend type, accepted values are + 'cv2' and 'pillow'. Defaults to 'cv2'. + """ + + def __init__(self, + scale: int, + min_covered: float = 0.1, + crop_padding: int = 32, + interpolation: str = 'bicubic', + **kwarg): + assert isinstance(scale, int) + super().__init__(scale, interpolation=interpolation, **kwarg) + assert min_covered >= 0, 'min_covered should be no less than 0.' + assert crop_padding >= 0, 'crop_padding should be no less than 0.' + + self.min_covered = min_covered + self.crop_padding = crop_padding + + # https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. + """ + h, w = img.shape[:2] + area = h * w + min_target_area = self.crop_ratio_range[0] * area + max_target_area = self.crop_ratio_range[1] * area + + for _ in range(self.max_attempts): + aspect_ratio = np.random.uniform(*self.aspect_ratio_range) + min_target_h = int( + round(math.sqrt(min_target_area / aspect_ratio))) + max_target_h = int( + round(math.sqrt(max_target_area / aspect_ratio))) + + if max_target_h * aspect_ratio > w: + max_target_h = int((w + 0.5 - 1e-7) / aspect_ratio) + if max_target_h * aspect_ratio > w: + max_target_h -= 1 + + max_target_h = min(max_target_h, h) + min_target_h = min(max_target_h, min_target_h) + + # slightly differs from tf implementation + target_h = int( + round(np.random.uniform(min_target_h, max_target_h))) + target_w = int(round(target_h * aspect_ratio)) + target_area = target_h * target_w + + # slight differs from tf. In tf, if target_area > max_target_area, + # area will be recalculated + if (target_area < min_target_area or target_area > max_target_area + or target_w > w or target_h > h + or target_area < self.min_covered * area): + continue + + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + # Fallback to central crop + img_short = min(h, w) + crop_size = self.scale[0] / (self.scale[0] + + self.crop_padding) * img_short + + offset_h = max(0, int(round((h - crop_size) / 2.))) + offset_w = max(0, int(round((w - crop_size) / 2.))) + return offset_h, offset_w, crop_size, crop_size + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = super().__repr__()[:-1] + repr_str += f', min_covered={self.min_covered}' + repr_str += f', crop_padding={self.crop_padding})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomErasing(BaseTransform): + """Randomly selects a rectangle region in an image and erase pixels. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + erase_prob (float): Probability that image will be randomly erased. + Default: 0.5 + min_area_ratio (float): Minimum erased area / input image area + Default: 0.02 + max_area_ratio (float): Maximum erased area / input image area + Default: 0.4 + aspect_range (sequence | float): Aspect ratio range of erased area. + if float, it will be converted to (aspect_ratio, 1/aspect_ratio) + Default: (3/10, 10/3) + mode (str): Fill method in erased area, can be: + + - const (default): All pixels are assign with the same value. + - rand: each pixel is assigned with a random value in [0, 255] + + fill_color (sequence | Number): Base color filled in erased area. + Defaults to (128, 128, 128). + fill_std (sequence | Number, optional): If set and ``mode`` is 'rand', + fill erased area with random color from normal distribution + (mean=fill_color, std=fill_std); If not set, fill erased area with + random color from uniform distribution (0~255). Defaults to None. + + Note: + See `Random Erasing Data Augmentation + `_ + + This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as + default. The config of these 4 modes are: + + - RE-R: RandomErasing(mode='rand') + - RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5)) + - RE-0: RandomErasing(mode='const', fill_color=0) + - RE-255: RandomErasing(mode='const', fill_color=255) + """ + + def __init__(self, + erase_prob=0.5, + min_area_ratio=0.02, + max_area_ratio=0.4, + aspect_range=(3 / 10, 10 / 3), + mode='const', + fill_color=(128, 128, 128), + fill_std=None): + assert isinstance(erase_prob, float) and 0. <= erase_prob <= 1. + assert isinstance(min_area_ratio, float) and 0. <= min_area_ratio <= 1. + assert isinstance(max_area_ratio, float) and 0. <= max_area_ratio <= 1. + assert min_area_ratio <= max_area_ratio, \ + 'min_area_ratio should be smaller than max_area_ratio' + if isinstance(aspect_range, float): + aspect_range = min(aspect_range, 1 / aspect_range) + aspect_range = (aspect_range, 1 / aspect_range) + assert isinstance(aspect_range, Sequence) and len(aspect_range) == 2 \ + and all(isinstance(x, float) for x in aspect_range), \ + 'aspect_range should be a float or Sequence with two float.' + assert all(x > 0 for x in aspect_range), \ + 'aspect_range should be positive.' + assert aspect_range[0] <= aspect_range[1], \ + 'In aspect_range (min, max), min should be smaller than max.' + assert mode in ['const', 'rand'], \ + 'Please select `mode` from ["const", "rand"].' + if isinstance(fill_color, Number): + fill_color = [fill_color] * 3 + assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \ + and all(isinstance(x, Number) for x in fill_color), \ + 'fill_color should be a float or Sequence with three int.' + if fill_std is not None: + if isinstance(fill_std, Number): + fill_std = [fill_std] * 3 + assert isinstance(fill_std, Sequence) and len(fill_std) == 3 \ + and all(isinstance(x, Number) for x in fill_std), \ + 'fill_std should be a float or Sequence with three int.' + + self.erase_prob = erase_prob + self.min_area_ratio = min_area_ratio + self.max_area_ratio = max_area_ratio + self.aspect_range = aspect_range + self.mode = mode + self.fill_color = fill_color + self.fill_std = fill_std + + def _fill_pixels(self, img, top, left, h, w): + """Fill pixels to the patch of image.""" + if self.mode == 'const': + patch = np.empty((h, w, 3), dtype=np.uint8) + patch[:, :] = np.array(self.fill_color, dtype=np.uint8) + elif self.fill_std is None: + # Uniform distribution + patch = np.random.uniform(0, 256, (h, w, 3)).astype(np.uint8) + else: + # Normal distribution + patch = np.random.normal(self.fill_color, self.fill_std, (h, w, 3)) + patch = np.clip(patch.astype(np.int32), 0, 255).astype(np.uint8) + + img[top:top + h, left:left + w] = patch + return img + + @cache_randomness + def random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.erase_prob + + @cache_randomness + def random_patch(self, img_h, img_w): + """Randomly generate patch the erase.""" + # convert the aspect ratio to log space to equally handle width and + # height. + log_aspect_range = np.log( + np.array(self.aspect_range, dtype=np.float32)) + aspect_ratio = np.exp(np.random.uniform(*log_aspect_range)) + area = img_h * img_w + area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio) + + h = min(int(round(np.sqrt(area * aspect_ratio))), img_h) + w = min(int(round(np.sqrt(area / aspect_ratio))), img_w) + top = np.random.randint(0, img_h - h) if img_h > h else 0 + left = np.random.randint(0, img_w - w) if img_w > w else 0 + return top, left, h, w + + def transform(self, results): + """ + Args: + results (dict): Results dict from pipeline + + Returns: + dict: Results after the transformation. + """ + if self.random_disable(): + return results + + img = results['img'] + img_h, img_w = img.shape[:2] + + img = self._fill_pixels(img, *self.random_patch(img_h, img_w)) + + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(erase_prob={self.erase_prob}, ' + repr_str += f'min_area_ratio={self.min_area_ratio}, ' + repr_str += f'max_area_ratio={self.max_area_ratio}, ' + repr_str += f'aspect_range={self.aspect_range}, ' + repr_str += f'mode={self.mode}, ' + repr_str += f'fill_color={self.fill_color}, ' + repr_str += f'fill_std={self.fill_std})' + return repr_str + + +@TRANSFORMS.register_module() +class EfficientNetCenterCrop(BaseTransform): + r"""EfficientNet style center crop. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + crop_size (int): Expected size after cropping with the format + of (h, w). + crop_padding (int): The crop padding parameter in efficientnet style + center crop. Defaults to 32. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if + ``efficientnet_style`` is True. Defaults to 'bicubic'. + backend (str): The image resize backend type, accepted values are + `cv2` and `pillow`. Only valid if efficientnet style is True. + Defaults to `cv2`. + Notes: + - If the image is smaller than the crop size, return the original + image. + - The pipeline will be to first + to perform the center crop with the ``crop_size_`` as: + + .. math:: + + \text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} + + \text{crop_padding}} \times \text{short_edge} + + And then the pipeline resizes the img to the input crop size. + """ + + def __init__(self, + crop_size: int, + crop_padding: int = 32, + interpolation: str = 'bicubic', + backend: str = 'cv2'): + assert isinstance(crop_size, int) + assert crop_size > 0 + assert crop_padding >= 0 + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') + + self.crop_size = crop_size + self.crop_padding = crop_padding + self.interpolation = interpolation + self.backend = backend + + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: EfficientNet style center cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + h, w = img.shape[:2] + + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa + img_short = min(h, w) + crop_size = self.crop_size / (self.crop_size + + self.crop_padding) * img_short + + offset_h = max(0, int(round((h - crop_size) / 2.))) + offset_w = max(0, int(round((w - crop_size) / 2.))) + + # crop the image + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + crop_size - 1, + offset_h + crop_size - 1 + ])) + # resize image + img = mmcv.imresize( + img, (self.crop_size, self.crop_size), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', crop_padding={self.crop_padding}' + repr_str += f', interpolation={self.interpolation}' + repr_str += f', backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeEdge(BaseTransform): + """Resize images along the specified edge. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + **Added Keys:** + + - scale + - scale_factor + + Args: + scale (int): The edge scale to resizing. + edge (str): The edge to resize. Defaults to 'short'. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. + Defaults to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + Defaults to 'bilinear'. + """ + + def __init__(self, + scale: int, + edge: str = 'short', + backend: str = 'cv2', + interpolation: str = 'bilinear') -> None: + allow_edges = ['short', 'long', 'width', 'height'] + assert edge in allow_edges, \ + f'Invalid edge "{edge}", please specify from {allow_edges}.' + self.edge = edge + self.scale = scale + self.backend = backend + self.interpolation = interpolation + + def _resize_img(self, results: dict) -> None: + """Resize images with ``results['scale']``.""" + + img, w_scale, h_scale = mmcv.imresize( + results['img'], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape[:2] + results['scale'] = img.shape[:2][::-1] + results['scale_factor'] = (w_scale, h_scale) + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img', 'scale', 'scale_factor', + 'img_shape' keys are updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + h, w = results['img'].shape[:2] + if any([ + # conditions to resize the width + self.edge == 'short' and w < h, + self.edge == 'long' and w > h, + self.edge == 'width', + ]): + width = self.scale + height = int(self.scale * h / w) + else: + height = self.scale + width = int(self.scale * w / h) + results['scale'] = (width, height) + + self._resize_img(results) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'edge={self.edge}, ' + repr_str += f'backend={self.backend}, ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class ColorJitter(BaseTransform): + """Randomly change the brightness, contrast and saturation of an image. + + Modified from + https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py + Licensed under the BSD 3-Clause License. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + brightness (float | Sequence[float] (min, max)): How much to jitter + brightness. brightness_factor is chosen uniformly from + ``[max(0, 1 - brightness), 1 + brightness]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + contrast (float | Sequence[float] (min, max)): How much to jitter + contrast. contrast_factor is chosen uniformly from + ``[max(0, 1 - contrast), 1 + contrast]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + saturation (float | Sequence[float] (min, max)): How much to jitter + saturation. saturation_factor is chosen uniformly from + ``[max(0, 1 - saturation), 1 + saturation]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + hue (float | Sequence[float] (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from ``[-hue, hue]`` (0 <= hue + <= 0.5) or the given ``[min, max]`` (-0.5 <= min <= max <= 0.5). + Defaults to 0. + backend (str): The backend to operate the image. Defaults to 'pillow' + """ + + def __init__(self, + brightness: Union[float, Sequence[float]] = 0., + contrast: Union[float, Sequence[float]] = 0., + saturation: Union[float, Sequence[float]] = 0., + hue: Union[float, Sequence[float]] = 0., + backend='pillow'): + self.brightness = self._set_range(brightness, 'brightness') + self.contrast = self._set_range(contrast, 'contrast') + self.saturation = self._set_range(saturation, 'saturation') + self.hue = self._set_range(hue, 'hue', center=0, bound=(-0.5, 0.5)) + self.backend = backend + + def _set_range(self, value, name, center=1, bound=(0, float('inf'))): + """Set the range of magnitudes.""" + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f'If {name} is a single number, it must be non negative.') + value = (center - float(value), center + float(value)) + + if isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + value = np.clip(value, bound[0], bound[1]) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning(f'ColorJitter {name} values exceed the bound ' + f'{bound}, clipped to the bound.') + else: + raise TypeError(f'{name} should be a single number ' + 'or a list/tuple with length 2.') + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + else: + value = tuple(value) + + return value + + @cache_randomness + def _rand_params(self): + """Get random parameters including magnitudes and indices of + transforms.""" + trans_inds = np.random.permutation(4) + b, c, s, h = (None, ) * 4 + + if self.brightness is not None: + b = np.random.uniform(self.brightness[0], self.brightness[1]) + if self.contrast is not None: + c = np.random.uniform(self.contrast[0], self.contrast[1]) + if self.saturation is not None: + s = np.random.uniform(self.saturation[0], self.saturation[1]) + if self.hue is not None: + h = np.random.uniform(self.hue[0], self.hue[1]) + + return trans_inds, b, c, s, h + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: ColorJitter results, 'img' key is updated in result dict. + """ + img = results['img'] + trans_inds, brightness, contrast, saturation, hue = self._rand_params() + + for index in trans_inds: + if index == 0 and brightness is not None: + img = mmcv.adjust_brightness( + img, brightness, backend=self.backend) + elif index == 1 and contrast is not None: + img = mmcv.adjust_contrast(img, contrast, backend=self.backend) + elif index == 2 and saturation is not None: + img = mmcv.adjust_color( + img, alpha=saturation, backend=self.backend) + elif index == 3 and hue is not None: + img = mmcv.adjust_hue(img, hue, backend=self.backend) + + results['img'] = img + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(brightness={self.brightness}, ' + repr_str += f'contrast={self.contrast}, ' + repr_str += f'saturation={self.saturation}, ' + repr_str += f'hue={self.hue})' + return repr_str + + +@TRANSFORMS.register_module() +class Lighting(BaseTransform): + """Adjust images lighting using AlexNet-style PCA jitter. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + eigval (Sequence[float]): the eigenvalue of the convariance matrix + of pixel values, respectively. + eigvec (list[list]): the eigenvector of the convariance matrix of + pixel values, respectively. + alphastd (float): The standard deviation for distribution of alpha. + Defaults to 0.1. + to_rgb (bool): Whether to convert img to rgb. Defaults to False. + """ + + def __init__(self, + eigval: Sequence[float], + eigvec: Sequence[float], + alphastd: float = 0.1, + to_rgb: bool = False): + assert isinstance(eigval, Sequence), \ + f'eigval must be Sequence, got {type(eigval)} instead.' + assert isinstance(eigvec, Sequence), \ + f'eigvec must be Sequence, got {type(eigvec)} instead.' + for vec in eigvec: + assert isinstance(vec, Sequence) and len(vec) == len(eigvec[0]), \ + 'eigvec must contains lists with equal length.' + assert isinstance(alphastd, float), 'alphastd should be of type ' \ + f'float or int, got {type(alphastd)} instead.' + + self.eigval = np.array(eigval) + self.eigvec = np.array(eigvec) + self.alphastd = alphastd + self.to_rgb = to_rgb + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Lightinged results, 'img' key is updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + img = results['img'] + img_lighting = mmcv.adjust_lighting( + img, + self.eigval, + self.eigvec, + alphastd=self.alphastd, + to_rgb=self.to_rgb) + results['img'] = img_lighting.astype(img.dtype) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(eigval={self.eigval.tolist()}, ' + repr_str += f'eigvec={self.eigvec.tolist()}, ' + repr_str += f'alphastd={self.alphastd}, ' + repr_str += f'to_rgb={self.to_rgb})' + return repr_str + + +# 'Albu' is used in previous versions of mmpretrain, here is for compatibility +# users can use both 'Albumentations' and 'Albu'. +@TRANSFORMS.register_module(['Albumentations', 'Albu']) +class Albumentations(BaseTransform): + """Wrapper to use augmentation from albumentations library. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Adds custom transformations from albumentations library. + More details can be found in + `Albumentations `_. + An example of ``transforms`` is as followed: + + .. code-block:: + + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + + Args: + transforms (List[Dict]): List of albumentations transform configs. + keymap (Optional[Dict]): Mapping of mmpretrain to albumentations + fields, in format {'input key':'albumentation-style key'}. + Defaults to None. + + Example: + >>> import mmcv + >>> from mmpretrain.datasets import Albumentations + >>> transforms = [ + ... dict( + ... type='ShiftScaleRotate', + ... shift_limit=0.0625, + ... scale_limit=0.0, + ... rotate_limit=0, + ... interpolation=1, + ... p=0.5), + ... dict( + ... type='RandomBrightnessContrast', + ... brightness_limit=[0.1, 0.3], + ... contrast_limit=[0.1, 0.3], + ... p=0.2), + ... dict(type='ChannelShuffle', p=0.1), + ... dict( + ... type='OneOf', + ... transforms=[ + ... dict(type='Blur', blur_limit=3, p=1.0), + ... dict(type='MedianBlur', blur_limit=3, p=1.0) + ... ], + ... p=0.1), + ... ] + >>> albu = Albumentations(transforms) + >>> data = {'img': mmcv.imread('./demo/demo.JPEG')} + >>> data = albu(data) + >>> print(data['img'].shape) + (375, 500, 3) + """ + + def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + else: + from albumentations import Compose as albu_Compose + + assert isinstance(transforms, list), 'transforms must be a list.' + if keymap is not None: + assert isinstance(keymap, dict), 'keymap must be None or a dict. ' + + self.transforms = transforms + + self.aug = albu_Compose( + [self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = dict(img='image') + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: Dict): + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \ + "transforms must be a dict with keyword 'type'." + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d, keymap): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def transform(self, results: Dict) -> Dict: + """Transform function to perform albumentations transforms. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Transformed results, 'img' and 'img_shape' keys are + updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + results = self.aug(**results) + + # back to the original format + results = self.mapper(results, self.keymap_back) + results['img_shape'] = results['img'].shape[:2] + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(transforms={repr(self.transforms)})' + return repr_str + + +@TRANSFORMS.register_module() +class SimMIMMaskGenerator(BaseTransform): + """Generate random block mask for each Image. + + **Added Keys**: + + - mask + + This module is used in SimMIM to generate masks. + + Args: + input_size (int): Size of input image. Defaults to 192. + mask_patch_size (int): Size of each block mask. Defaults to 32. + model_patch_size (int): Patch size of each token. Defaults to 4. + mask_ratio (float): The mask ratio of image. Defaults to 0.6. + """ + + def __init__(self, + input_size: int = 192, + mask_patch_size: int = 32, + model_patch_size: int = 4, + mask_ratio: float = 0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def transform(self, results: dict) -> dict: + """Method to generate random block mask for each Image in SimMIM. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask_idx = np.random.permutation(self.token_count)[:self.mask_count] + mask = np.zeros(self.token_count, dtype=int) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + results.update({'mask': mask}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(input_size={self.input_size}, ' + repr_str += f'mask_patch_size={self.mask_patch_size}, ' + repr_str += f'model_patch_size={self.model_patch_size}, ' + repr_str += f'mask_ratio={self.mask_ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class BEiTMaskGenerator(BaseTransform): + """Generate mask for image. + + **Added Keys**: + + - mask + + This module is borrowed from + https://github.com/microsoft/unilm/tree/master/beit + + Args: + input_size (int): The size of input image. + num_masking_patches (int): The number of patches to be masked. + min_num_patches (int): The minimum number of patches to be masked + in the process of generating mask. Defaults to 4. + max_num_patches (int, optional): The maximum number of patches to be + masked in the process of generating mask. Defaults to None. + min_aspect (float): The minimum aspect ratio of mask blocks. Defaults + to 0.3. + min_aspect (float, optional): The minimum aspect ratio of mask blocks. + Defaults to None. + """ + + def __init__(self, + input_size: int, + num_masking_patches: int, + min_num_patches: int = 4, + max_num_patches: Optional[int] = None, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None) -> None: + if not isinstance(input_size, tuple): + input_size = (input_size, ) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + + self.num_masking_patches = num_masking_patches + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None \ + else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int: + """Generate mask recursively. + + Args: + mask (np.ndarray): The mask to be generated. + max_mask_patches (int): The maximum number of patches to be masked. + + Returns: + int: The number of patches masked. + """ + delta = 0 + for _ in range(10): + target_area = np.random.uniform(self.min_num_patches, + max_mask_patches) + aspect_ratio = math.exp(np.random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = np.random.randint(0, self.height - h) + left = np.random.randint(0, self.width - w) + + num_masked = mask[top:top + h, left:left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + if delta > 0: + break + return delta + + def transform(self, results: dict) -> dict: + """Method to generate random block mask for each Image in BEiT. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask = np.zeros(shape=(self.height, self.width), dtype=int) + + mask_count = 0 + while mask_count != self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + mask_count += delta + results.update({'mask': mask}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(height={self.height}, ' + repr_str += f'width={self.width}, ' + repr_str += f'num_patches={self.num_patches}, ' + repr_str += f'num_masking_patches={self.num_masking_patches}, ' + repr_str += f'min_num_patches={self.min_num_patches}, ' + repr_str += f'max_num_patches={self.max_num_patches}, ' + repr_str += f'log_aspect_ratio={self.log_aspect_ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform): + """Crop the given PIL Image to random size and aspect ratio with random + interpolation. + + **Required Keys**: + + - img + + **Modified Keys**: + + - img + + **Added Keys**: + + - target_img + + This module is borrowed from + https://github.com/microsoft/unilm/tree/master/beit. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a + random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio + is made. This crop is finally resized to given size. This is popularly used + to train the Inception networks. This module first crops the image and + resizes the crop to two different sizes. + + Args: + size (Union[tuple, int]): Expected output size of each edge of the + first image. + second_size (Union[tuple, int], optional): Expected output size of each + edge of the second image. + scale (tuple[float, float]): Range of size of the origin size cropped. + Defaults to (0.08, 1.0). + ratio (tuple[float, float]): Range of aspect ratio of the origin aspect + ratio cropped. Defaults to (3./4., 4./3.). + interpolation (str): The interpolation for the first image. Defaults + to ``bilinear``. + second_interpolation (str): The interpolation for the second image. + Defaults to ``lanczos``. + """ + + def __init__(self, + size: Union[tuple, int], + second_size=None, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation='bilinear', + second_interpolation='lanczos') -> None: + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if second_size is not None: + if isinstance(second_size, tuple): + self.second_size = second_size + else: + self.second_size = (second_size, second_size) + else: + self.second_size = None + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + ('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = ('bilinear', 'bicubic') + else: + self.interpolation = interpolation + self.second_interpolation = second_interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img: np.ndarray, scale: tuple, + ratio: tuple) -> Sequence[int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (np.ndarray): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect + ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + img_h, img_w = img.shape[:2] + area = img_h * img_w + + for _ in range(10): + target_area = np.random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img_w and h < img_h: + i = np.random.randint(0, img_h - h) + j = np.random.randint(0, img_w - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img_w / img_h + if in_ratio < min(ratio): + w = img_w + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img_h + w = int(round(h * max(ratio))) + else: # whole image + w = img_w + h = img_h + i = (img_h - h) // 2 + j = (img_w - w) // 2 + return i, j, h, w + + def transform(self, results: dict) -> dict: + """Crop the given image and resize it to two different sizes. + + This module crops the given image randomly and resize the crop to two + different sizes. This is popularly used in BEiT-style masked image + modeling, where an off-the-shelf model is used to provide the target. + + Args: + results (dict): Results from previous pipeline. + + Returns: + dict: Results after applying this transformation. + """ + img = results['img'] + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = np.random.choice(self.interpolation) + else: + interpolation = self.interpolation + if self.second_size is None: + img = img[i:i + h, j:j + w] + img = mmcv.imresize(img, self.size, interpolation=interpolation) + results.update({'img': img}) + else: + img = img[i:i + h, j:j + w] + img_sample = mmcv.imresize( + img, self.size, interpolation=interpolation) + img_target = mmcv.imresize( + img, self.second_size, interpolation=self.second_interpolation) + results.update({'img': [img_sample, img_target]}) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, ' + repr_str += f'second_size={self.second_size}, ' + repr_str += f'interpolation={self.interpolation}, ' + repr_str += f'second_interpolation={self.second_interpolation}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'ratio={self.ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class CleanCaption(BaseTransform): + """Clean caption text. + + Remove some useless punctuation for the caption task. + + **Required Keys:** + + - ``*keys`` + + **Modified Keys:** + + - ``*keys`` + + Args: + keys (Sequence[str], optional): The keys of text to be cleaned. + Defaults to 'gt_caption'. + remove_chars (str): The characters to be removed. Defaults to + :py:attr:`string.punctuation`. + lowercase (bool): Whether to convert the text to lowercase. + Defaults to True. + remove_dup_space (bool): Whether to remove duplicated whitespaces. + Defaults to True. + strip (bool): Whether to remove leading and trailing whitespaces. + Defaults to True. + """ + + def __init__( + self, + keys='gt_caption', + remove_chars=string.punctuation, + lowercase=True, + remove_dup_space=True, + strip=True, + ): + if isinstance(keys, str): + keys = [keys] + self.keys = keys + self.transtab = str.maketrans({ch: None for ch in remove_chars}) + self.lowercase = lowercase + self.remove_dup_space = remove_dup_space + self.strip = strip + + def _clean(self, text): + """Perform text cleaning before tokenizer.""" + + if self.strip: + text = text.strip() + + text = text.translate(self.transtab) + + if self.remove_dup_space: + text = re.sub(r'\s{2,}', ' ', text) + + if self.lowercase: + text = text.lower() + + return text + + def clean(self, text): + """Perform text cleaning before tokenizer.""" + if isinstance(text, (list, tuple)): + return [self._clean(item) for item in text] + elif isinstance(text, str): + return self._clean(text) + else: + raise TypeError('text must be a string or a list of strings') + + def transform(self, results: dict) -> dict: + """Method to clean the input text data.""" + for key in self.keys: + results[key] = self.clean(results[key]) + return results + + +@TRANSFORMS.register_module() +class OFAAddObjects(BaseTransform): + + def transform(self, results: dict) -> dict: + if 'objects' not in results: + raise ValueError( + 'Some OFA fine-tuned models requires `objects` field in the ' + 'dataset, which is generated by VinVL. Or please use ' + 'zero-shot configs. See ' + 'https://github.com/OFA-Sys/OFA/issues/189') + + if 'question' in results: + prompt = '{} object: {}'.format( + results['question'], + ' '.join(results['objects']), + ) + results['decoder_prompt'] = prompt + results['question'] = prompt + + +@TRANSFORMS.register_module() +class RandomTranslatePad(BaseTransform): + + def __init__(self, size=640, aug_translate=False): + self.size = size + self.aug_translate = aug_translate + + @cache_randomness + def rand_translate_params(self, dh, dw): + top = np.random.randint(0, dh) + left = np.random.randint(0, dw) + return top, left + + def transform(self, results: dict) -> dict: + img = results['img'] + h, w = img.shape[:-1] + dw = self.size - w + dh = self.size - h + if self.aug_translate: + top, left = self.rand_translate_params(dh, dw) + else: + top = round(dh / 2.0 - 0.1) + left = round(dw / 2.0 - 0.1) + + out_img = np.zeros((self.size, self.size, 3), dtype=np.float32) + out_img[top:top + h, left:left + w, :] = img + results['img'] = out_img + results['img_shape'] = (self.size, self.size) + + # translate box + if 'gt_bboxes' in results.keys(): + for i in range(len(results['gt_bboxes'])): + box = results['gt_bboxes'][i] + box[0], box[2] = box[0] + left, box[2] + left + box[1], box[3] = box[1] + top, box[3] + top + results['gt_bboxes'][i] = box + + return results + + +@TRANSFORMS.register_module() +class MAERandomResizedCrop(transforms.RandomResizedCrop): + """RandomResizedCrop for matching TF/TPU implementation: no for-loop is + used. + + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501 + """ + + @staticmethod + def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple: + width, height = img.size + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1, )).item() + j = torch.randint(0, width - w + 1, size=(1, )).item() + + return i, j, h, w + + def forward(self, results: dict) -> dict: + """The forward function of MAERandomResizedCrop. + + Args: + results (dict): The results dict contains the image and all these + information related to the image. + + Returns: + dict: The results dict contains the cropped image and all these + information related to the image. + """ + img = results['img'] + i, j, h, w = self.get_params(img, self.scale, self.ratio) + img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + results['img'] = img + return results diff --git a/mmpretrain/datasets/transforms/utils.py b/mmpretrain/datasets/transforms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7940486fc9904c14f5a5a4a959022c11456c968 --- /dev/null +++ b/mmpretrain/datasets/transforms/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Union + +from mmcv.transforms import BaseTransform + +PIPELINE_TYPE = List[Union[dict, BaseTransform]] + + +def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int: + """Returns the index of the transform in a pipeline. + + Args: + pipeline (List[dict] | List[BaseTransform]): The transforms list. + target (str): The target transform class name. + + Returns: + int: The transform index. Returns -1 if not found. + """ + for i, transform in enumerate(pipeline): + if isinstance(transform, dict): + if isinstance(transform['type'], type): + if transform['type'].__name__ == target: + return i + else: + if transform['type'] == target: + return i + else: + if transform.__class__.__name__ == target: + return i + + return -1 + + +def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False): + """Remove the target transform type from the pipeline. + + Args: + pipeline (List[dict] | List[BaseTransform]): The transforms list. + target (str): The target transform class name. + inplace (bool): Whether to modify the pipeline inplace. + + Returns: + The modified transform. + """ + idx = get_transform_idx(pipeline, target) + if not inplace: + pipeline = copy.deepcopy(pipeline) + while idx >= 0: + pipeline.pop(idx) + idx = get_transform_idx(pipeline, target) + + return pipeline diff --git a/mmpretrain/datasets/transforms/wrappers.py b/mmpretrain/datasets/transforms/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dfd730b4db0dc80ed315b79658cfbf683e4035 --- /dev/null +++ b/mmpretrain/datasets/transforms/wrappers.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Callable, List, Union + +from mmcv.transforms import BaseTransform, Compose + +from mmpretrain.registry import TRANSFORMS + +# Define type of transform or transform config +Transform = Union[dict, Callable[[dict], dict]] + + +@TRANSFORMS.register_module() +class MultiView(BaseTransform): + """A transform wrapper for multiple views of an image. + + Args: + transforms (list[dict | callable], optional): Sequence of transform + object or config dict to be wrapped. + mapping (dict): A dict that defines the input key mapping. + The keys corresponds to the inner key (i.e., kwargs of the + ``transform`` method), and should be string type. The values + corresponds to the outer keys (i.e., the keys of the + data/results), and should have a type of string, list or dict. + None means not applying input mapping. Default: None. + allow_nonexist_keys (bool): If False, the outer keys in the mapping + must exist in the input data, or an exception will be raised. + Default: False. + + Examples: + >>> # Example 1: MultiViews 1 pipeline with 2 views + >>> pipeline = [ + >>> dict(type='MultiView', + >>> num_views=2, + >>> transforms=[ + >>> [ + >>> dict(type='Resize', scale=224))], + >>> ]) + >>> ] + >>> # Example 2: MultiViews 2 pipelines, the first with 2 views, + >>> # the second with 6 views + >>> pipeline = [ + >>> dict(type='MultiView', + >>> num_views=[2, 6], + >>> transforms=[ + >>> [ + >>> dict(type='Resize', scale=224)], + >>> [ + >>> dict(type='Resize', scale=224), + >>> dict(type='RandomSolarize')], + >>> ]) + >>> ] + """ + + def __init__(self, transforms: List[List[Transform]], + num_views: Union[int, List[int]]) -> None: + + if isinstance(num_views, int): + num_views = [num_views] + assert isinstance(num_views, List) + assert len(num_views) == len(transforms) + self.num_views = num_views + + self.pipelines = [] + for trans in transforms: + pipeline = Compose(trans) + self.pipelines.append(pipeline) + + self.transforms = [] + for i in range(len(num_views)): + self.transforms.extend([self.pipelines[i]] * num_views[i]) + + def transform(self, results: dict) -> dict: + """Apply transformation to inputs. + + Args: + results (dict): Result dict from previous pipelines. + + Returns: + dict: Transformed results. + """ + multi_views_outputs = dict(img=[]) + for trans in self.transforms: + inputs = copy.deepcopy(results) + outputs = trans(inputs) + + multi_views_outputs['img'].append(outputs['img']) + results.update(multi_views_outputs) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + '(' + for i, p in enumerate(self.pipelines): + repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n' + repr_str += str(p) + repr_str += ')' + return repr_str + + +@TRANSFORMS.register_module() +class ApplyToList(BaseTransform): + """A transform wrapper to apply the wrapped transforms to a list of items. + For example, to load and resize a list of images. + + Args: + transforms (list[dict | callable]): Sequence of transform config dict + to be wrapped. + scatter_key (str): The key to scatter data dict. If the field is a + list, scatter the list to multiple data dicts to do transformation. + collate_keys (List[str]): The keys to collate from multiple data dicts. + The fields in ``collate_keys`` will be composed into a list after + transformation, and the other fields will be adopted from the + first data dict. + """ + + def __init__(self, transforms, scatter_key, collate_keys): + super().__init__() + + self.transforms = Compose([TRANSFORMS.build(t) for t in transforms]) + self.scatter_key = scatter_key + self.collate_keys = set(collate_keys) + self.collate_keys.add(self.scatter_key) + + def transform(self, results: dict): + scatter_field = results.get(self.scatter_key) + + if isinstance(scatter_field, list): + scattered_results = [] + for item in scatter_field: + single_results = copy.deepcopy(results) + single_results[self.scatter_key] = item + scattered_results.append(self.transforms(single_results)) + + final_output = scattered_results[0] + + # merge output list to single output + for key in scattered_results[0].keys(): + if key in self.collate_keys: + final_output[key] = [ + single[key] for single in scattered_results + ] + return final_output + else: + return self.transforms(results) diff --git a/mmpretrain/datasets/utils.py b/mmpretrain/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb60e432c374c1a904700a7348f706fa0e523eb --- /dev/null +++ b/mmpretrain/datasets/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import hashlib +import os +import os.path +import shutil +import tarfile +import tempfile +import urllib.error +import urllib.request +import zipfile + +from mmengine.fileio import LocalBackend, get_file_backend + +__all__ = [ + 'rm_suffix', 'check_integrity', 'download_and_extract_archive', + 'open_maybe_compressed_file' +] + + +def rm_suffix(s, suffix=None): + if suffix is None: + return s[:s.rfind('.')] + else: + return s[:s.rfind(suffix)] + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024): + md5 = hashlib.md5() + backend = get_file_backend(fpath, enable_singleton=True) + if isinstance(backend, LocalBackend): + # Enable chunk update for local file. + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + else: + md5.update(backend.get(fpath)) + return md5.hexdigest() + + +def check_md5(fpath, md5, **kwargs): + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath, md5=None): + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + """Download object at the given URL to a local path. + + Modified from + https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file + + Args: + url (str): URL of the object to download + dst (str): Full path where object will be saved, + e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded + file should start with ``hash_prefix``. Defaults to None. + progress (bool): whether or not to display a progress bar to stderr. + Defaults to True + """ + file_size = None + req = urllib.request.Request(url) + u = urllib.request.urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders('Content-Length') + else: + content_length = meta.get_all('Content-Length') + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after download is + # complete. This prevents a local file being overridden by a broken + # download. + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + import rich.progress + columns = [ + rich.progress.DownloadColumn(), + rich.progress.BarColumn(bar_width=None), + rich.progress.TimeRemainingColumn(), + ] + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with rich.progress.Progress(*columns) as pbar: + task = pbar.add_task('download', total=file_size, visible=progress) + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(task, advance=len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError( + 'invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def download_url(url, root, filename=None, md5=None): + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from. + root (str): Directory to place downloaded file in. + filename (str | None): Name to save the file under. + If filename is None, use the basename of the URL. + md5 (str | None): MD5 checksum of the download. + If md5 is None, download without md5 check. + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f'Using downloaded and verified file: {fpath}') + else: + try: + print(f'Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + except (urllib.error.URLError, IOError) as e: + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + f' Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError('File not found or corrupted.') + + +def _is_tarxz(filename): + return filename.endswith('.tar.xz') + + +def _is_tar(filename): + return filename.endswith('.tar') + + +def _is_targz(filename): + return filename.endswith('.tar.gz') + + +def _is_tgz(filename): + return filename.endswith('.tgz') + + +def _is_gzip(filename): + return filename.endswith('.gz') and not filename.endswith('.tar.gz') + + +def _is_zip(filename): + return filename.endswith('.zip') + + +def extract_archive(from_path, to_path=None, remove_finished=False): + if to_path is None: + to_path = os.path.dirname(from_path) + + if _is_tar(from_path): + with tarfile.open(from_path, 'r') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path) or _is_tgz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_tarxz(from_path): + with tarfile.open(from_path, 'r:xz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join( + to_path, + os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError(f'Extraction of {from_path} not supported') + + if remove_finished: + os.remove(from_path) + + +def download_and_extract_archive(url, + download_root, + extract_root=None, + filename=None, + md5=None, + remove_finished=False): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f'Extracting {archive} to {extract_root}') + extract_archive(archive, extract_root, remove_finished) + + +def open_maybe_compressed_file(path: str): + """Return a file object that possibly decompresses 'path' on the fly. + + Decompression occurs when argument `path` is a string and ends with '.gz' + or '.xz'. + """ + if not isinstance(path, str): + return path + if path.endswith('.gz'): + import gzip + return gzip.open(path, 'rb') + if path.endswith('.xz'): + import lzma + return lzma.open(path, 'rb') + return open(path, 'rb') diff --git a/mmpretrain/datasets/vg_vqa.py b/mmpretrain/datasets/vg_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..2d83884c804086c060bcfe27e833bff28dc28e9e --- /dev/null +++ b/mmpretrain/datasets/vg_vqa.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.fileio import load + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class VGVQA(BaseDataset): + """Visual Genome VQA dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list. + + Compare to BaseDataset, the only difference is that coco_vqa annotation + file is already a list of data. There is no 'metainfo'. + """ + + raw_data_list = load(self.ann_file) + if not isinstance(raw_data_list, list): + raise TypeError( + f'The VQA annotations loaded from annotation file ' + f'should be a dict, but got {type(raw_data_list)}!') + + # load and parse data_infos. + data_list = [] + for raw_data_info in raw_data_list: + # parse raw data information to target format + data_info = self.parse_data_info(raw_data_info) + if isinstance(data_info, dict): + # For VQA tasks, each `data_info` looks like: + # { + # "question_id": 986769, + # "question": "How many people are there?", + # "answer": "two", + # "image": "image/1.jpg", + # "dataset": "vg" + # } + + # change 'image' key to 'img_path' + # TODO: This process will be removed, after the annotation file + # is preprocess. + data_info['img_path'] = data_info['image'] + del data_info['image'] + + if 'answer' in data_info: + # add answer_weight & answer_count, delete duplicate answer + if data_info['dataset'] == 'vqa': + answer_weight = {} + for answer in data_info['answer']: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len( + data_info['answer']) + else: + answer_weight[answer] = 1 / len( + data_info['answer']) + + data_info['answer'] = list(answer_weight.keys()) + data_info['answer_weight'] = list( + answer_weight.values()) + data_info['answer_count'] = len(answer_weight) + + elif data_info['dataset'] == 'vg': + data_info['answers'] = [data_info['answer']] + data_info['answer_weight'] = [0.2] + data_info['answer_count'] = 1 + + data_list.append(data_info) + + else: + raise TypeError( + f'Each VQA data element loaded from annotation file ' + f'should be a dict, but got {type(data_info)}!') + + return data_list diff --git a/mmpretrain/datasets/visual_genome.py b/mmpretrain/datasets/visual_genome.py new file mode 100644 index 0000000000000000000000000000000000000000..8c33b86c4f81d0be0f2830618ad100196b461dcf --- /dev/null +++ b/mmpretrain/datasets/visual_genome.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from itertools import chain +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VisualGenomeQA(BaseDataset): + """Visual Genome Question Answering dataset. + + dataset structure: :: + + data_root + ├── image + │   ├── 1.jpg + │   ├── 2.jpg + │   └── ... + └── question_answers.json + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. Defaults to ``"image"``. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to ``"question_answers.json"``. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str = 'image', + ann_file: str = 'question_answers.json', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _create_image_index(self): + img_prefix = self.data_prefix['img_path'] + + files = mmengine.list_dir_or_file(img_prefix, list_dir=False) + image_index = {} + for file in files: + image_id = re.findall(r'\d+', file) + if len(image_id) > 0: + image_id = int(image_id[-1]) + image_index[image_id] = mmengine.join_path(img_prefix, file) + + return image_index + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + # The original Visual Genome annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + data_list = [] + for qas in chain.from_iterable(ann['qas'] for ann in annotations): + # ann example + # { + # 'id': 1, + # 'qas': [ + # { + # 'a_objects': [], + # 'question': 'What color is the clock?', + # 'image_id': 1, + # 'qa_id': 986768, + # 'answer': 'Two.', + # 'q_objects': [], + # } + # ... + # ] + # } + + data_info = { + 'img_path': self.image_index[qas['image_id']], + 'quesiton': qas['quesiton'], + 'question_id': qas['question_id'], + 'image_id': qas['image_id'], + 'gt_answer': [qas['answer']], + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/vizwiz.py b/mmpretrain/datasets/vizwiz.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5dd394524cac5ad514351ac2a93286c75e1b17 --- /dev/null +++ b/mmpretrain/datasets/vizwiz.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VizWiz(BaseDataset): + """VizWiz dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # { + # "image": "VizWiz_val_00000001.jpg", + # "question": "Can you tell me what this medicine is please?", + # "answers": [ + # { + # "answer": "no", + # "answer_confidence": "yes" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time cold medicine", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time medicine", + # "answer_confidence": "yes" + # } + # ], + # "answer_type": "other", + # "answerable": 1 + # }, + data_info = dict() + data_info['question'] = ann['question'] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + + if 'answerable' not in ann: + data_list.append(data_info) + else: + if ann['answerable'] == 1: + # add answer_weight & answer_count, delete duplicate answer + answers = [] + for item in ann.pop('answers'): + if item['answer_confidence'] == 'yes' and item[ + 'answer'] != 'unanswerable': + answers.append(item['answer']) + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + # data_info.update(ann) + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/voc.py b/mmpretrain/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..39544de7a1794a2d965189c692f652cc56b218f9 --- /dev/null +++ b/mmpretrain/datasets/voc.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import xml.etree.ElementTree as ET +from typing import List, Optional, Union + +from mmengine import get_file_backend, list_from_file +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import expanduser +from .categories import VOC2007_CATEGORIES +from .multi_label import MultiLabelDataset + + +@DATASETS.register_module() +class VOC(MultiLabelDataset): + """`Pascal VOC `_ Dataset. + + After decompression, the dataset directory structure is as follows: + + VOC dataset directory: :: + + VOC2007 + ├── JPEGImages + │ ├── xxx.jpg + │ ├── xxy.jpg + │ └── ... + ├── Annotations + │ ├── xxx.xml + │ ├── xxy.xml + │ └── ... + └── ImageSets + └── Main + ├── train.txt + ├── val.txt + ├── trainval.txt + ├── test.txt + └── ... + + Extra difficult label is in VOC annotations, we will use + `gt_label_difficult` to record the difficult labels in each sample + and corresponding evaluation should take care of this field + to calculate metrics. Usually, difficult labels are reckoned as + negative in defaults. + + Args: + data_root (str): The root directory for VOC dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + image_set_path (str, optional): The path of image set, The file which + lists image ids of the sub dataset, and this path is relative + to ``data_root``. Default to ''. + data_prefix (dict): Prefix for data and annotation, keyword + 'img_path' and 'ann_path' can be set. Defaults to be + ``dict(img_path='JPEGImages', ann_path='Annotations')``. + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import VOC + >>> train_dataset = VOC(data_root='data/VOC2007', split='trainval') + >>> train_dataset + Dataset VOC + Number of samples: 5011 + Number of categories: 20 + Prefix of dataset: data/VOC2007 + Path of image set: data/VOC2007/ImageSets/Main/trainval.txt + Prefix of images: data/VOC2007/JPEGImages + Prefix of annotations: data/VOC2007/Annotations + >>> test_dataset = VOC(data_root='data/VOC2007', split='test') + >>> test_dataset + Dataset VOC + Number of samples: 4952 + Number of categories: 20 + Prefix of dataset: data/VOC2007 + Path of image set: data/VOC2007/ImageSets/Main/test.txt + Prefix of images: data/VOC2007/JPEGImages + Prefix of annotations: data/VOC2007/Annotations + """ # noqa: E501 + + METAINFO = {'classes': VOC2007_CATEGORIES} + + def __init__(self, + data_root: str, + split: str = 'trainval', + image_set_path: str = '', + data_prefix: Union[str, dict] = dict( + img_path='JPEGImages', ann_path='Annotations'), + test_mode: bool = False, + metainfo: Optional[dict] = None, + **kwargs): + + self.backend = get_file_backend(data_root, enable_singleton=True) + + if split: + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + if not data_prefix: + data_prefix = dict( + img_path='JPEGImages', ann_path='Annotations') + if not image_set_path: + image_set_path = self.backend.join_path( + 'ImageSets', 'Main', f'{split}.txt') + + # To handle the BC-breaking + if (split == 'train' or split == 'trainval') and test_mode: + logger = MMLogger.get_current_instance() + logger.warning(f'split="{split}" but test_mode=True. ' + f'The {split} set will be used.') + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + assert isinstance(data_prefix, dict) and 'img_path' in data_prefix, \ + '`data_prefix` must be a dict with key img_path' + + if (split and split not in ['val', 'test']) or not test_mode: + assert 'ann_path' in data_prefix and data_prefix[ + 'ann_path'] is not None, \ + '"ann_path" must be set in `data_prefix`' \ + 'when validation or test set is used.' + + self.data_root = data_root + self.image_set_path = self.backend.join_path(data_root, image_set_path) + + super().__init__( + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + @property + def ann_prefix(self): + """The prefix of images.""" + if 'ann_path' in self.data_prefix: + return self.data_prefix['ann_path'] + else: + return None + + def _get_labels_from_xml(self, img_id): + """Get gt_labels and labels_difficult from xml file.""" + xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml') + content = self.backend.get(xml_path) + root = ET.fromstring(content) + + labels, labels_difficult = set(), set() + for obj in root.findall('object'): + label_name = obj.find('name').text + # in case customized dataset has wrong labels + # or CLASSES has been override. + if label_name not in self.CLASSES: + continue + label = self.class_to_idx[label_name] + difficult = int(obj.find('difficult').text) + if difficult: + labels_difficult.add(label) + else: + labels.add(label) + + return list(labels), list(labels_difficult) + + def load_data_list(self): + """Load images and ground truth labels.""" + data_list = [] + img_ids = list_from_file(self.image_set_path) + + for img_id in img_ids: + img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg') + + labels, labels_difficult = None, None + if self.ann_prefix is not None: + labels, labels_difficult = self._get_labels_from_xml(img_id) + + info = dict( + img_path=img_path, + gt_label=labels, + gt_label_difficult=labels_difficult) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Prefix of dataset: \t{self.data_root}', + f'Path of image set: \t{self.image_set_path}', + f'Prefix of images: \t{self.img_prefix}', + f'Prefix of annotations: \t{self.ann_prefix}' + ] + + return body diff --git a/mmpretrain/datasets/vsr.py b/mmpretrain/datasets/vsr.py new file mode 100644 index 0000000000000000000000000000000000000000..7b109592bd020d57e3db8f2ff610901e2a1d9f31 --- /dev/null +++ b/mmpretrain/datasets/vsr.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VSR(BaseDataset): + """VSR: Visual Spatial Reasoning dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # "image": "train2017/000000372029.jpg", + # "question": "The dog is on the surfboard.", + # "answer": true + # } + data_info = dict() + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = 'yes' if ann['answer'] else 'no' + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/engine/.DS_Store b/mmpretrain/engine/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a32d42bbcc89e9c0dc0122ca60ba95a14bd59e32 Binary files /dev/null and b/mmpretrain/engine/.DS_Store differ diff --git a/mmpretrain/engine/__init__.py b/mmpretrain/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..332fea0909b4abdc6a83cf7662ea916a777d99dd --- /dev/null +++ b/mmpretrain/engine/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401, F403 +from .optimizers import * # noqa: F401, F403 +from .runners import * # noqa: F401, F403 +from .schedulers import * # noqa: F401, F403 diff --git a/mmpretrain/engine/hooks/__init__.py b/mmpretrain/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9e22be7e96d636f202066f2e00e7699b730619 --- /dev/null +++ b/mmpretrain/engine/hooks/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_num_check_hook import ClassNumCheckHook +from .densecl_hook import DenseCLHook +from .ema_hook import EMAHook +from .margin_head_hooks import SetAdaptiveMarginsHook +from .precise_bn_hook import PreciseBNHook +from .retriever_hooks import PrepareProtoBeforeValLoopHook +from .simsiam_hook import SimSiamHook +from .swav_hook import SwAVHook +from .switch_recipe_hook import SwitchRecipeHook +from .visualization_hook import VisualizationHook +from .warmup_param_hook import WarmupParamHook + +__all__ = [ + 'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook', + 'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook', + 'SetAdaptiveMarginsHook', 'EMAHook', 'SimSiamHook', 'DenseCLHook', + 'SwAVHook', 'WarmupParamHook' +] diff --git a/mmpretrain/engine/hooks/class_num_check_hook.py b/mmpretrain/engine/hooks/class_num_check_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..38170d6604810c575aa5c2c9435c0b75cfa761b2 --- /dev/null +++ b/mmpretrain/engine/hooks/class_num_check_hook.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved +from mmengine.hooks import Hook +from mmengine.utils import is_seq_of + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class ClassNumCheckHook(Hook): + """Class Number Check HOOK.""" + + def _check_head(self, runner, dataset): + """Check whether the `num_classes` in head matches the length of + `CLASSES` in `dataset`. + + Args: + runner (obj:`Runner`): runner object. + dataset (obj: `BaseDataset`): the dataset to check. + """ + model = runner.model + if dataset.CLASSES is None: + runner.logger.warning( + f'Please set class information in `metainfo` ' + f'in the {dataset.__class__.__name__} and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + assert is_seq_of(dataset.CLASSES, str), \ + (f'Class information in `metainfo` in ' + f'{dataset.__class__.__name__} should be a tuple of str.') + for _, module in model.named_modules(): + if hasattr(module, 'num_classes'): + assert module.num_classes == len(dataset.CLASSES), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of class information in `metainfo` ' + f'{len(dataset.CLASSES)}) in ' + f'{dataset.__class__.__name__}') + + def before_train(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj: `IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.train_dataloader.dataset) + + def before_val(self, runner): + """Check whether the validation dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.val_dataloader.dataset) + + def before_test(self, runner): + """Check whether the test dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.test_dataloader.dataset) diff --git a/mmpretrain/engine/hooks/densecl_hook.py b/mmpretrain/engine/hooks/densecl_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7e17d3419cbc2a540d3aecd81e223eed670df2 --- /dev/null +++ b/mmpretrain/engine/hooks/densecl_hook.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class DenseCLHook(Hook): + """Hook for DenseCL. + + This hook includes ``loss_lambda`` warmup in DenseCL. + Borrowed from the authors' code: ``_. + + Args: + start_iters (int): The number of warmup iterations to set + ``loss_lambda=0``. Defaults to 1000. + """ + + def __init__(self, start_iters: int = 1000) -> None: + self.start_iters = start_iters + + def before_train(self, runner) -> None: + """Obtain ``loss_lambda`` from algorithm.""" + assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ + "The runner must have attribute \"loss_lambda\" in DenseCL." + self.loss_lambda = get_ori_model(runner.model).loss_lambda + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """Adjust ``loss_lambda`` every train iter.""" + assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ + "The runner must have attribute \"loss_lambda\" in DenseCL." + cur_iter = runner.iter + if cur_iter >= self.start_iters: + get_ori_model(runner.model).loss_lambda = self.loss_lambda + else: + get_ori_model(runner.model).loss_lambda = 0. diff --git a/mmpretrain/engine/hooks/ema_hook.py b/mmpretrain/engine/hooks/ema_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..284d211b628c411f0eb712d1c558dc6aa2eb8996 --- /dev/null +++ b/mmpretrain/engine/hooks/ema_hook.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import itertools +import warnings +from typing import Dict, Optional + +from mmengine.hooks import EMAHook as BaseEMAHook +from mmengine.logging import MMLogger +from mmengine.runner import Runner + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class EMAHook(BaseEMAHook): + """A Hook to apply Exponential Moving Average (EMA) on the model during + training. + + Comparing with :class:`mmengine.hooks.EMAHook`, this hook accepts + ``evaluate_on_ema`` and ``evaluate_on_origin`` arguments. By default, the + ``evaluate_on_ema`` is enabled, and if you want to do validation and + testing on both original and EMA models, please set both arguments + ``True``. + + Note: + - EMAHook takes priority over CheckpointHook. + - The original model parameters are actually saved in ema field after + train. + - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. + + Args: + ema_type (str): The type of EMA strategy to use. You can find the + supported strategies in :mod:`mmengine.model.averaged_model`. + Defaults to 'ExponentialMovingAverage'. + strict_load (bool): Whether to strictly enforce that the keys of + ``state_dict`` in checkpoint match the keys returned by + ``self.module.state_dict``. Defaults to False. + Changed in v0.3.0. + begin_iter (int): The number of iteration to enable ``EMAHook``. + Defaults to 0. + begin_epoch (int): The number of epoch to enable ``EMAHook``. + Defaults to 0. + evaluate_on_ema (bool): Whether to evaluate (validate and test) + on EMA model during val-loop and test-loop. Defaults to True. + evaluate_on_origin (bool): Whether to evaluate (validate and test) + on the original model during val-loop and test-loop. + Defaults to False. + **kwargs: Keyword arguments passed to subclasses of + :obj:`BaseAveragedModel` + """ + + priority = 'NORMAL' + + def __init__(self, + ema_type: str = 'ExponentialMovingAverage', + strict_load: bool = False, + begin_iter: int = 0, + begin_epoch: int = 0, + evaluate_on_ema: bool = True, + evaluate_on_origin: bool = False, + **kwargs): + super().__init__( + ema_type=ema_type, + strict_load=strict_load, + begin_iter=begin_iter, + begin_epoch=begin_epoch, + **kwargs) + + if not evaluate_on_ema and not evaluate_on_origin: + warnings.warn( + 'Automatically set `evaluate_on_origin=True` since the ' + '`evaluate_on_ema` is disabled. If you want to disable ' + 'all validation, please modify the `val_interval` of ' + 'the `train_cfg`.', UserWarning) + evaluate_on_origin = True + + self.evaluate_on_ema = evaluate_on_ema + self.evaluate_on_origin = evaluate_on_origin + self.load_ema_from_ckpt = False + + def before_train(self, runner) -> None: + super().before_train(runner) + if not runner._resume and self.load_ema_from_ckpt: + # If loaded EMA state dict but not want to resume training + # overwrite the EMA state dict with the source model. + MMLogger.get_current_instance().info( + 'Load from a checkpoint with EMA parameters but not ' + 'resume training. Initialize the model parameters with ' + 'EMA parameters') + for p_ema, p_src in zip(self._ema_params, self._src_params): + p_src.data.copy_(p_ema.data) + + def before_val_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before + validation. + + Args: + runner (Runner): The runner of the training process. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """We recover source model's parameter from ema model after validation. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + if self.evaluate_on_ema and self.evaluate_on_origin: + # Re-evaluate if evaluate on both ema and origin. + val_loop = runner.val_loop + + runner.model.eval() + for idx, data_batch in enumerate(val_loop.dataloader): + val_loop.run_iter(idx, data_batch) + + # compute metrics + origin_metrics = val_loop.evaluator.evaluate( + len(val_loop.dataloader.dataset)) + + for k, v in origin_metrics.items(): + runner.message_hub.update_scalar(f'val/{k}_origin', v) + + def before_test_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before test. + + Args: + runner (Runner): The runner of the training process. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + MMLogger.get_current_instance().info('Start testing on EMA model.') + else: + MMLogger.get_current_instance().info( + 'Start testing on the original model.') + + def after_test_epoch(self, + runner: Runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """We recover source model's parameter from ema model after test. + + Args: + runner (Runner): The runner of the testing process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on test dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + if self.evaluate_on_ema and self.evaluate_on_origin: + # Re-evaluate if evaluate on both ema and origin. + MMLogger.get_current_instance().info( + 'Start testing on the original model.') + test_loop = runner.test_loop + + runner.model.eval() + for idx, data_batch in enumerate(test_loop.dataloader): + test_loop.run_iter(idx, data_batch) + + # compute metrics + origin_metrics = test_loop.evaluator.evaluate( + len(test_loop.dataloader.dataset)) + + for k, v in origin_metrics.items(): + runner.message_hub.update_scalar(f'test/{k}_origin', v) + + def after_load_checkpoint(self, runner, checkpoint: dict) -> None: + """Resume ema parameters from checkpoint. + + Args: + runner (Runner): The runner of the testing process. + """ + from mmengine.runner.checkpoint import load_state_dict + if 'ema_state_dict' in checkpoint: + # The original model parameters are actually saved in ema + # field swap the weights back to resume ema state. + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) + self.load_ema_from_ckpt = True + + # Support load checkpoint without ema state dict. + else: + load_state_dict( + self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) + + @property + def _src_params(self): + if self.ema_model.update_buffers: + return itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + else: + return self.src_model.parameters() + + @property + def _ema_params(self): + if self.ema_model.update_buffers: + return itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + else: + return self.ema_model.module.parameters() diff --git a/mmpretrain/engine/hooks/margin_head_hooks.py b/mmpretrain/engine/hooks/margin_head_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..fbeae7a347453153ff4ab3bef958acb549623f6f --- /dev/null +++ b/mmpretrain/engine/hooks/margin_head_hooks.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved +import numpy as np +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models.heads import ArcFaceClsHead +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class SetAdaptiveMarginsHook(Hook): + r"""Set adaptive-margins in ArcFaceClsHead based on the power of + category-wise count. + + A PyTorch implementation of paper `Google Landmark Recognition 2020 + Competition Third Place Solution `_. + The margins will be + :math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`. + The `n` indicates the number of occurrences of a category. + + Args: + margin_min (float): Lower bound of margins. Defaults to 0.05. + margin_max (float): Upper bound of margins. Defaults to 0.5. + power (float): The power of category freqercy. Defaults to -0.25. + """ + + def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None: + self.margin_min = margin_min + self.margin_max = margin_max + self.margin_range = margin_max - margin_min + self.p = power + + def before_train(self, runner): + """change the margins in ArcFaceClsHead. + + Args: + runner (obj: `Runner`): Runner. + """ + model = runner.model + if is_model_wrapper(model): + model = model.module + + if (hasattr(model, 'head') + and not isinstance(model.head, ArcFaceClsHead)): + raise ValueError( + 'Hook ``SetFreqPowAdvMarginsHook`` could only be used ' + f'for ``ArcFaceClsHead``, but get {type(model.head)}') + + # generate margins base on the dataset. + gt_labels = runner.train_dataloader.dataset.get_gt_labels() + label_count = np.bincount(gt_labels) + label_count[label_count == 0] = 1 # At least one occurrence + pow_freq = np.power(label_count, self.p) + + min_f, max_f = pow_freq.min(), pow_freq.max() + normized_pow_freq = (pow_freq - min_f) / (max_f - min_f) + margins = normized_pow_freq * self.margin_range + self.margin_min + + assert len(margins) == runner.model.head.num_classes + + model.head.set_margins(margins) diff --git a/mmpretrain/engine/hooks/precise_bn_hook.py b/mmpretrain/engine/hooks/precise_bn_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb0e4c419e4ed2af23574769815aaecbcd629c0 --- /dev/null +++ b/mmpretrain/engine/hooks/precise_bn_hook.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501 +# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501 + +import itertools +import logging +from typing import List, Optional, Sequence, Union + +import mmengine +import torch +import torch.nn as nn +from mmengine.hooks import Hook +from mmengine.logging import print_log +from mmengine.model import is_model_wrapper +from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner +from mmengine.utils import ProgressBar +from torch.functional import Tensor +from torch.nn import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm +from torch.utils.data import DataLoader + +from mmpretrain.registry import HOOKS + +DATA_BATCH = Optional[Sequence[dict]] + + +def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]: + """Performs the scaled all_reduce operation on the provided tensors. + + The input tensors are modified in-place. Currently supports only the sum + reduction operator. The reduced values are scaled by the inverse size of + the process group. + + Args: + tensors (List[torch.Tensor]): The tensors to process. + num_gpus (int): The number of gpus to use + Returns: + List[torch.Tensor]: The processed tensors. + """ + # There is no need for reduction in the single-proc case + if num_gpus == 1: + return tensors + # Queue the reductions + reductions = [] + for tensor in tensors: + reduction = torch.distributed.all_reduce(tensor, async_op=True) + reductions.append(reduction) + # Wait for reductions to finish + for reduction in reductions: + reduction.wait() + # Scale the results + for tensor in tensors: + tensor.mul_(1.0 / num_gpus) + return tensors + + +@torch.no_grad() +def update_bn_stats( + model: nn.Module, + loader: DataLoader, + num_samples: int = 8192, + logger: Optional[Union[logging.Logger, str]] = None) -> None: + """Computes precise BN stats on training data. + + Args: + model (nn.module): The model whose bn stats will be recomputed. + loader (DataLoader): PyTorch dataloader._dataloader + num_samples (int): The number of samples to update the bn stats. + Defaults to 8192. + logger (logging.Logger or str, optional): If the type of logger is + ``logging.Logger``, we directly use logger to log messages. + Some special loggers are: + - "silent": No message will be printed. + - "current": Use latest created logger to log message. + - other str: Instance name of logger. The corresponding logger + will log message if it has been created, otherwise will raise a + `ValueError`. + - None: The `print()` method will be used to print log messages. + """ + if is_model_wrapper(model): + model = model.module + + # get dist info + rank, world_size = mmengine.dist.get_dist_info() + # Compute the number of mini-batches to use, if the size of dataloader is + # less than num_iters, use all the samples in dataloader. + num_iter = num_samples // (loader.batch_size * world_size) + num_iter = min(num_iter, len(loader)) + # Retrieve the BN layers + bn_layers = [ + m for m in model.modules() + if m.training and isinstance(m, (_BatchNorm)) + ] + if len(bn_layers) == 0: + print_log('No BN found in model', logger=logger, level=logging.WARNING) + return + print_log( + f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger) + + # Finds all the other norm layers with training=True. + other_norm_layers = [ + m for m in model.modules() + if m.training and isinstance(m, (_InstanceNorm, GroupNorm)) + ] + if len(other_norm_layers) > 0: + print_log( + 'IN/GN stats will not be updated in PreciseHook.', + logger=logger, + level=logging.INFO) + + # Initialize BN stats storage for computing + # mean(mean(batch)) and mean(var(batch)) + running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers] + running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers] + # Remember momentum values + momentums = [bn.momentum for bn in bn_layers] + # Set momentum to 1.0 to compute BN stats that reflect the current batch + for bn in bn_layers: + bn.momentum = 1.0 + # Average the BN stats for each BN layer over the batches + if rank == 0: + prog_bar = ProgressBar(num_iter) + + for data in itertools.islice(loader, num_iter): + data = model.data_preprocessor(data, False) + model(**data) + + for i, bn in enumerate(bn_layers): + running_means[i] += bn.running_mean / num_iter + running_vars[i] += bn.running_var / num_iter + if rank == 0: + prog_bar.update() + + # Sync BN stats across GPUs (no reduction if 1 GPU used) + running_means = scaled_all_reduce(running_means, world_size) + running_vars = scaled_all_reduce(running_vars, world_size) + # Set BN stats and restore original momentum values + for i, bn in enumerate(bn_layers): + bn.running_mean = running_means[i] + bn.running_var = running_vars[i] + bn.momentum = momentums[i] + + +@HOOKS.register_module() +class PreciseBNHook(Hook): + """Precise BN hook. + + Recompute and update the batch norm stats to make them more precise. During + training both BN stats and the weight are changing after every iteration, + so the running average can not precisely reflect the actual stats of the + current model. + + With this hook, the BN stats are recomputed with fixed weights, to make the + running average more precise. Specifically, it computes the true average of + per-batch mean/variance instead of the running average. See Sec. 3 of the + paper `Rethinking Batch in BatchNorm ` + for details. + + This hook will update BN stats, so it should be executed before + ``CheckpointHook`` and ``EMAHook``, generally set its priority to + "ABOVE_NORMAL". + + Args: + num_samples (int): The number of samples to update the bn stats. + Defaults to 8192. + interval (int): Perform precise bn interval. If the train loop is + `EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the + train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is + 'iter'. Defaults to 1. + """ + + def __init__(self, num_samples: int = 8192, interval: int = 1) -> None: + assert interval > 0 and num_samples > 0, "'interval' and " \ + "'num_samples' must be bigger than 0." + + self.interval = interval + self.num_samples = num_samples + + def _perform_precise_bn(self, runner: Runner) -> None: + """perform precise bn.""" + print_log( + f'Running Precise BN for {self.num_samples} samples...', + logger=runner.logger) + update_bn_stats( + runner.model, + runner.train_loop.dataloader, + self.num_samples, + logger=runner.logger) + print_log('Finish Precise BN, BN stats updated.', logger=runner.logger) + + def after_train_epoch(self, runner: Runner) -> None: + """Calculate prcise BN and broadcast BN stats across GPUs. + + Args: + runner (obj:`Runner`): The runner of the training process. + """ + # if use `EpochBasedTrainLoop``, do perform precise every + # `self.interval` epochs. + if isinstance(runner.train_loop, + EpochBasedTrainLoop) and self.every_n_epochs( + runner, self.interval): + self._perform_precise_bn(runner) + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: + """Calculate prcise BN and broadcast BN stats across GPUs. + + Args: + runner (obj:`Runner`): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. + """ + # if use `IterBasedTrainLoop``, do perform precise every + # `self.interval` iters. + if isinstance(runner.train_loop, + IterBasedTrainLoop) and self.every_n_train_iters( + runner, self.interval): + self._perform_precise_bn(runner) diff --git a/mmpretrain/engine/hooks/retriever_hooks.py b/mmpretrain/engine/hooks/retriever_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd7c7aaff3175491b1ea1508e33b07b7c2ea8d4 --- /dev/null +++ b/mmpretrain/engine/hooks/retriever_hooks.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved +import warnings + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models import BaseRetriever +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class PrepareProtoBeforeValLoopHook(Hook): + """The hook to prepare the prototype in retrievers. + + Since the encoders of the retriever changes during training, the prototype + changes accordingly. So the `prototype_vecs` needs to be regenerated before + validation loop. + """ + + def before_val(self, runner) -> None: + model = runner.model + if is_model_wrapper(model): + model = model.module + + if isinstance(model, BaseRetriever): + if hasattr(model, 'prepare_prototype'): + model.prepare_prototype() + else: + warnings.warn( + 'Only the `mmpretrain.models.retrievers.BaseRetriever` ' + 'can execute `PrepareRetrieverPrototypeHook`, but got ' + f'`{type(model)}`') diff --git a/mmpretrain/engine/hooks/simsiam_hook.py b/mmpretrain/engine/hooks/simsiam_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..fabc4faca02bb78b92c39de68fa8a18e56d544f5 --- /dev/null +++ b/mmpretrain/engine/hooks/simsiam_hook.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class SimSiamHook(Hook): + """Hook for SimSiam. + + This hook is for SimSiam to fix learning rate of predictor. + + Args: + fix_pred_lr (bool): whether to fix the lr of predictor or not. + lr (float): the value of fixed lr. + adjust_by_epoch (bool, optional): whether to set lr by epoch or iter. + Defaults to True. + """ + + def __init__(self, + fix_pred_lr: bool, + lr: float, + adjust_by_epoch: Optional[bool] = True) -> None: + self.fix_pred_lr = fix_pred_lr + self.lr = lr + self.adjust_by_epoch = adjust_by_epoch + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """fix lr of predictor by iter.""" + if self.adjust_by_epoch: + return + else: + if self.fix_pred_lr: + for param_group in runner.optim_wrapper.optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = self.lr + + def before_train_epoch(self, runner) -> None: + """fix lr of predictor by epoch.""" + if self.fix_pred_lr: + for param_group in runner.optim_wrapper.optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = self.lr diff --git a/mmpretrain/engine/hooks/swav_hook.py b/mmpretrain/engine/hooks/swav_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..be5f3a36bdd7fc44e77700988f1759181e5ce54d --- /dev/null +++ b/mmpretrain/engine/hooks/swav_hook.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, List, Optional, Sequence + +import torch +from mmengine.device import get_device +from mmengine.dist import get_rank, get_world_size, is_distributed +from mmengine.hooks import Hook +from mmengine.logging import MMLogger + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class SwAVHook(Hook): + """Hook for SwAV. + + This hook builds the queue in SwAV according to ``epoch_queue_starts``. + The queue will be saved in ``runner.work_dir`` or loaded at start epoch + if the path folder has queues saved before. + + Args: + batch_size (int): the batch size per GPU for computing. + epoch_queue_starts (int, optional): from this epoch, starts to use the + queue. Defaults to 15. + crops_for_assign (list[int], optional): list of crops id used for + computing assignments. Defaults to [0, 1]. + feat_dim (int, optional): feature dimension of output vector. + Defaults to 128. + queue_length (int, optional): length of the queue (0 for no queue). + Defaults to 0. + interval (int, optional): the interval to save the queue. + Defaults to 1. + frozen_layers_cfg (dict, optional): Dict to config frozen layers. + The key-value pair is layer name and its frozen iters. If frozen, + the layers don't need gradient. Defaults to dict(). + """ + + def __init__( + self, + batch_size: int, + epoch_queue_starts: Optional[int] = 15, + crops_for_assign: Optional[List[int]] = [0, 1], + feat_dim: Optional[int] = 128, + queue_length: Optional[int] = 0, + interval: Optional[int] = 1, + frozen_layers_cfg: Optional[Dict] = dict() + ) -> None: + self.batch_size = batch_size * get_world_size() + self.epoch_queue_starts = epoch_queue_starts + self.crops_for_assign = crops_for_assign + self.feat_dim = feat_dim + self.queue_length = queue_length + self.interval = interval + self.frozen_layers_cfg = frozen_layers_cfg + self.requires_grad = True + self.queue = None + + def before_run(self, runner) -> None: + """Check whether the queues exist locally or not.""" + if is_distributed(): + self.queue_path = osp.join(runner.work_dir, + 'queue' + str(get_rank()) + '.pth') + else: + self.queue_path = osp.join(runner.work_dir, 'queue.pth') + + # load the queues if queues exist locally + if osp.isfile(self.queue_path): + self.queue = torch.load(self.queue_path)['queue'] + get_ori_model(runner.model).head.loss_module.queue = self.queue + MMLogger.get_current_instance().info( + f'Load queue from file: {self.queue_path}') + + # the queue needs to be divisible by the batch size + self.queue_length -= self.queue_length % self.batch_size + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """Freeze layers before specific iters according to the config.""" + for layer, frozen_iters in self.frozen_layers_cfg.items(): + if runner.iter < frozen_iters and self.requires_grad: + self.requires_grad = False + for name, p in get_ori_model(runner.model).named_parameters(): + if layer in name: + p.requires_grad = False + elif runner.iter >= frozen_iters and not self.requires_grad: + self.requires_grad = True + for name, p in get_ori_model(runner.model).named_parameters(): + if layer in name: + p.requires_grad = True + + def before_train_epoch(self, runner) -> None: + """Check the queues' state.""" + # optionally starts a queue + if self.queue_length > 0 \ + and runner.epoch >= self.epoch_queue_starts \ + and self.queue is None: + + self.queue = torch.zeros( + len(self.crops_for_assign), + self.queue_length // runner.world_size, + self.feat_dim, + device=get_device(), + ) + + # set the boolean type of use_the_queue + get_ori_model(runner.model).head.loss_module.queue = self.queue + get_ori_model(runner.model).head.loss_module.use_queue = False + + def after_train_epoch(self, runner) -> None: + """Save the queues locally.""" + self.queue = get_ori_model(runner.model).head.loss_module.queue + + if self.queue is not None and self.every_n_epochs( + runner, self.interval): + torch.save({'queue': self.queue}, self.queue_path) diff --git a/mmpretrain/engine/hooks/switch_recipe_hook.py b/mmpretrain/engine/hooks/switch_recipe_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..914b9572eb22d2cd2f54c519273c86baf2e0894d --- /dev/null +++ b/mmpretrain/engine/hooks/switch_recipe_hook.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from copy import deepcopy + +from mmcv.transforms import Compose +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models.utils import RandomBatchAugment +from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS + + +@HOOKS.register_module() +class SwitchRecipeHook(Hook): + """switch recipe during the training loop, including train pipeline, batch + augments and loss currently. + + Args: + schedule (list): Every item of the schedule list should be a dict, and + the dict should have ``action_epoch`` and some of + ``train_pipeline``, ``train_augments`` and ``loss`` keys: + + - ``action_epoch`` (int): switch training recipe at which epoch. + - ``train_pipeline`` (list, optional): The new data pipeline of the + train dataset. If not specified, keep the original settings. + - ``batch_augments`` (dict | None, optional): The new batch + augmentations of during training. See :mod:`Batch Augmentations + ` for more details. + If None, disable batch augmentations. If not specified, keep the + original settings. + - ``loss`` (dict, optional): The new loss module config. If not + specified, keep the original settings. + + Example: + To use this hook in config files. + + .. code:: python + + custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict( + action_epoch=30, + train_pipeline=pipeline_after_30e, + batch_augments=batch_augments_after_30e, + loss=loss_after_30e, + ), + dict( + action_epoch=60, + # Disable batch augmentations after 60e + # and keep other settings. + batch_augments=None, + ), + ] + ) + ] + """ + priority = 'NORMAL' + + def __init__(self, schedule): + recipes = {} + for recipe in schedule: + assert 'action_epoch' in recipe, \ + 'Please set `action_epoch` in every item ' \ + 'of the `schedule` in the SwitchRecipeHook.' + recipe = deepcopy(recipe) + if 'train_pipeline' in recipe: + recipe['train_pipeline'] = Compose(recipe['train_pipeline']) + if 'batch_augments' in recipe: + batch_augments = recipe['batch_augments'] + if isinstance(batch_augments, dict): + batch_augments = RandomBatchAugment(**batch_augments) + recipe['batch_augments'] = batch_augments + if 'loss' in recipe: + loss = recipe['loss'] + if isinstance(loss, dict): + loss = MODELS.build(loss) + recipe['loss'] = loss + + action_epoch = recipe.pop('action_epoch') + assert action_epoch not in recipes, \ + f'The `action_epoch` {action_epoch} is repeated ' \ + 'in the SwitchRecipeHook.' + recipes[action_epoch] = recipe + self.schedule = OrderedDict(sorted(recipes.items())) + + def before_train(self, runner) -> None: + """before run setting. If resume form a checkpoint, do all switch + before the current epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + """ + if runner._resume: + for action_epoch, recipe in self.schedule.items(): + if action_epoch >= runner.epoch + 1: + break + self._do_switch(runner, recipe, + f' (resume recipe of epoch {action_epoch})') + + def before_train_epoch(self, runner): + """do before train epoch.""" + recipe = self.schedule.get(runner.epoch + 1, None) + if recipe is not None: + self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}') + + def _do_switch(self, runner, recipe, extra_info=''): + """do the switch aug process.""" + if 'batch_augments' in recipe: + self._switch_batch_augments(runner, recipe['batch_augments']) + runner.logger.info(f'Switch batch augments{extra_info}.') + + if 'train_pipeline' in recipe: + self._switch_train_pipeline(runner, recipe['train_pipeline']) + runner.logger.info(f'Switch train pipeline{extra_info}.') + + if 'loss' in recipe: + self._switch_loss(runner, recipe['loss']) + runner.logger.info(f'Switch loss{extra_info}.') + + @staticmethod + def _switch_batch_augments(runner, batch_augments): + """switch the train augments.""" + model = runner.model + if is_model_wrapper(model): + model = model.module + + model.data_preprocessor.batch_augments = batch_augments + + @staticmethod + def _switch_train_pipeline(runner, train_pipeline): + """switch the train loader dataset pipeline.""" + + def switch_pipeline(dataset, pipeline): + if hasattr(dataset, 'pipeline'): + # for usual dataset + dataset.pipeline = pipeline + elif hasattr(dataset, 'datasets'): + # for concat dataset wrapper + for ds in dataset.datasets: + switch_pipeline(ds, pipeline) + elif hasattr(dataset, 'dataset'): + # for other dataset wrappers + switch_pipeline(dataset.dataset, pipeline) + else: + raise RuntimeError( + 'Cannot access the `pipeline` of the dataset.') + + train_loader = runner.train_loop.dataloader + switch_pipeline(train_loader.dataset, train_pipeline) + + # To restart the iterator of dataloader when `persistent_workers=True` + train_loader._iterator = None + + @staticmethod + def _switch_loss(runner, loss_module): + """switch the loss module.""" + model = runner.model + if is_model_wrapper(model, MODEL_WRAPPERS): + model = model.module + + if hasattr(model, 'loss_module'): + model.loss_module = loss_module + elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'): + model.head.loss_module = loss_module + else: + raise RuntimeError('Cannot access the `loss_module` of the model.') diff --git a/mmpretrain/engine/hooks/visualization_hook.py b/mmpretrain/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..64d2230a79db971bef78d77bcf80c40365bddb15 --- /dev/null +++ b/mmpretrain/engine/hooks/visualization_hook.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import os.path as osp +from typing import Optional, Sequence + +from mmengine.fileio import join_path +from mmengine.hooks import Hook +from mmengine.runner import EpochBasedTrainLoop, Runner +from mmengine.visualization import Visualizer + +from mmpretrain.registry import HOOKS +from mmpretrain.structures import DataSample + + +@HOOKS.register_module() +class VisualizationHook(Hook): + """Classification Visualization Hook. Used to visualize validation and + testing prediction results. + + - If ``out_dir`` is specified, all storage backends are ignored + and save the image to the ``out_dir``. + - If ``show`` is True, plot the result image in a window, please + confirm you are able to access the graphical interface. + + Args: + enable (bool): Whether to enable this hook. Defaults to False. + interval (int): The interval of samples to visualize. Defaults to 5000. + show (bool): Whether to display the drawn image. Defaults to False. + out_dir (str, optional): directory where painted images will be saved + in the testing process. If None, handle with the backends of the + visualizer. Defaults to None. + **kwargs: other keyword arguments of + :meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`. + """ + + def __init__(self, + enable=False, + interval: int = 5000, + show: bool = False, + out_dir: Optional[str] = None, + **kwargs): + self._visualizer: Visualizer = Visualizer.get_current_instance() + + self.enable = enable + self.interval = interval + self.show = show + self.out_dir = out_dir + + self.draw_args = {**kwargs, 'show': show} + + def _draw_samples(self, + batch_idx: int, + data_batch: dict, + data_samples: Sequence[DataSample], + step: int = 0) -> None: + """Visualize every ``self.interval`` samples from a data batch. + + Args: + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DataSample`]): Outputs from model. + step (int): Global step value to record. Defaults to 0. + """ + if self.enable is False: + return + + batch_size = len(data_samples) + images = data_batch['inputs'] + start_idx = batch_size * batch_idx + end_idx = start_idx + batch_size + + # The first index divisible by the interval, after the start index + first_sample_id = math.ceil(start_idx / self.interval) * self.interval + + for sample_id in range(first_sample_id, end_idx, self.interval): + image = images[sample_id - start_idx] + image = image.permute(1, 2, 0).cpu().numpy().astype('uint8') + + data_sample = data_samples[sample_id - start_idx] + if 'img_path' in data_sample: + # osp.basename works on different platforms even file clients. + sample_name = osp.basename(data_sample.get('img_path')) + else: + sample_name = str(sample_id) + + draw_args = self.draw_args + if self.out_dir is not None: + draw_args['out_file'] = join_path(self.out_dir, + f'{sample_name}_{step}.png') + + self._visualizer.visualize_cls( + image=image, + data_sample=data_sample, + step=step, + name=sample_name, + **self.draw_args, + ) + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DataSample]) -> None: + """Visualize every ``self.interval`` samples during validation. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DataSample`]): Outputs from model. + """ + if isinstance(runner.train_loop, EpochBasedTrainLoop): + step = runner.epoch + else: + step = runner.iter + + self._draw_samples(batch_idx, data_batch, outputs, step=step) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DataSample]) -> None: + """Visualize every ``self.interval`` samples during test. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. + """ + self._draw_samples(batch_idx, data_batch, outputs, step=0) diff --git a/mmpretrain/engine/hooks/warmup_param_hook.py b/mmpretrain/engine/hooks/warmup_param_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d8918dbbcb9cf5d12c252621908f0b6c1f251 --- /dev/null +++ b/mmpretrain/engine/hooks/warmup_param_hook.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator as op +from typing import Any, Optional, Union + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class WarmupParamHook(Hook): + """This is a hook used for changing the parameters other than optimizations + that need to warmup inside the module. + + This hook can extend with more detailed warmup rule if necessary. + + Args: + param_name (str): The parameter name that needs to be altered. + module_name (str): Module name that belongs to the model. Such as + `head`, `head.loss`, etc. + warmup_epochs (int): The warmup epochs for this parameter. + """ + + def __init__( + self, + param_name: str, + module_name: str, + warmup_epochs: int, + ) -> None: + self.param_name = param_name + self.warmup_epochs = warmup_epochs + # getter for module which saves the changed parameter + self.module_getter = op.attrgetter(module_name) + + def get_param(self, runner) -> Any: + """Get the parameter.""" + try: + module = self.module_getter(get_ori_model(runner.model)) + return getattr(module, self.param_name) + except AttributeError as e: + raise AttributeError(f'{e}. Please check hook settings.') + + def set_param(self, runner, value) -> None: + """Set the parameter.""" + try: + module = self.module_getter(get_ori_model(runner.model)) + setattr(module, self.param_name, value) + except AttributeError as e: + raise AttributeError(f'{e}. Please check hook settings.') + + def before_train(self, runner) -> None: + """Get the original value before train.""" + self.ori_val = self.get_param(runner) + + def before_train_iter( + self, + runner, + batch_idx: int, + data_batch: Optional[Union[dict, tuple, list]] = None) -> None: + """Set the warmup value before each train iter.""" + cur_iter = runner.iter + iters_per_epoch = runner.max_iters / runner.max_epochs + new_val = self.ori_val * min( + 1, cur_iter / (self.warmup_epochs * iters_per_epoch)) + self.set_param(runner, new_val) diff --git a/mmpretrain/engine/optimizers/__init__.py b/mmpretrain/engine/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd53a37630b2a0dfbb69b1020518b9ec4ff03715 --- /dev/null +++ b/mmpretrain/engine/optimizers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adan_t import Adan +from .lamb import Lamb +from .lars import LARS +from .layer_decay_optim_wrapper_constructor import \ + LearningRateDecayOptimWrapperConstructor + +__all__ = ['Lamb', 'Adan', 'LARS', 'LearningRateDecayOptimWrapperConstructor'] diff --git a/mmpretrain/engine/optimizers/adan_t.py b/mmpretrain/engine/optimizers/adan_t.py new file mode 100644 index 0000000000000000000000000000000000000000..571a71b6fe561fb33053af2fd6d2161a775918e4 --- /dev/null +++ b/mmpretrain/engine/optimizers/adan_t.py @@ -0,0 +1,312 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Adan(Optimizer): + """Implements a pytorch variant of Adan. + + Adan was proposed in + Adan : Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. # noqa + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize + or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used + for computing running averages of gradient. + (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self): + """Performs a single optimization step.""" + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + group['eps'] + + clip_global_grad_norm = \ + torch.clamp(max_grad_norm / global_grad_norm, max=1.0) + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'pre_grad' not in state or group['step'] == 1: + # at first step grad wouldn't be clipped + # by `clip_global_grad_norm` + # this is only to simplify implementation + state['pre_grad'] = p.grad + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + pre_grads.append(state['pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + pre_grads=pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + if group['foreach']: + copy_grads = _multi_tensor_adan(**kwargs) + else: + copy_grads = _single_tensor_adan(**kwargs) + + for p, copy_grad in zip(params_with_grad, copy_grads): + self.state[p]['pre_grad'] = copy_grad + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + copy_grads = [] + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + pre_grad = pre_grads[i] + + grad = grad.mul_(clip_global_grad_norm) + copy_grads.append(grad.clone()) + + diff = grad - pre_grad + update = grad + beta2 * diff + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t + exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t + + denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps) + update = exp_avg / bias_correction1 + update.add_(beta2 * exp_avg_diff / bias_correction2).div_(denom) + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.add_(update, alpha=-lr) + else: + param.add_(update, alpha=-lr) + param.div_(1 + lr * weight_decay) + return copy_grads + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if clip_global_grad_norm < 1.0: + torch._foreach_mul_(grads, clip_global_grad_norm.item()) + copy_grads = [g.clone() for g in grads] + + diff = torch._foreach_sub(grads, pre_grads) + # NOTE: line below while looking identical gives different result, + # due to float precision errors. + # using mul+add produces identical results to single-tensor, + # using add+alpha doesn't + # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2)) + update = torch._foreach_add(grads, diff, alpha=beta2) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t + + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_( + exp_avg_sqs, update, update, value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + update = torch._foreach_div(exp_avgs, bias_correction1) + # NOTE: same issue as above. + # beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) # noqa + # using faster version by default. uncomment for tests to pass + # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) # noqa + torch._foreach_add_( + update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2)) + torch._foreach_div_(update, denom) + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + else: + torch._foreach_add_(params, update, alpha=-lr) + torch._foreach_div_(params, 1 + lr * weight_decay) + return copy_grads diff --git a/mmpretrain/engine/optimizers/lamb.py b/mmpretrain/engine/optimizers/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..0b44a1c168e03fa7f569388beec206fe68c64749 --- /dev/null +++ b/mmpretrain/engine/optimizers/lamb.py @@ -0,0 +1,228 @@ +"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb. + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/ +2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/ +LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb +is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or +cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support +PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Lamb(Optimizer): + """A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer. + + This class is copied from `timm`_. The LAMB was proposed in `Large Batch + Optimization for Deep Learning - Training BERT in 76 minutes`_. + + .. _timm: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + """ # noqa: E501 + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False): + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + trust_clip=trust_clip, + always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + 'Lamb does not support sparse gradients, consider ' + 'SparseAdam instead.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars + # when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or + # pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1**group['step'] + bias_correction2 = 1 - beta2**group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on + # parameters that are + # excluded from weight decay, unless always_adapt == True, + # then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not + # working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/mmpretrain/engine/optimizers/lars.py b/mmpretrain/engine/optimizers/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..5e388878374e3d1e7408861a5f1830b00df5664b --- /dev/null +++ b/mmpretrain/engine/optimizers/lars.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable + +import torch +from torch.optim.optimizer import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class LARS(Optimizer): + """Implements layer-wise adaptive rate scaling for SGD. + + Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. + `Large Batch Training of Convolutional Networks: + `_. + + Args: + params (Iterable): Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): Base learning rate. + momentum (float): Momentum factor. Defaults to 0. + weight_decay (float): Weight decay (L2 penalty). Defaults to 0. + dampening (float): Dampening for momentum. Defaults to 0. + eta (float): LARS coefficient. Defaults to 0.001. + nesterov (bool): Enables Nesterov momentum. Defaults to False. + eps (float): A small number to avoid dviding zero. Defaults to 1e-8. + + Example: + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, + >>> weight_decay=1e-4, eta=1e-3) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + """ + + def __init__(self, + params: Iterable, + lr: float, + momentum: float = 0, + weight_decay: float = 0, + dampening: float = 0, + eta: float = 0.001, + nesterov: bool = False, + eps: float = 1e-8) -> None: + if not isinstance(lr, float) and lr < 0.0: + raise ValueError(f'Invalid learning rate: {lr}') + if momentum < 0.0: + raise ValueError(f'Invalid momentum value: {momentum}') + if weight_decay < 0.0: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if eta < 0.0: + raise ValueError(f'Invalid LARS coefficient value: {eta}') + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + eta=eta) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + 'Nesterov momentum requires a momentum and zero dampening') + + self.eps = eps + super().__init__(params, defaults) + + def __setstate__(self, state) -> None: + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None) -> torch.Tensor: + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + eta = group['eta'] + nesterov = group['nesterov'] + lr = group['lr'] + lars_exclude = group.get('lars_exclude', False) + + for p in group['params']: + if p.grad is None: + continue + + d_p = p.grad + + if lars_exclude: + local_lr = 1. + else: + weight_norm = torch.norm(p).item() + grad_norm = torch.norm(d_p).item() + if weight_norm != 0 and grad_norm != 0: + # Compute local learning rate for this layer + local_lr = eta * weight_norm / \ + (grad_norm + weight_decay * weight_norm + self.eps) + else: + local_lr = 1. + + actual_lr = local_lr * lr + d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = \ + torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + p.add_(-d_p) + + return loss diff --git a/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..09c6abc54a9f49cc789bf91d2bf74b0ec68902c4 --- /dev/null +++ b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Callable, List, Optional + +from mmengine.logging import MMLogger +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from torch import nn +from torch.nn import GroupNorm, LayerNorm + +from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): + """Different learning rates are set for different layers of backbone. + + By default, each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``layer_decay_rate`` (float): The learning rate of a parameter will + multiply it by multiple times according to the layer depth of the + parameter. Usually, it's less than 1, so that the earlier layers will + have a lower learning rate. Defaults to 1. + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in normalization layers). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization layers. + - ``flat_decay_mult`` (float): It will be multiplied to the weight + decay for all one-dimensional parameters + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be + ignored. It should be a dict and may contain fields ``decay_mult``. + (The ``lr_mult`` is disabled in this constructor). + + Example: + + In the config file, you can use this constructor as below: + + .. code:: python + + optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=4e-3, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + constructor='LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + layer_decay_rate=0.75, # layer-wise lr decay factor + norm_decay_mult=0., + flat_decay_mult=0., + custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + """ + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + get_layer_depth: Optional[Callable] = None, + **kwargs) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (List[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + optimizer_cfg (dict): The configuration of optimizer. + prefix (str): The prefix of the module. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + logger = MMLogger.get_current_instance() + + # The model should have `get_layer_depth` method + if get_layer_depth is None and not hasattr(module, 'get_layer_depth'): + raise NotImplementedError('The layer-wise learning rate decay need' + f' the model {type(module)} has' + ' `get_layer_depth` method.') + else: + get_layer_depth = get_layer_depth or module.get_layer_depth + + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) + flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) + decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + param_name = prefix + name + if not param.requires_grad: + continue + + if self.base_wd is not None: + base_wd = self.base_wd + custom_key = next( + filter(lambda k: k in param_name, sorted_keys), None) + # custom parameters decay + if custom_key is not None: + custom_cfg = custom_keys[custom_key].copy() + decay_mult = custom_cfg.pop('decay_mult', 1.) + + param_group['weight_decay'] = base_wd * decay_mult + # add custom settings to param_group + param_group.update(custom_cfg) + # norm decay + elif is_norm and norm_decay_mult is not None: + param_group['weight_decay'] = base_wd * norm_decay_mult + # bias decay + elif name == 'bias' and bias_decay_mult is not None: + param_group['weight_decay'] = base_wd * bias_decay_mult + # flatten parameters decay + elif param.ndim == 1 and flat_decay_mult is not None: + param_group['weight_decay'] = base_wd * flat_decay_mult + else: + param_group['weight_decay'] = base_wd + + layer_id, max_id = get_layer_depth(param_name) + scale = decay_rate**(max_id - layer_id - 1) + param_group['lr'] = self.base_lr * scale + param_group['lr_scale'] = scale + param_group['layer_id'] = layer_id + param_group['param_name'] = param_name + + params.append(param_group) + + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}{child_name}.' + self.add_params( + params, + child_mod, + prefix=child_prefix, + get_layer_depth=get_layer_depth, + ) + + if prefix == '': + layer_params = defaultdict(list) + for param in params: + layer_params[param['layer_id']].append(param) + for layer_id, layer_params in layer_params.items(): + lr_scale = layer_params[0]['lr_scale'] + lr = layer_params[0]['lr'] + msg = [ + f'layer {layer_id} params ' + f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):' + ] + for param in layer_params: + msg.append(f'\t{param["param_name"]}: ' + f'weight_decay={param["weight_decay"]:.3g}') + logger.debug('\n'.join(msg)) diff --git a/mmpretrain/engine/runners/__init__.py b/mmpretrain/engine/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23206e1ea7c83fa1d547c677b3fe5203f8c5485f --- /dev/null +++ b/mmpretrain/engine/runners/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .retrieval_loop import RetrievalTestLoop, RetrievalValLoop + +__all__ = ['RetrievalTestLoop', 'RetrievalValLoop'] diff --git a/mmpretrain/engine/runners/retrieval_loop.py b/mmpretrain/engine/runners/retrieval_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..d15387eddeb9075c23949f95a77ed59006bb9a38 --- /dev/null +++ b/mmpretrain/engine/runners/retrieval_loop.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmengine.model import is_model_wrapper +from mmengine.runner import TestLoop, ValLoop, autocast + +from mmpretrain.registry import LOOPS + + +@LOOPS.register_module() +class RetrievalValLoop(ValLoop): + """Loop for multimodal retrieval val. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 valing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch val.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + img_size = self.dataloader.dataset.img_size + text_size = self.dataloader.dataset.text_size + with torch.no_grad(): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=img_size, + num_texts=text_size, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(img_size) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(text_size) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_val_epoch', metrics=metrics) + self.runner.call_hook('after_val') + return metrics + + +@LOOPS.register_module() +class RetrievalTestLoop(TestLoop): + """Loop for multimodal retrieval test. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 testing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + img_size = self.dataloader.dataset.img_size + text_size = self.dataloader.dataset.text_size + with torch.no_grad(): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=img_size, + num_texts=text_size, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(img_size) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(text_size) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + return metrics diff --git a/mmpretrain/engine/schedulers/__init__.py b/mmpretrain/engine/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68b6a5477b84a53e060e0e6d43fdac830adebffb --- /dev/null +++ b/mmpretrain/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .weight_decay_scheduler import CosineAnnealingWeightDecay + +__all__ = ['CosineAnnealingWeightDecay'] diff --git a/mmpretrain/engine/schedulers/weight_decay_scheduler.py b/mmpretrain/engine/schedulers/weight_decay_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7e725a4c3f53856cf848ed7e6a225a178b36ab98 --- /dev/null +++ b/mmpretrain/engine/schedulers/weight_decay_scheduler.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmengine.optim.scheduler import CosineAnnealingParamScheduler + +from mmpretrain.registry import PARAM_SCHEDULERS + + +class WeightDecaySchedulerMixin: + """A mixin class for learning rate schedulers.""" + + def __init__(self, optimizer, *args, **kwargs): + super().__init__(optimizer, 'weight_decay', *args, **kwargs) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin, + CosineAnnealingParamScheduler): + """Set the weight decay value of each parameter group using a cosine + annealing schedule. + + If the weight decay was set to be 0 initially, the weight decay value will + be 0 constantly during the training. + """ + + def _get_value(self) -> list: + """Compute value using chainable form of the scheduler.""" + + def _get_eta_min(base_value): + if self.eta_min_ratio is None: + return self.eta_min + return base_value * self.eta_min_ratio + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = group[self.param_name] + ( + base_value - _get_eta_min(base_value)) * ( + 1 - math.cos(math.pi / self.T_max)) / 2 + weight_decay_value_list.append(group_value) + return weight_decay_value_list + + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = ( + 1 + math.cos(math.pi * self.last_step / self.T_max)) / ( + 1 + math.cos(math.pi * + (self.last_step - 1) / self.T_max) + ) * (group[self.param_name] - + _get_eta_min(base_value)) + _get_eta_min(base_value) + weight_decay_value_list.append(group_value) + return weight_decay_value_list diff --git a/mmpretrain/evaluation/.DS_Store b/mmpretrain/evaluation/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..224f1978a6ddf3a50f9548352948f8a6b40d9edc Binary files /dev/null and b/mmpretrain/evaluation/.DS_Store differ diff --git a/mmpretrain/evaluation/__init__.py b/mmpretrain/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f70dc226d30f7b8e4ee5a44ca163ad1ae04eabf5 --- /dev/null +++ b/mmpretrain/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .functional import * # noqa: F401,F403 +from .metrics import * # noqa: F401,F403 diff --git a/mmpretrain/evaluation/__pycache__/__init__.cpython-311.pyc b/mmpretrain/evaluation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ccb74dbb539364c930cafe2f7ff66661a33cbf6 Binary files /dev/null and b/mmpretrain/evaluation/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/functional/__init__.py b/mmpretrain/evaluation/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef101fec61e72abc0eb90266d453b5b22331378d --- /dev/null +++ b/mmpretrain/evaluation/functional/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-311.pyc b/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d078ddfc96f8e3e7d005254612f7dfe38588be6c Binary files /dev/null and b/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/ANLS.py b/mmpretrain/evaluation/metrics/ANLS.py new file mode 100644 index 0000000000000000000000000000000000000000..14917f16e343b1f9c73a44af34f800c3ae72fd22 --- /dev/null +++ b/mmpretrain/evaluation/metrics/ANLS.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class ANLS(BaseMetric): + """ANLS metric. + + Compute the Average Normalized Levenshtein Similarity(ANLS). + + Args: + threshold (float): ANLS threshold used for determining if the answer + has been correctly selected but not properly recognized, + or on the contrary, the output is a wrong text selected from the + options and given as an answer. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'ANLS' + + def __init__(self, + threshold: float = 0.5, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.threshold = threshold + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer + } + + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + total_score = 0. + for result in results: + sample_score_list = [] + pred = ' '.join(result['pred_answer'].strip().lower().split()) + for gt in result['gt_answer']: + gt = ' '.join(gt.strip().lower().split()) + dist = levenshtein_distance(gt, pred) + length = max( + len(gt.upper()), len(result['pred_answer'].upper())) + sample_score_list.append(0.0 if length == 0 else float(dist) / + float(length)) + + per_sample_score = 1. - min(sample_score_list) + if per_sample_score < self.threshold: + per_sample_score = 0. + + total_score += per_sample_score + + total_score = total_score / len(results) + return {'ANLS': total_score} + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], + distances_[-1]))) + distances = distances_ + return distances[-1] diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e572efeb91e8ba64c46ab6241fe611bff136a210 --- /dev/null +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ANLS import ANLS +from .caption import COCOCaption +from .gqa import GQAAcc +from .multi_label import AveragePrecision, MultiLabelMetric +from .multi_task import MultiTasksMetric +from .nocaps import NocapsSave +from .retrieval import RetrievalAveragePrecision, RetrievalRecall +from .scienceqa import ScienceQAMetric +from .shape_bias_label import ShapeBiasMetric +from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric +from .visual_grounding_eval import VisualGroundingMetric +from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric +from .vqa import ReportVQA, VQAAcc + +__all__ = [ + 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', + 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', + 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', + 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', + 'RetrievalAveragePrecision', 'ShapeBiasMetric', 'ANLS' +] diff --git a/mmpretrain/evaluation/metrics/__pycache__/ANLS.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/ANLS.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53bf91b0eb85e52b2c8cd1f8d47069261073348a Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/ANLS.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbcd41c7cc56ffbac84f8bc55fccc63325749e6d Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d64fa722e88c9dce8d4c76530abd7b1c4157dd8b Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0242b959489b47e9ddadeab05dcaebbd1d2ed480 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92923954fe698ee0b9923f67a782bb7ffd0a3c2e Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08918b40858e23126564ee758c119d7e6a48b0af Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e85104891a4c758f3290f17cd57ddf6f15ee3e Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc6b1c9e8120c2444d3ca329c8f550a4fe9a3fd2 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75ab55d5204c40c19381d0a4a9656b586fbac06 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25add97f8915ab76cda309331827798159d2bdc Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fefd58b24eb2d6c557a948aa57f9ba517e92680 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47da0f903258939fd147520f75ea99229f3fe995 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49fcaa7276a83d7836da2c6c9ff0cef4f95a1b67 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-311.pyc b/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e9aa5dfb026fc8865a0bc1be5cbf3c42b7cf96 Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-311.pyc differ diff --git a/mmpretrain/evaluation/metrics/caption.py b/mmpretrain/evaluation/metrics/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bffabfa97a9c6faec7ecc0ffb6d9ba2f435b97 --- /dev/null +++ b/mmpretrain/evaluation/metrics/caption.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import tempfile +from typing import List, Optional + +from mmengine.evaluator import BaseMetric +from mmengine.utils import track_iter_progress + +from mmpretrain.registry import METRICS +from mmpretrain.utils import require + +try: + from pycocoevalcap.eval import COCOEvalCap + from pycocotools.coco import COCO +except ImportError: + COCOEvalCap = None + COCO = None + + +@METRICS.register_module() +class COCOCaption(BaseMetric): + """Coco Caption evaluation wrapper. + + Save the generated captions and transform into coco format. + Calling COCO API for caption metrics. + + Args: + ann_file (str): the path for the COCO format caption ground truth + json file, load for evaluations. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + @require('pycocoevalcap') + def __init__(self, + ann_file: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.ann_file = ann_file + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['caption'] = data_sample.get('pred_caption') + result['image_id'] = int(data_sample.get('image_id')) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + + with tempfile.TemporaryDirectory() as temp_dir: + + eval_result_file = save_result( + result=results, + result_dir=temp_dir, + filename='m4-caption_pred', + remove_duplicate='image_id', + ) + + coco_val = coco_caption_eval(eval_result_file, self.ann_file) + + return coco_val + + +def save_result(result, result_dir, filename, remove_duplicate=''): + """Saving predictions as json file for evaluation.""" + + # combine results from all processes + result_new = [] + + if remove_duplicate: + result_new = [] + id_list = [] + for res in track_iter_progress(result): + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + final_result_file_url = os.path.join(result_dir, '%s.json' % filename) + print(f'result file saved to {final_result_file_url}') + json.dump(result, open(final_result_file_url, 'w')) + + return final_result_file_url + + +def coco_caption_eval(results_file, ann_file): + """Evaluation between gt json and prediction json files.""" + # create coco object and coco_result object + coco = COCO(ann_file) + coco_result = coco.loadRes(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # make sure the image ids are the same + coco_eval.params['image_id'] = coco_result.getImgIds() + + # This will take some times at the first run + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + return coco_eval.eval diff --git a/mmpretrain/evaluation/metrics/gqa.py b/mmpretrain/evaluation/metrics/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e8b0725524839c5b0a15a8ba6fb4eed689e589 --- /dev/null +++ b/mmpretrain/evaluation/metrics/gqa.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, + _process_punctuation) +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class GQAAcc(BaseMetric): + """GQA Acc metric. + + Compute GQA accuracy. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'GQA' + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer + } + + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = self._process_answer(result['gt_answer']) + gqa_acc = 1 if pred_answer == gt_answer else 0 + acc.append(gqa_acc) + + accuracy = sum(acc) / len(acc) + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer) -> str: + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer diff --git a/mmpretrain/evaluation/metrics/multi_label.py b/mmpretrain/evaluation/metrics/multi_label.py new file mode 100644 index 0000000000000000000000000000000000000000..bd91aac4449c845fbed514ed5f800bd971236ade --- /dev/null +++ b/mmpretrain/evaluation/metrics/multi_label.py @@ -0,0 +1,599 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .single_label import _precision_recall_f1_support, to_tensor + + +@METRICS.register_module() +class MultiLabelMetric(BaseMetric): + r"""A collection of precision, recall, f1-score and support for + multi-label tasks. + + The collection of metrics is for single-label multi-class classification. + And all these metrics are based on the confusion matrix of every category: + + .. image:: ../../_static/image/confusion-matrix.png + :width: 60% + :align: center + + All metrics can be formulated use variables above: + + **Precision** is the fraction of correct predictions in all predictions: + + .. math:: + \text{Precision} = \frac{TP}{TP+FP} + + **Recall** is the fraction of correct predictions in all targets: + + .. math:: + \text{Recall} = \frac{TP}{TP+FN} + + **F1-score** is the harmonic mean of the precision and recall: + + .. math:: + \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}} + + **Support** is the number of samples: + + .. math:: + \text{Support} = TP + TN + FN + FP + + Args: + thr (float, optional): Predictions with scores under the threshold + are considered as negative. If None, the ``topk`` predictions will + be considered as positive. If the ``topk`` is also None, use + ``thr=0.5`` as default. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. If None, use ``thr`` to determine positive + predictions. If both ``thr`` and ``topk`` are not None, use + ``thr``. Defaults to None. + items (Sequence[str]): The detailed metric items to evaluate, select + from "precision", "recall", "f1-score" and "support". + Defaults to ``('precision', 'recall', 'f1-score')``. + average (str | None): How to calculate the final metrics from the + confusion matrix of every category. It supports three modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories and + calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import MultiLabelMetric + >>> # ------ The Basic Usage for category indices labels ------- + >>> y_pred = [[0], [1], [0, 1], [3]] + >>> y_true = [[0, 3], [0, 2], [1], [3]] + >>> # Output precision, recall, f1-score and support + >>> MultiLabelMetric.calculate( + ... y_pred, y_true, pred_indices=True, target_indices=True, num_classes=4) + (tensor(50.), tensor(50.), tensor(45.8333), tensor(6)) + >>> # ----------- The Basic Usage for one-hot labels ----------- + >>> y_pred = torch.tensor([[1, 1, 0, 0], + ... [1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [0, 1, 0, 0], + ... [0, 1, 0, 0]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 1, 1, 0], + ... [1, 0, 0, 0], + ... [1, 0, 0, 0]]) + >>> MultiLabelMetric.calculate(y_pred, y_true) + (tensor(43.7500), tensor(31.2500), tensor(33.3333), tensor(8)) + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.rand(y_true.size()) + >>> y_pred + tensor([[0.4575, 0.7335, 0.3934, 0.2572], + [0.1318, 0.1004, 0.8248, 0.6448], + [0.8349, 0.6294, 0.7896, 0.2061], + [0.4037, 0.7308, 0.6713, 0.8374], + [0.3779, 0.4836, 0.0313, 0.0067]]) + >>> # Calculate with different threshold. + >>> MultiLabelMetric.calculate(y_pred, y_true, thr=0.1) + (tensor(42.5000), tensor(75.), tensor(53.1746), tensor(8)) + >>> # Calculate with topk. + >>> MultiLabelMetric.calculate(y_pred, y_true, topk=1) + (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8)) + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_sampels = [ + ... DataSample().set_pred_score(pred).set_gt_score(gt) + ... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))] + >>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5)) + >>> evaluator.process(data_sampels) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision': 50.72898037055408, + 'multi-label/recall': 50.06836461357571, + 'multi-label/f1-score': 50.384466955258475 + } + >>> # Evaluate on each class by using topk strategy + >>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None)) + >>> evaluator.process(data_sampels) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5], + 'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27], + 'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25] + } + """ # noqa: E501 + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + thr: Optional[float] = None, + topk: Optional[int] = None, + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + + logger = MMLogger.get_current_instance() + if thr is None and topk is None: + thr = 0.5 + logger.warning('Neither thr nor k is given, set thr as 0.5 by ' + 'default.') + elif thr is not None and topk is not None: + logger.warning('Both thr and topk are given, ' + 'use threshold in favor of top-k.') + + self.thr = thr + self.topk = topk + self.average = average + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please choose from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(data_sample['gt_label'], + num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + metric_res = self.calculate( + pred, + target, + pred_indices=False, + target_indices=False, + average=self.average, + thr=self.thr, + topk=self.topk) + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + if self.thr: + suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}' + for k, v in pack_results(*metric_res).items(): + metrics[k + suffix] = v + else: + for k, v in pack_results(*metric_res).items(): + metrics[k + f'_top{self.topk}'] = v + + result_metrics = dict() + for k, v in metrics.items(): + if self.average is None: + result_metrics[k + '_classwise'] = v.detach().cpu().tolist() + elif self.average == 'macro': + result_metrics[k] = v.item() + else: + result_metrics[k + f'_{self.average}'] = v.item() + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + pred_indices: bool = False, + target_indices: bool = False, + average: Optional[str] = 'macro', + thr: Optional[float] = None, + topk: Optional[int] = None, + num_classes: Optional[int] = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + average (str | None): How to calculate the final metrics from + the confusion matrix of every category. It supports three + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories + and calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + num_classes (Optional, int): The number of classes. If the ``pred`` + is indices instead of onehot, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: A tensor for each metric. The shape is (1, ) if + ``average`` is not None, and (C, ) if ``average`` is None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + def _format_label(label, is_indices): + """format various label to torch.Tensor.""" + if isinstance(label, np.ndarray): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'array must be (N, num_classes).' + label = torch.from_numpy(label) + elif isinstance(label, torch.Tensor): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'tensor must be (N, num_classes).' + elif isinstance(label, Sequence): + if is_indices: + assert num_classes is not None, 'For index-type labels, ' \ + 'please specify `num_classes`.' + label = torch.stack([ + label_to_onehot(indices, num_classes) + for indices in label + ]) + else: + label = torch.stack( + [to_tensor(onehot) for onehot in label]) + else: + raise TypeError( + 'The `pred` and `target` must be type of torch.tensor or ' + f'np.ndarray or sequence but get {type(label)}.') + return label + + pred = _format_label(pred, pred_indices) + target = _format_label(target, target_indices).long() + + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + if num_classes is not None: + assert pred.size(1) == num_classes, \ + f'The shape of `pred` ({pred.shape}) '\ + f"doesn't match the num_classes ({num_classes})." + num_classes = pred.size(1) + + thr = 0.5 if (thr is None and topk is None) else thr + + if thr is not None: + # a label is predicted positive if larger than thr + pos_inds = (pred >= thr).long() + else: + # top-k labels will be predicted positive for any example + _, topk_indices = pred.topk(topk) + pos_inds = torch.zeros_like(pred).scatter_(1, topk_indices, 1) + pos_inds = pos_inds.long() + + return _precision_recall_f1_support(pos_inds, target, average) + + +def _average_precision(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + r"""Calculate the average precision for a single class. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + pred (torch.Tensor): The model prediction with shape + ``(N, num_classes)``. + target (torch.Tensor): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + torch.Tensor: average precision result. + """ + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + # a small value for division by zero errors + eps = torch.finfo(torch.float32).eps + + # get rid of -1 target such as difficult sample + # that is not wanted in evaluation results. + valid_index = target > -1 + pred = pred[valid_index] + target = target[valid_index] + + # sort examples + sorted_pred_inds = torch.argsort(pred, dim=0, descending=True) + sorted_target = target[sorted_pred_inds] + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = torch.cumsum(pos_inds, 0) + total_pos = tps[-1].item() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device) + pred_pos_nums[pred_pos_nums < eps] = eps + + tps[torch.logical_not(pos_inds)] = 0 + precision = tps / pred_pos_nums.float() + ap = torch.sum(precision, 0) / max(total_pos, eps) + return ap + + +@METRICS.register_module() +class AveragePrecision(BaseMetric): + r"""Calculate the average precision with respect of classes. + + AveragePrecision (AP) summarizes a precision-recall curve as the weighted + mean of maximum precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + average (str | None): How to calculate the final metrics from + every category. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. The result of this mode + is also called **mAP**. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + References + ---------- + 1. `Wikipedia entry for the Average precision + `_ + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import AveragePrecision + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], + ... [0.1, 0.2, 0.2, 0.1], + ... [0.7, 0.5, 0.9, 0.3], + ... [0.8, 0.1, 0.1, 0.2]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 0, 0, 0]]) + >>> AveragePrecision.calculate(y_pred, y_true) + tensor(70.833) + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_pred_score(i).set_gt_score(j) + ... for i, j in zip(y_pred, y_true) + ... ] + >>> evaluator = Evaluator(metrics=AveragePrecision()) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(5) + {'multi-label/mAP': 70.83333587646484} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=AveragePrecision(average=None)) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(5) + {'multi-label/AP_classwise': [100., 83.33, 100., 0.]} + """ + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.average = average + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(data_sample['gt_label'], + num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + + # concat + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + ap = self.calculate(pred, target, self.average) + + result_metrics = dict() + + if self.average is None: + result_metrics['AP_classwise'] = ap.detach().cpu().tolist() + else: + result_metrics['mAP'] = ap.item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[torch.Tensor, np.ndarray], + target: Union[torch.Tensor, np.ndarray], + average: Optional[str] = 'macro') -> torch.Tensor: + r"""Calculate the average precision for a single class. + + Args: + pred (torch.Tensor | np.ndarray): The model predictions with + shape ``(N, num_classes)``. + target (torch.Tensor | np.ndarray): The target of predictions + with shape ``(N, num_classes)``. + average (str | None): The average method. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. The result of this mode + is also called mAP. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + + Returns: + torch.Tensor: the average precision of all classes. + """ + average_options = ['macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target) + assert pred.ndim == 2 and pred.shape == target.shape, \ + 'Both `pred` and `target` should have shape `(N, num_classes)`.' + + num_classes = pred.shape[1] + ap = pred.new_zeros(num_classes) + for k in range(num_classes): + ap[k] = _average_precision(pred[:, k], target[:, k]) + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 diff --git a/mmpretrain/evaluation/metrics/multi_task.py b/mmpretrain/evaluation/metrics/multi_task.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6af7680192883308df5f24b65ec38c9bb65ce6 --- /dev/null +++ b/mmpretrain/evaluation/metrics/multi_task.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class MultiTasksMetric(BaseMetric): + """Metrics for MultiTask + Args: + task_metrics(dict): a dictionary in the keys are the names of the tasks + and the values is a list of the metric corresponds to this task + Examples: + >>> import torch + >>> from mmpretrain.evaluation import MultiTasksMetric + # -------------------- The Basic Usage -------------------- + >>>task_metrics = { + 'task0': [dict(type='Accuracy', topk=(1, ))], + 'task1': [dict(type='Accuracy', topk=(1, 3))] + } + >>>pred = [{ + 'pred_task': { + 'task0': torch.tensor([0.7, 0.0, 0.3]), + 'task1': torch.tensor([0.5, 0.2, 0.3]) + }, + 'gt_task': { + 'task0': torch.tensor(0), + 'task1': torch.tensor(2) + } + }, { + 'pred_task': { + 'task0': torch.tensor([0.0, 0.0, 1.0]), + 'task1': torch.tensor([0.0, 0.0, 1.0]) + }, + 'gt_task': { + 'task0': torch.tensor(2), + 'task1': torch.tensor(2) + } + }] + >>>metric = MultiTasksMetric(task_metrics) + >>>metric.process(None, pred) + >>>results = metric.evaluate(2) + results = { + 'task0_accuracy/top1': 100.0, + 'task1_accuracy/top1': 50.0, + 'task1_accuracy/top3': 100.0 + } + """ + + def __init__(self, + task_metrics: Dict, + collect_device: str = 'cpu') -> None: + self.task_metrics = task_metrics + super().__init__(collect_device=collect_device) + + self._metrics = {} + for task_name in self.task_metrics.keys(): + self._metrics[task_name] = [] + for metric in self.task_metrics[task_name]: + self._metrics[task_name].append(METRICS.build(metric)) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for task_name in self.task_metrics.keys(): + filtered_data_samples = [] + for data_sample in data_samples: + eval_mask = data_sample[task_name]['eval_mask'] + if eval_mask: + filtered_data_samples.append(data_sample[task_name]) + for metric in self._metrics[task_name]: + metric.process(data_batch, filtered_data_samples) + + def compute_metrics(self, results: list) -> dict: + raise NotImplementedError( + 'compute metrics should not be used here directly') + + def evaluate(self, size): + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are + "{task_name}_{metric_name}" , and the values + are corresponding results. + """ + metrics = {} + for task_name in self._metrics: + for metric in self._metrics[task_name]: + name = metric.__class__.__name__ + if name == 'MultiTasksMetric' or metric.results: + results = metric.evaluate(size) + else: + results = {metric.__class__.__name__: 0} + for key in results: + name = f'{task_name}_{key}' + if name in results: + """Inspired from https://github.com/open- + mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 + 747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" + raise ValueError( + 'There are multiple metric results with the same' + f'metric name {name}. Please make sure all metrics' + 'have different prefixes.') + metrics[name] = results[key] + return metrics diff --git a/mmpretrain/evaluation/metrics/nocaps.py b/mmpretrain/evaluation/metrics/nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e1d0625b66dfa1abe59bd6f83ea2a6c0b3d446 --- /dev/null +++ b/mmpretrain/evaluation/metrics/nocaps.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import mmengine + +from mmpretrain.registry import METRICS +from mmpretrain.utils import require +from .caption import COCOCaption, save_result + +try: + from pycocoevalcap.eval import COCOEvalCap + from pycocotools.coco import COCO +except ImportError: + COCOEvalCap = None + COCO = None + + +@METRICS.register_module() +class NocapsSave(COCOCaption): + """Nocaps evaluation wrapper. + + Save the generated captions and transform into coco format. + The dumped file can be submitted to the official evluation system. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + @require('pycocoevalcap') + def __init__(self, + save_dir: str = './', + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super(COCOCaption, self).__init__( + collect_device=collect_device, prefix=prefix) + self.save_dir = save_dir + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + """ + mmengine.mkdir_or_exist(self.save_dir) + save_result( + result=results, + result_dir=self.save_dir, + filename='nocap_pred', + remove_duplicate='image_id', + ) + + return dict() diff --git a/mmpretrain/evaluation/metrics/retrieval.py b/mmpretrain/evaluation/metrics/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..9813486b521c5b73d7be96901ea4f604bbe2a938 --- /dev/null +++ b/mmpretrain/evaluation/metrics/retrieval.py @@ -0,0 +1,445 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.utils import is_seq_of + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .single_label import to_tensor + + +@METRICS.register_module() +class RetrievalRecall(BaseMetric): + r"""Recall evaluation metric for image retrieval. + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k recall will + be calculated and outputted together. Defaults to 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + Use in the code: + + >>> import torch + >>> from mmpretrain.evaluation import RetrievalRecall + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [[0], [1], [2], [3]] + >>> y_true = [[0, 1], [2], [1], [0, 3]] + >>> RetrievalRecall.calculate( + >>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + [tensor([50.])] + >>> # Calculate the recall@1 and recall@5 for non-indices input. + >>> y_score = torch.rand((1000, 10)) + >>> import torch.nn.functional as F + >>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10) + >>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5)) + [tensor(9.3000), tensor(48.4000)] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label([0, 1]).set_pred_score( + ... torch.rand(10)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5))) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + {'retrieval/Recall@1': 20.700000762939453, + 'retrieval/Recall@5': 78.5999984741211} + + Use in OpenMMLab configs: + + .. code:: python + + val_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) + test_evaluator = val_evaluator + """ + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Union[int, Sequence[int]], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + topk = (topk, ) if isinstance(topk, int) else topk + + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + self.topk = topk + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample['pred_score'].clone() + gt_label = data_sample['gt_label'] + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalRecall.calculate( + pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + for i, k in enumerate(self.topk): + recall_at_k = sum([r[i].item() for r in results]) / len(results) + result_metrics[f'Recall@{k}'] = recall_at_k + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Union[int, Sequence[int]], + pred_indices: (bool) = False, + target_indices: (bool) = False) -> float: + """Calculate the average recall. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, Sequence[int]): Predictions with the k-th highest + scores are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + + Returns: + List[float]: the average recalls. + """ + topk = (topk, ) if isinstance(topk, int) else topk + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + max_keep = max(topk) + pred = _format_pred(pred, max_keep, pred_indices) + target = _format_target(target, target_indices) + + assert len(pred) == len(target), ( + f'Length of `pred`({len(pred)}) and `target` ({len(target)}) ' + f'must be the same.') + + num_samples = len(pred) + results = [] + for k in topk: + recalls = torch.zeros(num_samples) + for i, (sample_pred, + sample_target) in enumerate(zip(pred, target)): + sample_pred = np.array(to_tensor(sample_pred).cpu()) + sample_target = np.array(to_tensor(sample_target).cpu()) + recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max()) + results.append(recalls.mean() * 100) + return results + + +@METRICS.register_module() +class RetrievalAveragePrecision(BaseMetric): + r"""Calculate the average precision for image retrieval. + + Args: + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. + mode (str, optional): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page[1]; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets[2]. + + References: + [1] `Wikipedia entry for the Average precision `_ + + [2] `The Oxford Buildings Dataset + `_ + + Examples: + Use in code: + + >>> import torch + >>> import numpy as np + >>> from mmcls.evaluation import RetrievalAveragePrecision + >>> # using index format inputs + >>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3 + >>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]] + >>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True) + 29.246031746031747 + >>> # using tensor format inputs + >>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + >>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2) + >>> RetrievalAveragePrecision.calculate(pred, target, 10) + 62.222222222222214 + + Use in OpenMMLab config files: + + .. code:: python + + val_evaluator = dict(type='RetrievalAveragePrecision', topk=100) + test_evaluator = val_evaluator + """ + + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Optional[int] = None, + mode: Optional[str] = 'IR', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + if topk is None or (isinstance(topk, int) and topk <= 0): + raise ValueError('`topk` must be a ingter larger than 0.') + + mode_options = ['IR', 'integrate'] + assert mode in mode_options, \ + f'Invalid `mode` argument, please specify from {mode_options}.' + + self.topk = topk + self.mode = mode + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample.get('pred_score').clone() + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + gt_label = data_sample.get('gt_label') + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalAveragePrecision.calculate( + pred_score.unsqueeze(0), + target.unsqueeze(0), + self.topk, + mode=self.mode) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Optional[int] = None, + pred_indices: (bool) = False, + target_indices: (bool) = False, + mode: str = 'IR') -> float: + """Calculate the average precision. + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, optional): Predictions with the k-th highest scores + are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + mode (Optional[str]): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets. + + Returns: + float: the average precision of the query image. + + References: + [1] `Wikipedia entry for Average precision(information_retrieval) + `_ + [2] `The Oxford Buildings Dataset 0 else 1 + cur_precision = (i + 1) / (rank + 1) + prediction = (old_precision + cur_precision) / 2 + ap += prediction + ap = ap / len(target) + + return ap * 100 + + +def _format_pred(label, topk=None, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`pred` must be Sequence of indices when' \ + f' `pred_indices` set to True, but get {type(label)}' + for i, sample_pred in enumerate(label): + assert is_seq_of(sample_pred, int) or isinstance( + sample_pred, (np.ndarray, torch.Tensor)), \ + '`pred` should be Sequence of indices when `pred_indices`' \ + f'set to True. but pred[{i}] is {sample_pred}' + if topk: + label[i] = sample_pred[:min(topk, len(sample_pred))] + return label + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + topk = topk if topk else label.size()[-1] + _, indices = label.topk(topk) + return indices + + +def _format_target(label, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`target` must be Sequence of indices when' \ + f' `target_indices` set to True, but get {type(label)}' + for i, sample_gt in enumerate(label): + assert is_seq_of(sample_gt, int) or isinstance( + sample_gt, (np.ndarray, torch.Tensor)), \ + '`target` should be Sequence of indices when ' \ + f'`target_indices` set to True. but target[{i}] is {sample_gt}' + return label + + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif isinstance(label, Sequence) and not mmengine.is_str(label): + label = torch.tensor(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + + indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label] + return indices diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf01c78cc88e5ce5e232fe837a0d77293386112 --- /dev/null +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def get_pred_idx(prediction: str, choices: List[str], + options: List[str]) -> int: # noqa + """Get the index (e.g. 2) from the prediction (e.g. 'C') + + Args: + prediction (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + options (List(str)): The options for the question, + from ['A', 'B', 'C', 'D', 'E'] + + Returns: + int: The index of the prediction, from [0, 1, 2, 3, 4] + """ + if prediction in options[:len(choices)]: + return options.index(prediction) + else: + return random.choice(range(len(choices))) + + +@METRICS.register_module() +class ScienceQAMetric(BaseMetric): + """Evaluation Metric for ScienceQA. + + Args: + options (List(str)): Options for each question. Defaults to + ["A", "B", "C", "D", "E"]. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + def __init__(self, + options: List[str] = ['A', 'B', 'C', 'D', 'E'], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.options = options + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + data_samples should contain the following keys: + 1. pred_answer (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + 2. choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + 3. grade (int): The grade for the question, from grade1 to grade12 + 4. subject (str): The subject for the question, from + ['natural science', 'social science', 'language science'] + 5. answer (str): The answer for the question, from + ['A', 'B', 'C', 'D', 'E'] + 6. hint (str): The hint for the question + 7. has_image (bool): Whether or not the question has image + + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + choices = data_sample.get('choices') + result['prediction'] = get_pred_idx( + data_sample.get('pred_answer'), choices, self.options) + result['grade'] = data_sample.get('grade') + result['subject'] = data_sample.get('subject') + result['answer'] = data_sample.get('gt_answer') + hint = data_sample.get('hint') + has_image = data_sample.get('has_image', False) + result['no_context'] = True if not has_image and len( + hint) == 0 else False # noqa + result['has_text'] = True if len(hint) > 0 else False + result['has_image'] = has_image + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = dict() + + all_acc = [] + acc_natural = [] + acc_social = [] + acc_language = [] + acc_has_text = [] + acc_has_image = [] + acc_no_context = [] + acc_grade_1_6 = [] + acc_grade_7_12 = [] + + for result in results: + correct = result['prediction'] == result['answer'] + all_acc.append(correct) + # different subjects + if result['subject'] == 'natural science': + acc_natural.append(correct) + elif result['subject'] == 'social science': + acc_social.append(correct) + elif result['subject'] == 'language science': + acc_language.append(correct) + + # different context + if result['has_text']: + acc_has_text.append(correct) + elif result['has_image']: + acc_has_image.append(correct) + elif result['no_context']: + acc_no_context.append(correct) + + # different grade + if result['grade'] in [ + 'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6' + ]: + acc_grade_1_6.append(correct) + elif result['grade'] in [ + 'grade7', 'grade8', 'grade9', 'grade10', 'grade11', + 'grade12' + ]: + acc_grade_7_12.append(correct) + + metrics['all_acc'] = sum(all_acc) / len(all_acc) + if len(acc_natural) > 0: + metrics['acc_natural'] = sum(acc_natural) / len(acc_natural) + if len(acc_social) > 0: + metrics['acc_social'] = sum(acc_social) / len(acc_social) + if len(acc_language) > 0: + metrics['acc_language'] = sum(acc_language) / len(acc_language) + if len(acc_has_text) > 0: + metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text) + if len(acc_has_image) > 0: + metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image) + if len(acc_no_context) > 0: + metrics['acc_no_context'] = sum(acc_no_context) / len( + acc_no_context) + if len(acc_grade_1_6) > 0: + metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6) + if len(acc_grade_7_12) > 0: + metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len( + acc_grade_7_12) + + return metrics diff --git a/mmpretrain/evaluation/metrics/shape_bias_label.py b/mmpretrain/evaluation/metrics/shape_bias_label.py new file mode 100644 index 0000000000000000000000000000000000000000..27c80a36073a9e6edd5e6583e213ed93374b165e --- /dev/null +++ b/mmpretrain/evaluation/metrics/shape_bias_label.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import os +import os.path as osp +from typing import List, Sequence + +import numpy as np +import torch +from mmengine.dist.utils import get_rank +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class ShapeBiasMetric(BaseMetric): + """Evaluate the model on ``cue_conflict`` dataset. + + This module will evaluate the model on an OOD dataset, cue_conflict, in + order to measure the shape bias of the model. In addition to compuate the + Top-1 accuracy, this module also generate a csv file to record the + detailed prediction results, such that this csv file can be used to + generate the shape bias curve. + + Args: + csv_dir (str): The directory to save the csv file. + model_name (str): The name of the csv file. Please note that the + model name should be an unique identifier. + dataset_name (str): The name of the dataset. Default: 'cue_conflict'. + """ + + # mapping several classes from ImageNet-1K to the same category + airplane_indices = [404] + bear_indices = [294, 295, 296, 297] + bicycle_indices = [444, 671] + bird_indices = [ + 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129, + 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, + 145 + ] + boat_indices = [472, 554, 625, 814, 914] + bottle_indices = [440, 720, 737, 898, 899, 901, 907] + car_indices = [436, 511, 817] + cat_indices = [281, 282, 283, 284, 285, 286] + chair_indices = [423, 559, 765, 857] + clock_indices = [409, 530, 892] + dog_indices = [ + 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194, + 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, + 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, + 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254, + 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268 + ] + elephant_indices = [385, 386] + keyboard_indices = [508, 878] + knife_indices = [499] + oven_indices = [766] + truck_indices = [555, 569, 656, 675, 717, 734, 864, 867] + + def __init__(self, + csv_dir: str, + model_name: str, + dataset_name: str = 'cue_conflict', + **kwargs) -> None: + super().__init__(**kwargs) + + self.categories = sorted([ + 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock', + 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car', + 'bird', 'dog' + ]) + self.csv_dir = csv_dir + self.model_name = model_name + self.dataset_name = dataset_name + if get_rank() == 0: + self.csv_path = self.create_csv() + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + result['gt_category'] = data_sample['img_path'].split('/')[-2] + result['img_name'] = data_sample['img_path'].split('/')[-1] + + aggregated_category_probabilities = [] + # get the prediction for each category of current instance + for category in self.categories: + category_indices = getattr(self, f'{category}_indices') + category_probabilities = torch.gather( + result['pred_score'], 0, + torch.tensor(category_indices)).mean() + aggregated_category_probabilities.append( + category_probabilities) + # sort the probabilities in descending order + pred_indices = torch.stack(aggregated_category_probabilities + ).argsort(descending=True).numpy() + result['pred_category'] = np.take(self.categories, pred_indices) + + # Save the result to `self.results`. + self.results.append(result) + + def create_csv(self) -> str: + """Create a csv file to store the results.""" + session_name = 'session-1' + csv_path = osp.join( + self.csv_dir, self.dataset_name + '_' + self.model_name + '_' + + session_name + '.csv') + if osp.exists(csv_path): + os.remove(csv_path) + directory = osp.dirname(csv_path) + if not osp.exists(directory): + os.makedirs(directory, exist_ok=True) + with open(csv_path, 'w') as f: + writer = csv.writer(f) + writer.writerow([ + 'subj', 'session', 'trial', 'rt', 'object_response', + 'category', 'condition', 'imagename' + ]) + return csv_path + + def dump_results_to_csv(self, results: List[dict]) -> None: + """Dump the results to a csv file. + + Args: + results (List[dict]): A list of results. + """ + for i, result in enumerate(results): + img_name = result['img_name'] + category = result['gt_category'] + condition = 'NaN' + with open(self.csv_path, 'a') as f: + writer = csv.writer(f) + writer.writerow([ + self.model_name, 1, i + 1, 'NaN', + result['pred_category'][0], category, condition, img_name + ]) + + def compute_metrics(self, results: List[dict]) -> dict: + """Compute the metrics from the results. + + Args: + results (List[dict]): A list of results. + + Returns: + dict: A dict of metrics. + """ + if get_rank() == 0: + self.dump_results_to_csv(results) + metrics = dict() + metrics['accuracy/top1'] = np.mean([ + result['pred_category'][0] == result['gt_category'] + for result in results + ]) + + return metrics diff --git a/mmpretrain/evaluation/metrics/single_label.py b/mmpretrain/evaluation/metrics/single_label.py new file mode 100644 index 0000000000000000000000000000000000000000..f9329b9567e698a4e3ebdb7d77f0f8404b81ad4c --- /dev/null +++ b/mmpretrain/evaluation/metrics/single_label.py @@ -0,0 +1,776 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def to_tensor(value): + """Convert value to torch.Tensor.""" + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif isinstance(value, Sequence) and not mmengine.is_str(value): + value = torch.tensor(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'{type(value)} is not an available argument.') + return value + + +def _precision_recall_f1_support(pred_positive, gt_positive, average): + """calculate base classification task metrics, such as precision, recall, + f1_score, support.""" + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + + # ignore -1 target such as difficult sample that is not wanted + # in evaluation results. + # only for calculate multi-label without affecting single-label behavior + ignored_index = gt_positive == -1 + pred_positive[ignored_index] = 0 + gt_positive[ignored_index] = 0 + + class_correct = (pred_positive & gt_positive) + if average == 'micro': + tp_sum = class_correct.sum() + pred_sum = pred_positive.sum() + gt_sum = gt_positive.sum() + else: + tp_sum = class_correct.sum(0) + pred_sum = pred_positive.sum(0) + gt_sum = gt_positive.sum(0) + + precision = tp_sum / torch.clamp(pred_sum, min=1).float() * 100 + recall = tp_sum / torch.clamp(gt_sum, min=1).float() * 100 + f1_score = 2 * precision * recall / torch.clamp( + precision + recall, min=torch.finfo(torch.float32).eps) + if average in ['macro', 'micro']: + precision = precision.mean(0) + recall = recall.mean(0) + f1_score = f1_score.mean(0) + support = gt_sum.sum(0) + else: + support = gt_sum + return precision, recall, f1_score, support + + +@METRICS.register_module() +class Accuracy(BaseMetric): + r"""Accuracy evaluation metric. + + For either binary classification or multi-class classification, the + accuracy is the fraction of correct predictions in all predictions: + + .. math:: + + \text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}} + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k accuracy will + be calculated and outputted together. Defaults to 1. + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, not apply threshold. If the parameter is a + tuple, accuracy based on all thresholds will be calculated and + outputted together. Defaults to 0. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import Accuracy + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [0, 2, 1, 3] + >>> y_true = [0, 1, 2, 3] + >>> Accuracy.calculate(y_pred, y_true) + tensor([50.]) + >>> # Calculate the top1 and top5 accuracy. + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.zeros((1000, )) + >>> Accuracy.calculate(y_score, y_true, topk=(1, 5)) + [[tensor([9.9000])], [tensor([51.5000])]] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label(0).set_pred_score(torch.rand(10)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5))) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + { + 'accuracy/top1': 9.300000190734863, + 'accuracy/top5': 51.20000076293945 + } + """ + default_prefix: Optional[str] = 'accuracy' + + def __init__(self, + topk: Union[int, Sequence[int]] = (1, ), + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(topk, int): + self.topk = (topk, ) + else: + self.topk = tuple(topk) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = {} + + # concat + target = torch.cat([res['gt_label'] for res in results]) + if 'pred_score' in results[0]: + pred = torch.stack([res['pred_score'] for res in results]) + + try: + acc = self.calculate(pred, target, self.topk, self.thrs) + except ValueError as e: + # If the topk is invalid. + raise ValueError( + str(e) + ' Please check the `val_evaluator` and ' + '`test_evaluator` fields in your config file.') + + multi_thrs = len(self.thrs) > 1 + for i, k in enumerate(self.topk): + for j, thr in enumerate(self.thrs): + name = f'top{k}' + if multi_thrs: + name += '_no-thr' if thr is None else f'_thr-{thr:.2f}' + metrics[name] = acc[i][j].item() + else: + # If only label in the `pred_label`. + pred = torch.cat([res['pred_label'] for res in results]) + acc = self.calculate(pred, target, self.topk, self.thrs) + metrics['top1'] = acc.item() + + return metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + topk: Sequence[int] = (1, ), + thrs: Sequence[Union[float, None]] = (0., ), + ) -> Union[torch.Tensor, List[List[torch.Tensor]]]: + """Calculate the accuracy. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + thrs (Sequence[float | None]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. None means no thresholds. + Defaults to (0., ). + thrs (Sequence[float]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. Defaults to (0., ). + + Returns: + torch.Tensor | List[List[torch.Tensor]]: Accuracy. + + - torch.Tensor: If the ``pred`` is a sequence of label instead of + score (number of dimensions is 1). Only return a top-1 accuracy + tensor, and ignore the argument ``topk` and ``thrs``. + - List[List[torch.Tensor]]: If the ``pred`` is a sequence of score + (number of dimensions is 2). Return the accuracy on each ``topk`` + and ``thrs``. And the first dim is ``topk``, the second dim is + ``thrs``. + """ + + pred = to_tensor(pred) + target = to_tensor(target).to(torch.int64) + num = pred.size(0) + assert pred.size(0) == target.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target.size(0)}).' + + if pred.ndim == 1: + # For pred label, ignore topk and acc + pred_label = pred.int() + correct = pred.eq(target).float().sum(0, keepdim=True) + acc = correct.mul_(100. / num) + return acc + else: + # For pred score, calculate on all topk and thresholds. + pred = pred.float() + maxk = max(topk) + + if maxk > pred.size(1): + raise ValueError( + f'Top-{maxk} accuracy is unavailable since the number of ' + f'categories is {pred.size(1)}.') + + pred_score, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + results = [] + for k in topk: + results.append([]) + for thr in thrs: + # Only prediction values larger than thr are counted + # as correct + _correct = correct + if thr is not None: + _correct = _correct & (pred_score.t() > thr) + correct_k = _correct[:k].reshape(-1).float().sum( + 0, keepdim=True) + acc = correct_k.mul_(100. / num) + results[-1].append(acc) + return results + + +@METRICS.register_module() +class SingleLabelMetric(BaseMetric): + r"""A collection of precision, recall, f1-score and support for + single-label tasks. + + The collection of metrics is for single-label multi-class classification. + And all these metrics are based on the confusion matrix of every category: + + .. image:: ../../_static/image/confusion-matrix.png + :width: 60% + :align: center + + All metrics can be formulated use variables above: + + **Precision** is the fraction of correct predictions in all predictions: + + .. math:: + \text{Precision} = \frac{TP}{TP+FP} + + **Recall** is the fraction of correct predictions in all targets: + + .. math:: + \text{Recall} = \frac{TP}{TP+FN} + + **F1-score** is the harmonic mean of the precision and recall: + + .. math:: + \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}} + + **Support** is the number of samples: + + .. math:: + \text{Support} = TP + TN + FN + FP + + Args: + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, only the top-1 prediction will be regard as + the positive prediction. If the parameter is a tuple, accuracy + based on all thresholds will be calculated and outputted together. + Defaults to 0. + items (Sequence[str]): The detailed metric items to evaluate, select + from "precision", "recall", "f1-score" and "support". + Defaults to ``('precision', 'recall', 'f1-score')``. + average (str | None): How to calculate the final metrics from the + confusion matrix of every category. It supports three modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories and + calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import SingleLabelMetric + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> # Output precision, recall, f1-score and support. + >>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4) + (tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4)) + >>> # Calculate with different thresholds. + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.zeros((1000, )) + >>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9)) + [(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)), + (tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=SingleLabelMetric()) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + {'single-label/precision': 19.650691986083984, + 'single-label/recall': 19.600000381469727, + 'single-label/f1-score': 19.619548797607422} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None)) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + { + 'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1], + 'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0], + 'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0] + } + """ # noqa: E501 + default_prefix: Optional[str] = 'single-label' + + def __init__(self, + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please specify from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + self.average = average + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + num_classes = self.num_classes or data_sample.get( + 'num_classes') + assert num_classes is not None, \ + 'The `num_classes` must be specified if no `pred_score`.' + result['pred_label'] = data_sample['pred_label'].cpu() + result['num_classes'] = num_classes + result['gt_label'] = data_sample['gt_label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + # concat + target = torch.cat([res['gt_label'] for res in results]) + if 'pred_score' in results[0]: + pred = torch.stack([res['pred_score'] for res in results]) + metrics_list = self.calculate( + pred, target, thrs=self.thrs, average=self.average) + + multi_thrs = len(self.thrs) > 1 + for i, thr in enumerate(self.thrs): + if multi_thrs: + suffix = '_no-thr' if thr is None else f'_thr-{thr:.2f}' + else: + suffix = '' + + for k, v in pack_results(*metrics_list[i]).items(): + metrics[k + suffix] = v + else: + # If only label in the `pred_label`. + pred = torch.cat([res['pred_label'] for res in results]) + res = self.calculate( + pred, + target, + average=self.average, + num_classes=results[0]['num_classes']) + metrics = pack_results(*res) + + result_metrics = dict() + for k, v in metrics.items(): + + if self.average is None: + result_metrics[k + '_classwise'] = v.cpu().detach().tolist() + elif self.average == 'micro': + result_metrics[k + f'_{self.average}'] = v.item() + else: + result_metrics[k] = v.item() + + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + thrs: Sequence[Union[float, None]] = (0., ), + average: Optional[str] = 'macro', + num_classes: Optional[int] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score and support. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + thrs (Sequence[float | None]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. None means no thresholds. + Defaults to (0., ). + average (str | None): How to calculate the final metrics from + the confusion matrix of every category. It supports three + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories + and calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: If the ``pred`` is a sequence of label instead of + score (number of dimensions is 1). Only returns a tensor for + each metric. The shape is (1, ) if ``classwise`` is False, and + (C, ) if ``classwise`` is True. + - List[torch.Tensor]: If the ``pred`` is a sequence of score + (number of dimensions is 2). Return the metrics on each ``thrs``. + The shape of tensor is (1, ) if ``classwise`` is False, and (C, ) + if ``classwise`` is True. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target).to(torch.int64) + assert pred.size(0) == target.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target.size(0)}).' + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + gt_positive = F.one_hot(target.flatten(), num_classes) + pred_positive = F.one_hot(pred.to(torch.int64), num_classes) + return _precision_recall_f1_support(pred_positive, gt_positive, + average) + else: + # For pred score, calculate on all thresholds. + num_classes = pred.size(1) + pred_score, pred_label = torch.topk(pred, k=1) + pred_score = pred_score.flatten() + pred_label = pred_label.flatten() + + gt_positive = F.one_hot(target.flatten(), num_classes) + + results = [] + for thr in thrs: + pred_positive = F.one_hot(pred_label, num_classes) + if thr is not None: + pred_positive[pred_score <= thr] = 0 + results.append( + _precision_recall_f1_support(pred_positive, gt_positive, + average)) + + return results + + +@METRICS.register_module() +class ConfusionMatrix(BaseMetric): + r"""A metric to calculate confusion matrix for single-label tasks. + + Args: + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + + 1. The basic usage. + + >>> import torch + >>> from mmpretrain.evaluation import ConfusionMatrix + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4) + tensor([[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) + >>> # plot the confusion matrix + >>> import matplotlib.pyplot as plt + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.randint(10, (1000, )) + >>> matrix = ConfusionMatrix.calculate(y_score, y_true) + >>> ConfusionMatrix().plot(matrix) + >>> plt.show() + + 2. In the config file + + .. code:: python + + val_evaluator = dict(type='ConfusionMatrix') + test_evaluator = dict(type='ConfusionMatrix') + """ # noqa: E501 + default_prefix = 'confusion_matrix' + + def __init__(self, + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + for data_sample in data_samples: + if 'pred_score' in data_sample: + pred_score = data_sample['pred_score'] + pred_label = pred_score.argmax(dim=0, keepdim=True) + self.num_classes = pred_score.size(0) + else: + pred_label = data_sample['pred_label'] + + self.results.append({ + 'pred_label': pred_label, + 'gt_label': data_sample['gt_label'], + }) + + def compute_metrics(self, results: list) -> dict: + pred_labels = [] + gt_labels = [] + for result in results: + pred_labels.append(result['pred_label']) + gt_labels.append(result['gt_label']) + confusion_matrix = ConfusionMatrix.calculate( + torch.cat(pred_labels), + torch.cat(gt_labels), + num_classes=self.num_classes) + return {'result': confusion_matrix} + + @staticmethod + def calculate(pred, target, num_classes=None) -> dict: + """Calculate the confusion matrix for single-label task. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + torch.Tensor: The confusion matrix. + """ + pred = to_tensor(pred) + target_label = to_tensor(target).int() + + assert pred.size(0) == target_label.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target_label.size(0)}).' + assert target_label.ndim == 1 + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + pred_label = pred + else: + num_classes = num_classes or pred.size(1) + pred_label = torch.argmax(pred, dim=1).flatten() + + with torch.no_grad(): + indices = num_classes * target_label + pred_label + matrix = torch.bincount(indices, minlength=num_classes**2) + matrix = matrix.reshape(num_classes, num_classes) + + return matrix + + @staticmethod + def plot(confusion_matrix: torch.Tensor, + include_values: bool = False, + cmap: str = 'viridis', + classes: Optional[List[str]] = None, + colorbar: bool = True, + show: bool = True): + """Draw a confusion matrix by matplotlib. + + Modified from `Scikit-Learn + `_ + + Args: + confusion_matrix (torch.Tensor): The confusion matrix to draw. + include_values (bool): Whether to draw the values in the figure. + Defaults to False. + cmap (str): The color map to use. Defaults to use "viridis". + classes (list[str], optional): The names of categories. + Defaults to None, which means to use index number. + colorbar (bool): Whether to show the colorbar. Defaults to True. + show (bool): Whether to show the figure immediately. + Defaults to True. + """ # noqa: E501 + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 10)) + + num_classes = confusion_matrix.size(0) + + im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap) + text_ = None + cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0) + + if include_values: + text_ = np.empty_like(confusion_matrix, dtype=object) + + # print text with appropriate color depending on background + thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0 + + for i, j in product(range(num_classes), range(num_classes)): + color = cmap_max if confusion_matrix[i, + j] < thresh else cmap_min + + text_cm = format(confusion_matrix[i, j], '.2g') + text_d = format(confusion_matrix[i, j], 'd') + if len(text_d) < len(text_cm): + text_cm = text_d + + text_[i, j] = ax.text( + j, i, text_cm, ha='center', va='center', color=color) + + display_labels = classes or np.arange(num_classes) + + if colorbar: + fig.colorbar(im_, ax=ax) + ax.set( + xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=display_labels, + yticklabels=display_labels, + ylabel='True label', + xlabel='Predicted label', + ) + ax.invert_yaxis() + ax.xaxis.tick_top() + + ax.set_ylim((num_classes - 0.5, -0.5)) + # Automatically rotate the x labels. + fig.autofmt_xdate(ha='center') + + if show: + plt.show() + return fig diff --git a/mmpretrain/evaluation/metrics/visual_grounding_eval.py b/mmpretrain/evaluation/metrics/visual_grounding_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ad16e5adf4660496b3a984087294ed9c0fee6537 --- /dev/null +++ b/mmpretrain/evaluation/metrics/visual_grounding_eval.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torchvision.ops.boxes as boxes +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def aligned_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor): + area1 = boxes.box_area(boxes1) + area2 = boxes.box_area(boxes2) + + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (B, 2) + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (B, 2) + + wh = boxes._upcast(rb - lt).clamp(min=0) # (B, 2) + inter = wh[:, 0] * wh[:, 1] # (B, ) + + union = area1 + area2 - inter + iou = inter / union + return iou + + +@METRICS.register_module() +class VisualGroundingMetric(BaseMetric): + """Visual Grounding evaluator. + + Calculate the box mIOU and box grounding accuracy for visual grounding + model. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'visual-grounding' + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for preds in data_samples: + + pred_box = preds['pred_bboxes'].squeeze() + box_gt = torch.Tensor(preds['gt_bboxes']).squeeze() + + result = { + 'box': pred_box.to('cpu').squeeze(), + 'box_target': box_gt.squeeze(), + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + pred_boxes = torch.stack([each['box'] for each in results]) + gt_boxes = torch.stack([each['box_target'] for each in results]) + iou = aligned_box_iou(pred_boxes, gt_boxes) + accu_num = torch.sum(iou >= 0.5) + + miou = torch.mean(iou) + acc = accu_num / len(gt_boxes) + coco_val = {'miou': miou, 'acc': acc} + return coco_val diff --git a/mmpretrain/evaluation/metrics/voc_multi_label.py b/mmpretrain/evaluation/metrics/voc_multi_label.py new file mode 100644 index 0000000000000000000000000000000000000000..1034852722796271c7ade9d75c3442cce8f1d0d1 --- /dev/null +++ b/mmpretrain/evaluation/metrics/voc_multi_label.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .multi_label import AveragePrecision, MultiLabelMetric + + +class VOCMetricMixin: + """A mixin class for VOC dataset metrics, VOC annotations have extra + `difficult` attribute for each object, therefore, extra option is needed + for calculating VOC metrics. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + """ + + def __init__(self, + *arg, + difficult_as_positive: Optional[bool] = None, + **kwarg): + self.difficult_as_positive = difficult_as_positive + super().__init__(*arg, **kwarg) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + gt_label = data_sample['gt_label'] + gt_label_difficult = data_sample['gt_label_difficult'] + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(gt_label, num_classes) + + # VOC annotation labels all the objects in a single image + # therefore, some categories are appeared both in + # difficult objects and non-difficult objects. + # Here we reckon those labels which are only exists in difficult + # objects as difficult labels. + difficult_label = set(gt_label_difficult) - ( + set(gt_label_difficult) & set(gt_label.tolist())) + + # set difficult label for better eval + if self.difficult_as_positive is None: + result['gt_score'][[*difficult_label]] = -1 + elif self.difficult_as_positive: + result['gt_score'][[*difficult_label]] = 1 + + # Save the result to `self.results`. + self.results.append(result) + + +@METRICS.register_module() +class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric): + """A collection of metrics for multi-label multi-class classification task + based on confusion matrix for VOC dataset. + + It includes precision, recall, f1-score and support. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + **kwarg: Refers to `MultiLabelMetric` for detailed docstrings. + """ + + +@METRICS.register_module() +class VOCAveragePrecision(VOCMetricMixin, AveragePrecision): + """Calculate the average precision with respect of classes for VOC dataset. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + **kwarg: Refers to `AveragePrecision` for detailed docstrings. + """ diff --git a/mmpretrain/evaluation/metrics/vqa.py b/mmpretrain/evaluation/metrics/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..fd77ba9bc23e013c41ac095810740bdb71d33fb3 --- /dev/null +++ b/mmpretrain/evaluation/metrics/vqa.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Partly adopted from https://github.com/GT-Vision-Lab/VQA +# Copyright (c) 2014, Aishwarya Agrawal +from typing import List, Optional + +import mmengine +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS + + +def _process_punctuation(inText): + import re + outText = inText + punct = [ + ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!' + ] + commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 + periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 + for p in punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search( + commaStrip, inText) is not None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = periodStrip.sub('', outText, re.UNICODE) + return outText + + +def _process_digit_article(inText): + outText = [] + tempText = inText.lower().split() + articles = ['a', 'an', 'the'] + manualMap = { + 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10', + } + contractions = { + 'aint': "ain't", + 'arent': "aren't", + 'cant': "can't", + 'couldve': "could've", + 'couldnt': "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + 'didnt': "didn't", + 'doesnt': "doesn't", + 'dont': "don't", + 'hadnt': "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + 'hasnt': "hasn't", + 'havent': "haven't", + 'hed': "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + 'hes': "he's", + 'howd': "how'd", + 'howll': "how'll", + 'hows': "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + 'Im': "I'm", + 'Ive': "I've", + 'isnt': "isn't", + 'itd': "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + 'itll': "it'll", + "let's": "let's", + 'maam': "ma'am", + 'mightnt': "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + 'mightve': "might've", + 'mustnt': "mustn't", + 'mustve': "must've", + 'neednt': "needn't", + 'notve': "not've", + 'oclock': "o'clock", + 'oughtnt': "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + 'shant': "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + 'shouldve': "should've", + 'shouldnt': "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": 'somebodyd', + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + 'somebodyll': "somebody'll", + 'somebodys': "somebody's", + 'someoned': "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + 'someonell': "someone'll", + 'someones': "someone's", + 'somethingd': "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + 'somethingll': "something'll", + 'thats': "that's", + 'thered': "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + 'therere': "there're", + 'theres': "there's", + 'theyd': "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + 'theyll': "they'll", + 'theyre': "they're", + 'theyve': "they've", + 'twas': "'twas", + 'wasnt': "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + 'weve': "we've", + 'werent': "weren't", + 'whatll': "what'll", + 'whatre': "what're", + 'whats': "what's", + 'whatve': "what've", + 'whens': "when's", + 'whered': "where'd", + 'wheres': "where's", + 'whereve': "where've", + 'whod': "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + 'wholl': "who'll", + 'whos': "who's", + 'whove': "who've", + 'whyll': "why'll", + 'whyre': "why're", + 'whys': "why's", + 'wont': "won't", + 'wouldve': "would've", + 'wouldnt': "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + 'yall': "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + 'youd': "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + 'youll': "you'll", + 'youre': "you're", + 'youve': "you've", + } + for word in tempText: + word = manualMap.setdefault(word, word) + if word not in articles: + outText.append(word) + for wordId, word in enumerate(outText): + if word in contractions: + outText[wordId] = contractions[word] + outText = ' '.join(outText) + return outText + + +@METRICS.register_module() +class VQAAcc(BaseMetric): + '''VQA Acc metric. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'VQA' + + def __init__(self, + full_score_weight: float = 0.3, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.full_score_weight = full_score_weight + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + gt_answer_weight = sample.get('gt_answer_weight') + if isinstance(gt_answer, str): + gt_answer = [gt_answer] + if gt_answer_weight is None: + gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer) + + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer, + 'gt_answer_weight': gt_answer_weight, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = [ + self._process_answer(answer) for answer in result['gt_answer'] + ] + answer_weight = result['gt_answer_weight'] + + weight_sum = 0 + for i, gt in enumerate(gt_answer): + if gt == pred_answer: + weight_sum += answer_weight[i] + vqa_acc = min(1.0, weight_sum / self.full_score_weight) + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer): + answer = answer.replace('\n', ' ') + answer = answer.replace('\t', ' ') + answer = answer.strip() + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer + + +@METRICS.register_module() +class ReportVQA(BaseMetric): + """Dump VQA result to the standard json format for VQA evaluation. + + Args: + file_path (str): The file path to save the result file. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'VQA' + + def __init__(self, + file_path: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + if not file_path.endswith('.json'): + raise ValueError('The output file must be a json file.') + self.file_path = file_path + + def process(self, data_batch, data_samples) -> None: + """transfer tensors in predictions to CPU.""" + for sample in data_samples: + question_id = sample['question_id'] + pred_answer = sample['pred_answer'] + + result = { + 'question_id': int(question_id), + 'answer': pred_answer, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Dump the result to json file.""" + mmengine.dump(results, self.file_path) + logger = MMLogger.get_current_instance() + logger.info(f'Results has been saved to {self.file_path}.') + return {} diff --git a/mmpretrain/models/.DS_Store b/mmpretrain/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..44c41c24785fdc8c6130baaa05aec30aaf66ed72 Binary files /dev/null and b/mmpretrain/models/.DS_Store differ diff --git a/mmpretrain/models/__init__.py b/mmpretrain/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f583114ee54abd7885759c63b45231252ae0db1 --- /dev/null +++ b/mmpretrain/models/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS, + build_backbone, build_classifier, build_head, build_loss, + build_neck) +from .classifiers import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .multimodal import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .peft import * # noqa: F401,F403 +from .retrievers import * # noqa: F401,F403 +from .selfsup import * # noqa: F401,F403 +from .tta import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone', + 'build_head', 'build_neck', 'build_loss', 'build_classifier' +] diff --git a/mmpretrain/models/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c83faa83c61526f03e4213135adb4329791025f0 Binary files /dev/null and b/mmpretrain/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/__pycache__/builder.cpython-311.pyc b/mmpretrain/models/__pycache__/builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b8dbb89ebb02e520e6ffc24370f3f289e0383be Binary files /dev/null and b/mmpretrain/models/__pycache__/builder.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60e37fb7b6e15cadd0eef4a3c9c79c856fbf4247 --- /dev/null +++ b/mmpretrain/models/backbones/__init__.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .alexnet import AlexNet +from .beit import BEiTViT +from .conformer import Conformer +from .convmixer import ConvMixer +from .convnext import ConvNeXt +from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt +from .davit import DaViT +from .deit import DistilledVisionTransformer +from .deit3 import DeiT3 +from .densenet import DenseNet +from .edgenext import EdgeNeXt +from .efficientformer import EfficientFormer +from .efficientnet import EfficientNet +from .efficientnet_v2 import EfficientNetV2 +from .hivit import HiViT +from .hornet import HorNet +from .hrnet import HRNet +from .inception_v3 import InceptionV3 +from .lenet import LeNet5 +from .levit import LeViT +from .mixmim import MixMIMTransformer +from .mlp_mixer import MlpMixer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mobileone import MobileOne +from .mobilevit import MobileViT +from .mvit import MViT +from .poolformer import PoolFormer +from .regnet import RegNet +from .replknet import RepLKNet +from .repmlp import RepMLPNet +from .repvgg import RepVGG +from .res2net import Res2Net +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnet_cifar import ResNet_CIFAR +from .resnext import ResNeXt +from .revvit import RevVisionTransformer +from .riformer import RIFormer +from .seresnet import SEResNet +from .seresnext import SEResNeXt +from .shufflenet_v1 import ShuffleNetV1 +from .shufflenet_v2 import ShuffleNetV2 +from .sparse_convnext import SparseConvNeXt +from .sparse_resnet import SparseResNet +from .swin_transformer import SwinTransformer +from .swin_transformer_v2 import SwinTransformerV2 +from .t2t_vit import T2T_ViT +from .timm_backbone import TIMMBackbone +from .tinyvit import TinyViT +from .tnt import TNT +from .twins import PCPVT, SVT +from .van import VAN +from .vgg import VGG +from .vig import PyramidVig, Vig +from .vision_transformer import VisionTransformer +from .vit_eva02 import ViTEVA02 +from .vit_sam import ViTSAM +from .xcit import XCiT + +__all__ = [ + 'LeNet5', + 'AlexNet', + 'VGG', + 'RegNet', + 'ResNet', + 'ResNeXt', + 'ResNetV1d', + 'ResNeSt', + 'ResNet_CIFAR', + 'SEResNet', + 'SEResNeXt', + 'ShuffleNetV1', + 'ShuffleNetV2', + 'MobileNetV2', + 'MobileNetV3', + 'VisionTransformer', + 'SwinTransformer', + 'TNT', + 'TIMMBackbone', + 'T2T_ViT', + 'Res2Net', + 'RepVGG', + 'Conformer', + 'MlpMixer', + 'DistilledVisionTransformer', + 'PCPVT', + 'SVT', + 'EfficientNet', + 'EfficientNetV2', + 'ConvNeXt', + 'HRNet', + 'ResNetV1c', + 'ConvMixer', + 'EdgeNeXt', + 'CSPDarkNet', + 'CSPResNet', + 'CSPResNeXt', + 'CSPNet', + 'RepLKNet', + 'RepMLPNet', + 'PoolFormer', + 'RIFormer', + 'DenseNet', + 'VAN', + 'InceptionV3', + 'MobileOne', + 'EfficientFormer', + 'SwinTransformerV2', + 'MViT', + 'DeiT3', + 'HorNet', + 'MobileViT', + 'DaViT', + 'BEiTViT', + 'RevVisionTransformer', + 'MixMIMTransformer', + 'TinyViT', + 'LeViT', + 'Vig', + 'PyramidVig', + 'XCiT', + 'ViTSAM', + 'ViTEVA02', + 'HiViT', + 'SparseResNet', + 'SparseConvNeXt', +] diff --git a/mmpretrain/models/backbones/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..117e35f9413f5b10c0accfa138484d551944aa5b Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/alexnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/alexnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4f1e810f0f5246df46a7ff52c6306c2ec6de76b Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/alexnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bfde1004a284a602b2532a7f837a070573f49bf Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/beit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/beit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca14053b809cc5f71709c28477a6f1e73d68c21 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/beit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/conformer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/conformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94476b5275b08820c9340f4d8677dd1c16b1269 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/conformer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/convmixer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/convmixer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1804c9d35cd7d349da2f93b294cff83831875c0e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/convmixer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/convnext.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/convnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5894ab6cfb639bb9ab2488e4009935ab2d3357ec Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/convnext.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/cspnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/cspnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3d58159ffce41841767bba11b97026adb784553 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/cspnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/davit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/davit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1e203598dd654f60c54ec28befc1f8970b44e85 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/davit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/deit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/deit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e434dfbf3845cd347c52e101cff28f88e4a413b5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/deit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/deit3.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/deit3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa3a633a0fb11ba8b93938791030b8b8868278e4 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/deit3.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/densenet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/densenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f55dd5ea455dc04a01a09a4fd732f9f9de072a0e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/densenet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/edgenext.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/edgenext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fc9b9575e5837c86dcc94767bf69eba579829be Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/edgenext.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecf2aebe61636a223a7506563ce1424c385626ae Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b8cfa82b8fe8495dd9a7cb28fb8f79202b7a906 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82f2c40ab804b090ab346111ddc8deaf36236c05 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/hivit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/hivit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adbac545bf8249fca81cc7d52197be4254204232 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hivit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/hornet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/hornet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19141901251900228c959586de8e47f1fec47d59 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hornet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/hrnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/hrnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b945741d49b95d77d7df38bb343c390c6dabee1 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hrnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9203cd6e76ed0ea67424293c491a24d98702f2f Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/lenet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/lenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58a47ce26391ed0c80b38d80f2c18dc0d28c467f Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/lenet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/levit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/levit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a6b08ea81bcc5ef8a751a4a69fc7fcb431ea452 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/levit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mixmim.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mixmim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49a7553fa75ea5947a6fe04de7d0062197e64881 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mixmim.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f346b6cebc209822a51c5bd49e4dd0f570fe508 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f882944eb485fbf179966e6ac926408d9eca89e8 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..911b564c034ff20fe4ad3ec49dc65d7de8175edb Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobileone.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mobileone.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a7d56a4a33daced10d84ece4e6f6a57f4ac9ac5 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobileone.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed92dfc5887daa1327fb38df5f1953efe260e26d Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/mvit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/mvit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53201b7f0d46f4213921fb36984c27d62b5e7894 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mvit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/poolformer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/poolformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a75ef18e8ca1a007e87746269f6af25559435c44 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/poolformer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/regnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/regnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63a01054a9ce7ccfe03119ba7908f8c660743523 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/regnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/replknet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/replknet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee66bf43cb5ef334790d53d5aa2bfea733c5f43e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/replknet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/repmlp.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/repmlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41588fb5c2329f1283e4611579f40975b13a2392 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/repmlp.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/repvgg.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/repvgg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44664efac8ab17f9e71c10535a0cd01c48a3f5da Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/repvgg.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/res2net.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/res2net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b96f5b8f1854f9dc6a941131b4f15e7824cea0 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/res2net.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnest.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/resnest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33cee99d7a3b5e0a7cdb5d4a3a3b6351f6563a3e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnest.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c880df53c4b835bf54eb1b05c203002f65bcd3e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33dff460feba217bbf03e1e6b75c3c23ae45722 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/resnext.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/resnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dca147be96638948268c50e56efc7ae3f72c5be Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnext.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/revvit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/revvit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c1da2f791f9d4d10718fd63f029c0fe4fd158ba Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/revvit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/riformer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/riformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f4d67ecac448b9a3e2bb640058fab08c41f1068 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/riformer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/seresnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/seresnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f8168a79cac80c62ba800298d3e3ed2f258c91d Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/seresnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/seresnext.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/seresnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e747be71415641281ce86d08a3d22d1e32e503a Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/seresnext.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..496e78cad82531ec799229ce132955abe401dd7f Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a3073f7c958d1f43118c7726b4683ed8f5cb860 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d475df0d7fbead6818e693f143c315e03f2e1821 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05bfeea5507e3b6f6b8de020f6ba791223d86f6 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c8ac9ae1c85d728cb22e21b9dd9cf941e03e5d Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9c5ab6ac14abb4af1c53c8f3b1bc28c58c12900 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a9c3e58124cbe71dcc6f22946b6a260dac013c1 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c346160df2546b7a047d4f54fa906335c67b1d Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a2c7f2ad19fce206c1eaf15766977102153f02 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/tnt.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/tnt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4047ba6727c55e62adeeea8f32028650155abe9e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/tnt.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/twins.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/twins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..023517982a4a71fbfb37d99ac7514962fbb84094 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/twins.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/van.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/van.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b294c9b043116116909d26ab6eadd28c0af706e Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/van.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vgg.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/vgg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74ad53495fa88a41c4ba568db39afa0579007ce8 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vgg.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vig.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/vig.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d61e0bdc9e9c82d7e0f62c465a57a885a35ded8d Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vig.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dcabed8ceb32a0951ac488b36235d892cbc1f65 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3c68dcaa9a9f0029cd5b580983bdd4170ba5e8 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3278de78c3489ef5a21b3e49ce30c1e677cb8918 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/__pycache__/xcit.cpython-311.pyc b/mmpretrain/models/backbones/__pycache__/xcit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cfbfa4b7aef364fd11cb154a5baeeafd87adc51 Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/xcit.cpython-311.pyc differ diff --git a/mmpretrain/models/backbones/alexnet.py b/mmpretrain/models/backbones/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c2891fdd2c878e243331f572f6e3e562232d46 --- /dev/null +++ b/mmpretrain/models/backbones/alexnet.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class AlexNet(BaseBackbone): + """`AlexNet `_ backbone. + + The input for AlexNet is a 224x224 RGB image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return (x, ) diff --git a/mmpretrain/models/backbones/base_backbone.py b/mmpretrain/models/backbones/base_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..751aa956ba2ad178ea9e40875b6e610ee7bbbcd3 --- /dev/null +++ b/mmpretrain/models/backbones/base_backbone.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmengine.model import BaseModule + + +class BaseBackbone(BaseModule, metaclass=ABCMeta): + """Base backbone. + + This class defines the basic functions of a backbone. Any backbone that + inherits this class should at least define its own `forward` function. + """ + + def __init__(self, init_cfg=None): + super(BaseBackbone, self).__init__(init_cfg) + + @abstractmethod + def forward(self, x): + """Forward computation. + + Args: + x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + """ + pass + + def train(self, mode=True): + """Set module status before forward computation. + + Args: + mode (bool): Whether it is train_mode or test_mode + """ + super(BaseBackbone, self).train(mode) diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7d9085182a989a8b2a6b26e90c35702759f36f --- /dev/null +++ b/mmpretrain/models/backbones/beit.py @@ -0,0 +1,697 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed, + resize_relative_position_bias_table, to_2tuple) +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class RelativePositionBias(BaseModule): + """Relative Position Bias. + + This module is copied from + https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209. + + Args: + window_size (Sequence[int]): The window size of the relative + position bias. + num_heads (int): The number of head in multi-head attention. + with_cls_token (bool): To indicate the backbone has cls_token or not. + Defaults to True. + """ + + def __init__( + self, + window_size: Sequence[int], + num_heads: int, + with_cls_token: bool = True, + ) -> None: + super().__init__() + self.window_size = window_size + if with_cls_token: + num_extra_tokens = 3 + else: + num_extra_tokens = 0 + # cls to token & token to cls & cls to cls + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1) + num_extra_tokens + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each + # token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] -\ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + if with_cls_token: + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1, ) * 2, + dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum( + -1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + else: + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1], ) * 2, + dtype=relative_coords.dtype) + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + self.register_buffer('relative_position_index', + relative_position_index) + + def forward(self) -> torch.Tensor: + # Wh*Ww,Wh*Ww,nH + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) + return relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class BEiTTransformerEncoderLayer(TransformerEncoderLayer): + """Implements one encoder layer in BEiT. + + Comparing with conventional ``TransformerEncoderLayer``, this module + adds weights to the shortcut connection. In addition, ``BEiTAttention`` + is used to replace the original ``MultiheadAttention`` in + ``TransformerEncoderLayer``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. 1 means no scaling. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + window_size (tuple[int]): The height and width of the window. + Defaults to None. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='LN'). + attn_cfg (dict): The configuration for the attention layer. + Defaults to an empty dict. + ffn_cfg (dict): The configuration for the ffn layer. + Defaults to ``dict(add_identity=False)``. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + layer_scale_init_value: float, + window_size: Tuple[int, int], + use_rel_pos_bias: bool, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + bias: Union[str, bool] = 'qv_bias', + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + attn_cfg: dict = dict(), + ffn_cfg: dict = dict(add_identity=False), + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + attn_cfg = { + 'window_size': window_size, + 'use_rel_pos_bias': use_rel_pos_bias, + 'qk_scale': None, + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'attn_drop': attn_drop_rate, + 'proj_drop': drop_rate, + 'bias': bias, + **attn_cfg, + } + self.attn = BEiTAttention(**attn_cfg) + + ffn_cfg = { + 'embed_dims': embed_dims, + 'feedforward_channels': feedforward_channels, + 'num_fcs': num_fcs, + 'ffn_drop': drop_rate, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path_rate), + 'act_cfg': act_cfg, + **ffn_cfg, + } + self.ffn = FFN(**ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x: torch.Tensor, + rel_pos_bias: torch.Tensor) -> torch.Tensor: + if self.gamma_1 is None: + x = x + self.drop_path( + self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.ffn(self.ln2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn( + self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x))) + return x + + +@MODELS.register_module() +class BEiTViT(BaseBackbone): + """Backbone for BEiT. + + A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers + `_ + A PyTorch implement of : `BEiT v2: Masked Image Modeling with + Vector-Quantized Visual Tokenizers `_ + + Args: + arch (str | dict): BEiT architecture. If use string, choose from + 'base', 'large'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + use_abs_pos_emb (bool): Use position embedding like vanilla ViT. + Defaults to False. + use_rel_pos_bias (bool): Use relative position embedding in each + transformer encoder layer. Defaults to True. + use_shared_rel_pos_bias (bool): Use shared relative position embedding, + all transformer encoder layers share the same relative position + embedding. Defaults to False. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['eva-g', 'eva-giant'], + { + # The implementation in EVA + # + 'embed_dims': 1408, + 'num_layers': 40, + 'num_heads': 16, + 'feedforward_channels': 6144 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0, + drop_path_rate=0, + bias='qv_bias', + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=False, + out_type='avg_featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + interpolate_mode='bicubic', + layer_scale_init_value=0.1, + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(BEiTViT, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + else: + self.pos_embed = None + self.drop_after_pos = nn.Dropout(p=drop_rate) + + assert not (use_rel_pos_bias and use_shared_rel_pos_bias), ( + '`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set ' + 'to True at the same time') + self.use_rel_pos_bias = use_rel_pos_bias + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_resolution, + num_heads=self.arch_settings['num_heads']) + else: + self.rel_pos_bias = None + self._register_load_state_dict_pre_hook( + self._prepare_relative_position_bias_table) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + window_size=self.patch_resolution, + use_rel_pos_bias=use_rel_pos_bias, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + bias=bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + if out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(BEiTViT, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.with_cls_token: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == 'avg_featmap': + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + rel_pos_bias = self.rel_pos_bias() \ + if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, + **kwargs): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + + if self.use_rel_pos_bias and 'rel_pos_bias.relative_position_bias_table' in state_dict: # noqa:E501 + logger.info('Expand the shared relative position embedding to ' + 'each transformer block.') + rel_pos_bias = state_dict[ + 'rel_pos_bias.relative_position_bias_table'] + for i in range(self.num_layers): + state_dict[ + f'layers.{i}.attn.relative_position_bias_table'] = \ + rel_pos_bias.clone() + state_dict.pop('rel_pos_bias.relative_position_bias_table') + state_dict.pop('rel_pos_bias.relative_position_index') + + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'relative_position_bias_table' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + rel_pos_bias_pretrained = state_dict[ckpt_key] + rel_pos_bias_current = state_dict_model[key] + L1, nH1 = rel_pos_bias_pretrained.size() + L2, nH2 = rel_pos_bias_current.size() + src_size = int((L1 - 3)**0.5) + dst_size = int((L2 - 3)**0.5) + if L1 != L2: + extra_tokens = rel_pos_bias_pretrained[-3:, :] + rel_pos_bias = rel_pos_bias_pretrained[:-3, :] + + new_rel_pos_bias = resize_relative_position_bias_table( + src_size, dst_size, rel_pos_bias, nH1) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + logger.info('Resize the relative_position_bias_table from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos_bias.shape}') + state_dict[ckpt_key] = new_rel_pos_bias + + # The index buffer need to be re-generated. + index_buffer = ckpt_key.replace('bias_table', 'index') + if index_buffer in state_dict: + del state_dict[index_buffer] + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/conformer.py b/mmpretrain/models/backbones/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..eda72b0595b6923a7f1f563ae7186ca533f85023 --- /dev/null +++ b/mmpretrain/models/backbones/conformer.py @@ -0,0 +1,621 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import AdaptivePadding +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class ConvBlock(BaseModule): + """Basic convluation block used in Conformer. + + This block includes three convluation modules, and supports three new + functions: + 1. Returns the output of both the final layers and the second convluation + module. + 2. Fuses the input of the second convluation module with an extra input + feature map. + 3. Supports to add an extra convluation module to the identity connection. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + stride (int): The stride of the second convluation module. + Defaults to 1. + groups (int): The groups of the second convluation module. + Defaults to 1. + drop_path_rate (float): The rate of the DropPath layer. Defaults to 0. + with_residual_conv (bool): Whether to add an extra convluation module + to the identity connection. Defaults to False. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN', eps=1e-6)``. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='ReLU', inplace=True))``. + init_cfg (dict, optional): The extra config to initialize the module. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + groups=1, + drop_path_rate=0., + with_residual_conv=False, + norm_cfg=dict(type='BN', eps=1e-6), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(ConvBlock, self).__init__(init_cfg=init_cfg) + + expansion = 4 + mid_channels = out_channels // expansion + + self.conv1 = nn.Conv2d( + in_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1] + self.act1 = build_activation_layer(act_cfg) + + self.conv2 = nn.Conv2d( + mid_channels, + mid_channels, + kernel_size=3, + stride=stride, + groups=groups, + padding=1, + bias=False) + self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1] + self.act2 = build_activation_layer(act_cfg) + + self.conv3 = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn3 = build_norm_layer(norm_cfg, out_channels)[1] + self.act3 = build_activation_layer(act_cfg) + + if with_residual_conv: + self.residual_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1] + + self.with_residual_conv = with_residual_conv + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x, fusion_features=None, out_conv2=True): + identity = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) if fusion_features is None else self.conv2( + x + fusion_features) + x = self.bn2(x) + x2 = self.act2(x) + + x = self.conv3(x2) + x = self.bn3(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.with_residual_conv: + identity = self.residual_conv(identity) + identity = self.residual_bn(identity) + + x += identity + x = self.act3(x) + + if out_conv2: + return x, x2 + else: + return x + + +class FCUDown(BaseModule): + """CNN feature maps -> Transformer patch embeddings.""" + + def __init__(self, + in_channels, + out_channels, + down_stride, + with_cls_token=True, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(FCUDown, self).__init__(init_cfg=init_cfg) + self.down_stride = down_stride + self.with_cls_token = with_cls_token + + self.conv_project = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.sample_pooling = nn.AvgPool2d( + kernel_size=down_stride, stride=down_stride) + + self.ln = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x, x_t): + x = self.conv_project(x) # [N, C, H, W] + + x = self.sample_pooling(x).flatten(2).transpose(1, 2) + x = self.ln(x) + x = self.act(x) + + if self.with_cls_token: + x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) + + return x + + +class FCUUp(BaseModule): + """Transformer patch embeddings -> CNN feature maps.""" + + def __init__(self, + in_channels, + out_channels, + up_stride, + with_cls_token=True, + norm_cfg=dict(type='BN', eps=1e-6), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(FCUUp, self).__init__(init_cfg=init_cfg) + + self.up_stride = up_stride + self.with_cls_token = with_cls_token + + self.conv_project = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.bn = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x, H, W): + B, _, C = x.shape + # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14] + if self.with_cls_token: + x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) + else: + x_r = x.transpose(1, 2).reshape(B, C, H, W) + + x_r = self.act(self.bn(self.conv_project(x_r))) + + return F.interpolate( + x_r, size=(H * self.up_stride, W * self.up_stride)) + + +class ConvTransBlock(BaseModule): + """Basic module for Conformer. + + This module is a fusion of CNN block transformer encoder block. + + Args: + in_channels (int): The number of input channels in conv blocks. + out_channels (int): The number of output channels in conv blocks. + embed_dims (int): The embedding dimension in transformer blocks. + conv_stride (int): The stride of conv2d layers. Defaults to 1. + groups (int): The groups of conv blocks. Defaults to 1. + with_residual_conv (bool): Whether to add a conv-bn layer to the + identity connect in the conv block. Defaults to False. + down_stride (int): The stride of the downsample pooling layer. + Defaults to 4. + num_heads (int): The number of heads in transformer attention layers. + Defaults to 12. + mlp_ratio (float): The expansion ratio in transformer FFN module. + Defaults to 4. + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + with_cls_token (bool): Whether use class token or not. + Defaults to True. + drop_rate (float): The dropout rate of the output projection and + FFN in the transformer block. Defaults to 0. + attn_drop_rate (float): The dropout rate after the attention + calculation in the transformer block. Defaults to 0. + drop_path_rate (bloat): The drop path rate in both the conv block + and the transformer block. Defaults to 0. + last_fusion (bool): Whether this block is the last stage. If so, + downsample the fusion feature map. + init_cfg (dict, optional): The extra config to initialize the module. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + embed_dims, + conv_stride=1, + groups=1, + with_residual_conv=False, + down_stride=4, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + with_cls_token=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + last_fusion=False, + init_cfg=None): + super(ConvTransBlock, self).__init__(init_cfg=init_cfg) + expansion = 4 + self.cnn_block = ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + with_residual_conv=with_residual_conv, + stride=conv_stride, + groups=groups) + + if last_fusion: + self.fusion_block = ConvBlock( + in_channels=out_channels, + out_channels=out_channels, + stride=2, + with_residual_conv=True, + groups=groups, + drop_path_rate=drop_path_rate) + else: + self.fusion_block = ConvBlock( + in_channels=out_channels, + out_channels=out_channels, + groups=groups, + drop_path_rate=drop_path_rate) + + self.squeeze_block = FCUDown( + in_channels=out_channels // expansion, + out_channels=embed_dims, + down_stride=down_stride, + with_cls_token=with_cls_token) + + self.expand_block = FCUUp( + in_channels=embed_dims, + out_channels=out_channels // expansion, + up_stride=down_stride, + with_cls_token=with_cls_token) + + self.trans_block = TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(embed_dims * mlp_ratio), + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + qkv_bias=qkv_bias, + norm_cfg=dict(type='LN', eps=1e-6)) + + self.down_stride = down_stride + self.embed_dim = embed_dims + self.last_fusion = last_fusion + + def forward(self, cnn_input, trans_input): + x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True) + + _, _, H, W = x_conv2.shape + + # Convert the feature map of conv2 to transformer embedding + # and concat with class token. + conv2_embedding = self.squeeze_block(x_conv2, trans_input) + + trans_output = self.trans_block(conv2_embedding + trans_input) + + # Convert the transformer output embedding to feature map + trans_features = self.expand_block(trans_output, H // self.down_stride, + W // self.down_stride) + x = self.fusion_block( + x, fusion_features=trans_features, out_conv2=False) + + return x, trans_output + + +@MODELS.register_module() +class Conformer(BaseBackbone): + """Conformer backbone. + + A PyTorch implementation of : `Conformer: Local Features Coupling Global + Representations for Visual Recognition `_ + + Args: + arch (str | dict): Conformer architecture. Defaults to 'tiny'. + patch_size (int): The patch size. Defaults to 16. + base_channels (int): The base number of channels in CNN network. + Defaults to 64. + mlp_ratio (float): The expansion ratio of FFN network in transformer + block. Defaults to 4. + with_cls_token (bool): Whether use class token or not. + Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 384, + 'channel_ratio': 1, + 'num_heads': 6, + 'depths': 12 + }), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 384, + 'channel_ratio': 4, + 'num_heads': 6, + 'depths': 12 + }), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 576, + 'channel_ratio': 6, + 'num_heads': 9, + 'depths': 12 + }), + } # yapf: disable + + _version = 1 + + def __init__(self, + arch='tiny', + patch_size=16, + base_channels=64, + mlp_ratio=4., + qkv_bias=True, + with_cls_token=True, + drop_path_rate=0., + norm_eval=True, + frozen_stages=0, + out_indices=-1, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'channel_ratio' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.num_features = self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.channel_ratio = self.arch_settings['channel_ratio'] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.depths + index + 1 + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + self.with_cls_token = with_cls_token + if self.with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + # stochastic depth decay rule + self.trans_dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, self.depths) + ] + + # Stem stage: get the feature maps by conv block + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, + bias=False) # 1 / 2 [112, 112] + self.bn1 = nn.BatchNorm2d(64) + self.act1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56] + + assert patch_size % 16 == 0, 'The patch size of Conformer must ' \ + 'be divisible by 16.' + trans_down_stride = patch_size // 4 + + # To solve the issue #680 + # Auto pad the feature map to be divisible by trans_down_stride + self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride) + + # 1 stage + stage1_channels = int(base_channels * self.channel_ratio) + self.conv_1 = ConvBlock( + in_channels=64, + out_channels=stage1_channels, + with_residual_conv=True, + stride=1) + self.trans_patch_conv = nn.Conv2d( + 64, + self.embed_dims, + kernel_size=trans_down_stride, + stride=trans_down_stride, + padding=0) + + self.trans_1 = TransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=int(self.embed_dims * mlp_ratio), + drop_path_rate=self.trans_dpr[0], + qkv_bias=qkv_bias, + norm_cfg=dict(type='LN', eps=1e-6)) + + # 2~4 stage + init_stage = 2 + fin_stage = self.depths // 3 + 1 + for i in range(init_stage, fin_stage): + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=stage1_channels, + out_channels=stage1_channels, + embed_dims=self.embed_dims, + conv_stride=1, + with_residual_conv=False, + down_stride=trans_down_stride, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token)) + + stage2_channels = int(base_channels * self.channel_ratio * 2) + # 5~8 stage + init_stage = fin_stage # 5 + fin_stage = fin_stage + self.depths // 3 # 9 + for i in range(init_stage, fin_stage): + if i == init_stage: + conv_stride = 2 + in_channels = stage1_channels + else: + conv_stride = 1 + in_channels = stage2_channels + + with_residual_conv = True if i == init_stage else False + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=in_channels, + out_channels=stage2_channels, + embed_dims=self.embed_dims, + conv_stride=conv_stride, + with_residual_conv=with_residual_conv, + down_stride=trans_down_stride // 2, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token)) + + stage3_channels = int(base_channels * self.channel_ratio * 2 * 2) + # 9~12 stage + init_stage = fin_stage # 9 + fin_stage = fin_stage + self.depths // 3 # 13 + for i in range(init_stage, fin_stage): + if i == init_stage: + conv_stride = 2 + in_channels = stage2_channels + with_residual_conv = True + else: + conv_stride = 1 + in_channels = stage3_channels + with_residual_conv = False + + last_fusion = (i == self.depths) + + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=in_channels, + out_channels=stage3_channels, + embed_dims=self.embed_dims, + conv_stride=conv_stride, + with_residual_conv=with_residual_conv, + down_stride=trans_down_stride // 4, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token, + last_fusion=last_fusion)) + self.fin_stage = fin_stage + + self.pooling = nn.AdaptiveAvgPool2d(1) + self.trans_norm = nn.LayerNorm(self.embed_dims) + + if self.with_cls_token: + trunc_normal_(self.cls_token, std=.02) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def init_weights(self): + super(Conformer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + self.apply(self._init_weights) + + def forward(self, x): + output = [] + B = x.shape[0] + if self.with_cls_token: + cls_tokens = self.cls_token.expand(B, -1, -1) + + # stem + x_base = self.maxpool(self.act1(self.bn1(self.conv1(x)))) + x_base = self.auto_pad(x_base) + + # 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56] + x = self.conv_1(x_base, out_conv2=False) + x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2) + if self.with_cls_token: + x_t = torch.cat([cls_tokens, x_t], dim=1) + x_t = self.trans_1(x_t) + + # 2 ~ final + for i in range(2, self.fin_stage): + stage = getattr(self, f'conv_trans_{i}') + x, x_t = stage(x, x_t) + if i in self.out_indices: + if self.with_cls_token: + output.append([ + self.pooling(x).flatten(1), + self.trans_norm(x_t)[:, 0] + ]) + else: + # if no class token, use the mean patch token + # as the transformer feature. + output.append([ + self.pooling(x).flatten(1), + self.trans_norm(x_t).mean(dim=1) + ]) + + return tuple(output) diff --git a/mmpretrain/models/backbones/convmixer.py b/mmpretrain/models/backbones/convmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..480050d5ce1aa29f190dbc24ec1413573d541cb1 --- /dev/null +++ b/mmpretrain/models/backbones/convmixer.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer, + build_norm_layer) +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class Residual(nn.Module): + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +@MODELS.register_module() +class ConvMixer(BaseBackbone): + """ConvMixer. . + + A PyTorch implementation of : `Patches Are All You Need? + `_ + + Modified from the `official repo + `_ + and `timm + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvMixer.arch_settings``. And if dict, it + should include the following two keys: + + - embed_dims (int): The dimensions of patch embedding. + - depth (int): Number of repetitions of ConvMixer Layer. + - patch_size (int): The patch size. + - kernel_size (int): The kernel size of depthwise conv layers. + + Defaults to '768/32'. + in_channels (int): Number of input image channels. Defaults to 3. + patch_size (int): The size of one patch in the patch embed layer. + Defaults to 7. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='GELU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '768/32': { + 'embed_dims': 768, + 'depth': 32, + 'patch_size': 7, + 'kernel_size': 7 + }, + '1024/20': { + 'embed_dims': 1024, + 'depth': 20, + 'patch_size': 14, + 'kernel_size': 9 + }, + '1536/20': { + 'embed_dims': 1536, + 'depth': 20, + 'patch_size': 7, + 'kernel_size': 9 + }, + } + + def __init__(self, + arch='768/32', + in_channels=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = { + 'embed_dims', 'depth', 'patch_size', 'kernel_size' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.embed_dims = arch['embed_dims'] + self.depth = arch['depth'] + self.patch_size = arch['patch_size'] + self.kernel_size = arch['kernel_size'] + self.act = build_activation_layer(act_cfg) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.depth + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.embed_dims, + kernel_size=self.patch_size, + stride=self.patch_size), self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1]) + + # Set conv2d according to torch version + convfunc = nn.Conv2d + if digit_version(torch.__version__) < digit_version('1.9.0'): + convfunc = Conv2dAdaptivePadding + + # Repetitions of ConvMixer Layer + self.stages = nn.Sequential(*[ + nn.Sequential( + Residual( + nn.Sequential( + convfunc( + self.embed_dims, + self.embed_dims, + self.kernel_size, + groups=self.embed_dims, + padding='same'), self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1])), + nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1), + self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1]) + for _ in range(self.depth) + ]) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + outs.append(x) + + # x = self.pooling(x).flatten(1) + return tuple(outs) + + def train(self, mode=True): + super(ConvMixer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False diff --git a/mmpretrain/models/backbones/convnext.py b/mmpretrain/models/backbones/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..6a954f5b980186a86565a228669c6917bda14f68 --- /dev/null +++ b/mmpretrain/models/backbones/convnext.py @@ -0,0 +1,412 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import GRN, build_norm_layer +from .base_backbone import BaseBackbone + + +class ConvNeXtBlock(BaseModule): + """ConvNeXt Block. + + Args: + in_channels (int): The number of input channels. + dw_conv_cfg (dict): Config of depthwise convolution. + Defaults to ``dict(kernel_size=7, padding=3)``. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + mlp_ratio (float): The expansion ratio in both pointwise convolution. + Defaults to 4. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. More details can be found in the note. + Defaults to True. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + + Note: + There are two equivalent implementations: + + 1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU + -> Linear; Permute back + + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def __init__(self, + in_channels, + dw_conv_cfg=dict(kernel_size=7, padding=3), + norm_cfg=dict(type='LN2d', eps=1e-6), + act_cfg=dict(type='GELU'), + mlp_ratio=4., + linear_pw_conv=True, + drop_path_rate=0., + layer_scale_init_value=1e-6, + use_grn=False, + with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.depthwise_conv = nn.Conv2d( + in_channels, in_channels, groups=in_channels, **dw_conv_cfg) + + self.linear_pw_conv = linear_pw_conv + self.norm = build_norm_layer(norm_cfg, in_channels) + + mid_channels = int(mlp_ratio * in_channels) + if self.linear_pw_conv: + # Use linear layer to do pointwise conv. + pw_conv = nn.Linear + else: + pw_conv = partial(nn.Conv2d, kernel_size=1) + + self.pointwise_conv1 = pw_conv(in_channels, mid_channels) + self.act = MODELS.build(act_cfg) + self.pointwise_conv2 = pw_conv(mid_channels, in_channels) + + if use_grn: + self.grn = GRN(mid_channels) + else: + self.grn = None + + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((in_channels)), + requires_grad=True) if layer_scale_init_value > 0 else None + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class ConvNeXt(BaseBackbone): + """ConvNeXt v1&v2 backbone. + + A PyTorch implementation of `A ConvNet for the 2020s + `_ and + `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders + `_ + + Modified from the `official repo + `_ + and `timm + `_. + + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + arch_settings = { + 'atto': { + 'depths': [2, 2, 6, 2], + 'channels': [40, 80, 160, 320] + }, + 'femto': { + 'depths': [2, 2, 6, 2], + 'channels': [48, 96, 192, 384] + }, + 'pico': { + 'depths': [2, 2, 6, 2], + 'channels': [64, 128, 256, 512] + }, + 'nano': { + 'depths': [2, 2, 8, 2], + 'channels': [80, 160, 320, 640] + }, + 'tiny': { + 'depths': [3, 3, 9, 3], + 'channels': [96, 192, 384, 768] + }, + 'small': { + 'depths': [3, 3, 27, 3], + 'channels': [96, 192, 384, 768] + }, + 'base': { + 'depths': [3, 3, 27, 3], + 'channels': [128, 256, 512, 1024] + }, + 'large': { + 'depths': [3, 3, 27, 3], + 'channels': [192, 384, 768, 1536] + }, + 'xlarge': { + 'depths': [3, 3, 27, 3], + 'channels': [256, 512, 1024, 2048] + }, + 'huge': { + 'depths': [3, 3, 27, 3], + 'channels': [352, 704, 1408, 2816] + } + } + + def __init__(self, + arch='tiny', + in_channels=3, + stem_patch_size=4, + norm_cfg=dict(type='LN2d', eps=1e-6), + act_cfg=dict(type='GELU'), + linear_pw_conv=True, + use_grn=False, + drop_path_rate=0., + layer_scale_init_value=1e-6, + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + with_cp=False, + init_cfg=[ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + ConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + if i in self.out_indices: + norm_layer = build_norm_layer(norm_cfg, channels) + self.add_module(f'norm{i}', norm_layer) + + self._freeze_stages() + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap).flatten(1)) + else: + outs.append(norm_layer(x)) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.downsample_layers[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(ConvNeXt, self).train(mode) + self._freeze_stages() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + + max_layer_id = 12 if self.depths[-2] > 9 else 6 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id + 1, max_layer_id + 2 + + param_name = param_name[len(prefix):] + if param_name.startswith('downsample_layers'): + stage_id = int(param_name.split('.')[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + else: # stage_id == 3: + layer_id = max_layer_id + + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = int(param_name.split('.')[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + else: # stage_id == 3: + layer_id = max_layer_id + + # final norm layer + else: + layer_id = max_layer_id + 1 + + return layer_id, max_layer_id + 2 diff --git a/mmpretrain/models/backbones/cspnet.py b/mmpretrain/models/backbones/cspnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7492e97702c28861dcce2808207a35e67f32f752 --- /dev/null +++ b/mmpretrain/models/backbones/cspnet.py @@ -0,0 +1,679 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import to_ntuple +from .resnet import Bottleneck as ResNetBottleneck +from .resnext import Bottleneck as ResNeXtBottleneck + +eps = 1.0e-5 + + +class DarknetBottleneck(BaseModule): + """The basic bottleneck block used in Darknet. Each DarknetBottleneck + consists of two ConvModules and the input is added to the final output. + Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer + has filter size of 1x1 and the second one has the filter size of 3x3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. + Defaults to 4. + add_identity (bool): Whether to add identity to the out. + Defaults to True. + use_depthwise (bool): Whether to use depthwise separable convolution. + Defaults to False. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + drop_path_rate (float): The ratio of the drop path layer. Default: 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN', eps=1e-5)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='Swish')``. + """ + + def __init__(self, + in_channels, + out_channels, + expansion=2, + add_identity=True, + use_depthwise=False, + conv_cfg=None, + drop_path_rate=0, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + init_cfg=None): + super().__init__(init_cfg) + hidden_channels = int(out_channels / expansion) + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + self.conv1 = ConvModule( + in_channels, + hidden_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = conv( + hidden_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.add_identity = \ + add_identity and in_channels == out_channels + + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.drop_path(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPStage(BaseModule): + """Cross Stage Partial Stage. + + .. code:: text + + Downsample Convolution (optional) + | + | + Expand Convolution + | + | + Split to xa, xb + | \ + | \ + | blocks(xb) + | / + | / transition + | / + Concat xa, blocks(xb) + | + Transition Convolution + + Args: + block_fn (nn.module): The basic block function in the Stage. + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + has_downsampler (bool): Whether to add a downsampler in the stage. + Default: False. + down_growth (bool): Whether to expand the channels in the + downsampler layer of the stage. Default: False. + expand_ratio (float): The expand ratio to adjust the number of + channels of the expand conv layer. Default: 0.5 + bottle_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + block_dpr (float): The ratio of the drop path layer in the + blocks of the stage. Default: 0. + num_blocks (int): Number of blocks. Default: 1 + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', inplace=True) + """ + + def __init__(self, + block_fn, + in_channels, + out_channels, + has_downsampler=True, + down_growth=False, + expand_ratio=0.5, + bottle_ratio=2, + num_blocks=1, + block_dpr=0, + block_args={}, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + init_cfg=None): + super().__init__(init_cfg) + # grow downsample channels to output channels + down_channels = out_channels if down_growth else in_channels + block_dpr = to_ntuple(num_blocks)(block_dpr) + + if has_downsampler: + self.downsample_conv = ConvModule( + in_channels=in_channels, + out_channels=down_channels, + kernel_size=3, + stride=2, + padding=1, + groups=32 if block_fn is ResNeXtBottleneck else 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.downsample_conv = nn.Identity() + + exp_channels = int(down_channels * expand_ratio) + self.expand_conv = ConvModule( + in_channels=down_channels, + out_channels=exp_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg if block_fn is DarknetBottleneck else None) + + assert exp_channels % 2 == 0, \ + 'The channel number before blocks must be divisible by 2.' + block_channels = exp_channels // 2 + blocks = [] + for i in range(num_blocks): + block_cfg = dict( + in_channels=block_channels, + out_channels=block_channels, + expansion=bottle_ratio, + drop_path_rate=block_dpr[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **block_args) + blocks.append(block_fn(**block_cfg)) + self.blocks = Sequential(*blocks) + self.atfer_blocks_conv = ConvModule( + block_channels, + block_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.final_conv = ConvModule( + 2 * block_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.downsample_conv(x) + x = self.expand_conv(x) + + split = x.shape[1] // 2 + xa, xb = x[:, :split], x[:, split:] + + xb = self.blocks(xb) + xb = self.atfer_blocks_conv(xb).contiguous() + + x_final = torch.cat((xa, xb), dim=1) + return self.final_conv(x_final) + + +class CSPNet(BaseModule): + """The abstract CSP Network class. + + A Pytorch implementation of `CSPNet: A New Backbone that can Enhance + Learning Capability of CNN `_ + + This class is an abstract class because the Cross Stage Partial Network + (CSPNet) is a kind of universal network structure, and you + network block to implement networks like CSPResNet, CSPResNeXt and + CSPDarkNet. + + Args: + arch (dict): The architecture of the CSPNet. + It should have the following keys: + + - block_fn (Callable): A function or class to return a block + module, and it should accept at least ``in_channels``, + ``out_channels``, ``expansion``, ``drop_path_rate``, ``norm_cfg`` + and ``act_cfg``. + - in_channels (Tuple[int]): The number of input channels of each + stage. + - out_channels (Tuple[int]): The number of output channels of each + stage. + - num_blocks (Tuple[int]): The number of blocks in each stage. + - expansion_ratio (float | Tuple[float]): The expansion ratio in + the expand convolution of each stage. Defaults to 0.5. + - bottle_ratio (float | Tuple[float]): The expansion ratio of + blocks in each stage. Defaults to 2. + - has_downsampler (bool | Tuple[bool]): Whether to add a + downsample convolution in each stage. Defaults to True + - down_growth (bool | Tuple[bool]): Whether to expand the channels + in the downsampler layer of each stage. Defaults to False. + - block_args (dict | Tuple[dict], optional): The extra arguments to + the blocks in each stage. Defaults to None. + + stem_fn (Callable): A function or class to return a stem module. + And it should accept ``in_channels``. + in_channels (int): Number of input image channels. Defaults to 3. + out_indices (int | Sequence[int]): Output from which stages. + Defaults to -1, which means the last stage. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict, optional): The config dict for conv layers in blocks. + Defaults to None, which means use Conv2d. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN', eps=1e-5)``. + act_cfg (dict): The config dict for activation functions. + Defaults to ``dict(type='LeakyReLU', inplace=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict, optional): The initialization settings. + Defaults to ``dict(type='Kaiming', layer='Conv2d'))``. + + Example: + >>> from functools import partial + >>> import torch + >>> import torch.nn as nn + >>> from mmpretrain.models import CSPNet + >>> from mmpretrain.models.backbones.resnet import Bottleneck + >>> + >>> # A simple example to build CSPNet. + >>> arch = dict( + ... block_fn=Bottleneck, + ... in_channels=[32, 64], + ... out_channels=[64, 128], + ... num_blocks=[3, 4] + ... ) + >>> stem_fn = partial(nn.Conv2d, out_channels=32, kernel_size=3) + >>> model = CSPNet(arch=arch, stem_fn=stem_fn, out_indices=(0, 1)) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outs = model(inputs) + >>> for out in outs: + ... print(out.shape) + ... + (1, 64, 111, 111) + (1, 128, 56, 56) + """ + + def __init__(self, + arch, + stem_fn, + in_channels=3, + out_indices=-1, + frozen_stages=-1, + drop_path_rate=0., + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict(type='Kaiming', layer='Conv2d')): + super().__init__(init_cfg=init_cfg) + self.arch = self.expand_arch(arch) + self.num_stages = len(self.arch['in_channels']) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + if frozen_stages not in range(-1, self.num_stages): + raise ValueError('frozen_stages must be in range(-1, ' + f'{self.num_stages}). But received ' + f'{frozen_stages}') + self.frozen_stages = frozen_stages + + self.stem = stem_fn(in_channels) + + stages = [] + depths = self.arch['num_blocks'] + dpr = torch.linspace(0, drop_path_rate, sum(depths)).split(depths) + + for i in range(self.num_stages): + stage_cfg = {k: v[i] for k, v in self.arch.items()} + csp_stage = CSPStage( + **stage_cfg, + block_dpr=dpr[i].tolist(), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=init_cfg) + stages.append(csp_stage) + self.stages = Sequential(*stages) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + @staticmethod + def expand_arch(arch): + num_stages = len(arch['in_channels']) + + def to_tuple(x, name=''): + if isinstance(x, (list, tuple)): + assert len(x) == num_stages, \ + f'The length of {name} ({len(x)}) does not ' \ + f'equals to the number of stages ({num_stages})' + return tuple(x) + else: + return (x, ) * num_stages + + full_arch = {k: to_tuple(v, k) for k, v in arch.items()} + if 'block_args' not in full_arch: + full_arch['block_args'] = to_tuple({}) + return full_arch + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(CSPNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def forward(self, x): + outs = [] + + x = self.stem(x) + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + +@MODELS.register_module() +class CSPDarkNet(CSPNet): + """CSP-Darknet backbone used in YOLOv4. + + Args: + depth (int): Depth of CSP-Darknet. Default: 53. + in_channels (int): Number of input image channels. Default: 3. + out_indices (Sequence[int]): Output from which stages. + Default: (3, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmpretrain.models import CSPDarkNet + >>> import torch + >>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 64, 208, 208) + (1, 128, 104, 104) + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + arch_settings = { + 53: + dict( + block_fn=DarknetBottleneck, + in_channels=(32, 64, 128, 256, 512), + out_channels=(64, 128, 256, 512, 1024), + num_blocks=(1, 2, 8, 8, 4), + expand_ratio=(2, 1, 1, 1, 1), + bottle_ratio=(2, 1, 1, 1, 1), + has_downsampler=True, + down_growth=True, + ), + } + + def __init__(self, + depth, + in_channels=3, + out_indices=(4, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu')): + + assert depth in self.arch_settings, 'depth must be one of ' \ + f'{list(self.arch_settings.keys())}, but get {depth}.' + + super().__init__( + arch=self.arch_settings[depth], + stem_fn=self._make_stem_layer, + in_channels=in_channels, + out_indices=out_indices, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + norm_eval=norm_eval, + init_cfg=init_cfg) + + def _make_stem_layer(self, in_channels): + """using a stride=1 conv as the stem in CSPDarknet.""" + # `stem_channels` equals to the `in_channels` in the first stage. + stem_channels = self.arch['in_channels'][0] + stem = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + return stem + + +@MODELS.register_module() +class CSPResNet(CSPNet): + """CSP-ResNet backbone. + + Args: + depth (int): Depth of CSP-ResNet. Default: 50. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmpretrain.models import CSPResNet + >>> import torch + >>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 128, 104, 104) + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + arch_settings = { + 50: + dict( + block_fn=ResNetBottleneck, + in_channels=(64, 128, 256, 512), + out_channels=(128, 256, 512, 1024), + num_blocks=(3, 3, 5, 2), + expand_ratio=4, + bottle_ratio=2, + has_downsampler=(False, True, True, True), + down_growth=False), + } + + def __init__(self, + depth, + in_channels=3, + out_indices=(3, ), + frozen_stages=-1, + deep_stem=False, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict(type='Kaiming', layer='Conv2d')): + assert depth in self.arch_settings, 'depth must be one of ' \ + f'{list(self.arch_settings.keys())}, but get {depth}.' + self.deep_stem = deep_stem + + super().__init__( + arch=self.arch_settings[depth], + stem_fn=self._make_stem_layer, + in_channels=in_channels, + out_indices=out_indices, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + norm_eval=norm_eval, + init_cfg=init_cfg) + + def _make_stem_layer(self, in_channels): + # `stem_channels` equals to the `in_channels` in the first stage. + stem_channels = self.arch['in_channels'][0] + if self.deep_stem: + stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + else: + stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + return stem + + +@MODELS.register_module() +class CSPResNeXt(CSPResNet): + """CSP-ResNeXt backbone. + + Args: + depth (int): Depth of CSP-ResNeXt. Default: 50. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmpretrain.models import CSPResNeXt + >>> import torch + >>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + arch_settings = { + 50: + dict( + block_fn=ResNeXtBottleneck, + in_channels=(64, 256, 512, 1024), + out_channels=(256, 512, 1024, 2048), + num_blocks=(3, 3, 5, 2), + expand_ratio=(4, 2, 2, 2), + bottle_ratio=4, + has_downsampler=(False, True, True, True), + down_growth=False, + # the base_channels is changed from 64 to 32 in CSPNet + block_args=dict(base_channels=32), + ), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/mmpretrain/models/backbones/davit.py b/mmpretrain/models/backbones/davit.py new file mode 100644 index 0000000000000000000000000000000000000000..cf25e2ed7137fb403e38801b50b355c4306331d6 --- /dev/null +++ b/mmpretrain/models/backbones/davit.py @@ -0,0 +1,834 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn.bricks import Conv2d +from mmcv.cnn.bricks.transformer import FFN, AdaptivePadding, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils import to_2tuple +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import ShiftWindowMSA + + +class DaViTWindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module for DaViT. + + The differences between DaViTWindowMSA & WindowMSA: + 1. Without relative position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ConvPosEnc(BaseModule): + """DaViT conv pos encode block. + + Args: + embed_dims (int): Number of input channels. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, embed_dims, kernel_size=3, init_cfg=None): + super(ConvPosEnc, self).__init__(init_cfg) + self.proj = Conv2d( + embed_dims, + embed_dims, + kernel_size, + stride=1, + padding=kernel_size // 2, + groups=embed_dims) + + def forward(self, x, size: Tuple[int, int]): + B, N, C = x.shape + H, W = size + assert N == H * W + + feat = x.transpose(1, 2).view(B, C, H, W) + feat = self.proj(feat) + feat = feat.flatten(2).transpose(1, 2) + x = x + feat + return x + + +class DaViTDownSample(BaseModule): + """DaViT down sampole block. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_type (str): The type of convolution + to generate patch embedding. Default: "Conv2d". + kernel_size (int): The kernel size of the first convolution. + Defaults to 2. + stride (int): The stride of the second convluation module. + Defaults to 2. + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Defaults to "corner". + dilation (int): Dilation of the convolution layers. Defaults to 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_type='Conv2d', + kernel_size=2, + stride=2, + padding='same', + dilation=1, + bias=True, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.out_channels = out_channels + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adaptive_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.norm = None + + def forward(self, x, input_size): + if self.adaptive_padding: + x = self.adaptive_padding(x) + H, W = input_size + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = self.norm(x) + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + + x = self.projection(x) + output_size = (x.size(2), x.size(3)) + x = x.flatten(2).transpose(1, 2) + return x, output_size + + +class ChannelAttention(BaseModule): + """DaViT channel attention. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, embed_dims, num_heads=8, qkv_bias=False, init_cfg=None): + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = self.head_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dims, embed_dims) + + def forward(self, x): + B, N, _ = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + x = self.proj(x) + return x + + +class ChannelBlock(BaseModule): + """DaViT channel attention block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + ffn_ratio=4., + qkv_bias=False, + drop_path=0., + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) + self.with_cp = with_cp + + self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ChannelAttention( + embed_dims, num_heads=num_heads, qkv_bias=qkv_bias) + self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.cpe1(x, hw_shape) + identity = x + x = self.norm1(x) + x = self.attn(x) + x = x + identity + + x = self.cpe2(x, hw_shape) + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SpatialBlock(BaseModule): + """DaViT spatial attention block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SpatialBlock, self).__init__(init_cfg) + self.with_cp = with_cp + + self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'qkv_bias': qkv_bias, + 'pad_small_map': pad_small_map, + 'window_msa': DaViTWindowMSA, + **attn_cfgs + } + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.cpe1(x, hw_shape) + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + x = x + identity + + x = self.cpe2(x, hw_shape) + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class DaViTBlock(BaseModule): + """DaViT block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(DaViTBlock, self).__init__(init_cfg) + self.spatial_block = SpatialBlock( + embed_dims, + num_heads, + window_size=window_size, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path, + pad_small_map=pad_small_map, + attn_cfgs=attn_cfgs, + ffn_cfgs=ffn_cfgs, + norm_cfg=norm_cfg, + with_cp=with_cp) + self.channel_block = ChannelBlock( + embed_dims, + num_heads, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path, + ffn_cfgs=ffn_cfgs, + norm_cfg=norm_cfg, + with_cp=False) + + def forward(self, x, hw_shape): + x = self.spatial_block(x, hw_shape) + x = self.channel_block(x, hw_shape) + + return x + + +class DaViTBlockSequence(BaseModule): + """Module with successive DaViT blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive DaViT blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + self.embed_dims = embed_dims + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'window_size': window_size, + 'ffn_ratio': ffn_ratio, + 'qkv_bias': qkv_bias, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **block_cfgs[i] + } + block = DaViTBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': 2 * embed_dims, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = DaViTDownSample(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x, in_shape, do_downsample=True): + for block in self.blocks: + x = block(x, in_shape) + + if self.downsample is not None and do_downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + return x, out_shape + + @property + def out_channels(self): + if self.downsample: + return self.downsample.out_channels + else: + return self.embed_dims + + +@MODELS.register_module() +class DaViT(BaseBackbone): + """DaViT. + + A PyTorch implement of : `DaViT: Dual Attention Vision Transformers + `_ + + Inspiration from + https://github.com/dingmyu/davit + + Args: + arch (str | dict): DaViT architecture. If use string, choose from + 'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 't'. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], { + 'embed_dims': 96, + 'depths': [1, 1, 3, 1], + 'num_heads': [3, 6, 12, 24] + }), + **dict.fromkeys(['s', 'small'], { + 'embed_dims': 96, + 'depths': [1, 1, 9, 1], + 'num_heads': [3, 6, 12, 24] + }), + **dict.fromkeys(['b', 'base'], { + 'embed_dims': 128, + 'depths': [1, 1, 9, 1], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [1, 1, 9, 1], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 256, + 'depths': [1, 1, 9, 1], + 'num_heads': [8, 16, 32, 64] + }), + **dict.fromkeys( + ['g', 'giant'], { + 'embed_dims': 384, + 'depths': [1, 1, 12, 3], + 'num_heads': [12, 24, 48, 96] + }), + } + + def __init__(self, + arch='t', + patch_size=4, + in_channels=3, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + out_after_downsample=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + frozen_stages=-1, + norm_eval=False, + out_indices=(3, ), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.out_after_downsample = out_after_downsample + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + # stochastic depth decay rule + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + _patch_cfg = dict( + in_channels=in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=7, + stride=patch_size, + padding='same', + norm_cfg=dict(type='LN'), + ) + self.patch_embed = PatchEmbed(**_patch_cfg) + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': window_size, + 'ffn_ratio': ffn_ratio, + 'qkv_bias': qkv_bias, + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **stage_cfg + } + + stage = DaViTBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + self.num_features = embed_dims[:-1] + + # add a norm layer for each output + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/deit.py b/mmpretrain/models/backbones/deit.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae340829bece31536d0c0ac119ffe635bce82e0 --- /dev/null +++ b/mmpretrain/models/backbones/deit.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .vision_transformer import VisionTransformer + + +@MODELS.register_module() +class DistilledVisionTransformer(VisionTransformer): + """Distilled Vision Transformer. + + A PyTorch implement of : `Training data-efficient image transformers & + distillation through attention `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'deit-base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: A tuple with the class token and the + distillation token. The shapes of both tensor are (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + num_extra_tokens = 2 # class token and distillation token + + def __init__(self, arch='deit-base', *args, **kwargs): + super(DistilledVisionTransformer, self).__init__( + arch=arch, + with_cls_token=True, + *args, + **kwargs, + ) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + x = x + self.resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'cls_token': + return x[:, 0], x[:, 1] + + return super()._format_output(x, hw) + + def init_weights(self): + super(DistilledVisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.dist_token, std=0.02) diff --git a/mmpretrain/models/backbones/deit3.py b/mmpretrain/models/backbones/deit3.py new file mode 100644 index 0000000000000000000000000000000000000000..acedabe42d66a8073f34b1b0ae87501522fcc1b5 --- /dev/null +++ b/mmpretrain/models/backbones/deit3.py @@ -0,0 +1,454 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +from mmcv.cnn import Linear, build_activation_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils import deprecated_api_warning +from torch import nn + +from mmpretrain.registry import MODELS +from ..utils import (LayerScale, MultiheadAttention, build_norm_layer, + resize_pos_embed, to_2tuple) +from .vision_transformer import VisionTransformer + + +class DeiT3FFN(BaseModule): + """FFN for DeiT3. + + The differences between DeiT3FFN & FFN: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3FFN. Defaults to True. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning( + { + 'dropout': 'ffn_drop', + 'add_residual': 'add_identity' + }, + cls_name='FFN') + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0., + dropout_layer=None, + add_identity=True, + use_layer_scale=True, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + if use_layer_scale: + self.gamma2 = LayerScale(embed_dims) + else: + self.gamma2 = nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class DeiT3TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in DeiT3. + + The differences between DeiT3TransformerEncoderLayer & + TransformerEncoderLayer: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3TransformerEncoderLayer. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + use_layer_scale=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + use_layer_scale=use_layer_scale) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = DeiT3FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + use_layer_scale=use_layer_scale) + + def init_weights(self): + super(DeiT3TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln1(x), identity=x) + return x + + +@MODELS.register_module() +class DeiT3(VisionTransformer): + """DeiT3 backbone. + + A PyTorch implement of : `DeiT III: Revenge of the ViT + `_ + + The differences between DeiT3 & VisionTransformer: + + 1. Use LayerScale. + 2. Concat cls token after adding pos_embed. + + Args: + arch (str | dict): DeiT3 architecture. If use string, + choose from 'small', 'base', 'medium', 'large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in DeiT3. + Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 1536, + }), + **dict.fromkeys( + ['m', 'medium'], { + 'embed_dims': 512, + 'num_layers': 12, + 'num_heads': 8, + 'feedforward_channels': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + } + num_extra_tokens = 1 # class token + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + use_layer_scale=True, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + use_layer_scale=use_layer_scale) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg)) + + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=0) + x = self.drop_after_pos(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1]))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed( + state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + num_extra_tokens=0, # The cls token adding is after pos_embed + ) diff --git a/mmpretrain/models/backbones/densenet.py b/mmpretrain/models/backbones/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f05302f9b84cd38c7c03701fc21ffd109c1620 --- /dev/null +++ b/mmpretrain/models/backbones/densenet.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import build_activation_layer, build_norm_layer +from torch.jit.annotations import List + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class DenseLayer(BaseBackbone): + """DenseBlock layers.""" + + def __init__(self, + in_channels, + growth_rate, + bn_size, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseLayer, self).__init__() + + self.norm1 = build_norm_layer(norm_cfg, in_channels)[1] + self.conv1 = nn.Conv2d( + in_channels, + bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False) + self.act = build_activation_layer(act_cfg) + self.norm2 = build_norm_layer(norm_cfg, bn_size * growth_rate)[1] + self.conv2 = nn.Conv2d( + bn_size * growth_rate, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1( + self.act(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + # This decorator indicates to the compiler that a function or method + # should be ignored and replaced with the raising of an exception. + # Here this function is incompatible with torchscript. + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + # Here use torch.utils.checkpoint to rerun a forward-pass during + # backward in bottleneck to save memories. + return cp.checkpoint(closure, *x) + + def forward(self, x): # noqa: F811 + # type: (List[torch.Tensor]) -> torch.Tensor + # assert input features is a list of Tensor + assert isinstance(x, list) + + if self.memory_efficient and self.any_requires_grad(x): + if torch.jit.is_scripting(): + raise Exception('Memory Efficient not supported in JIT') + bottleneck_output = self.call_checkpoint_bottleneck(x) + else: + bottleneck_output = self.bottleneck_fn(x) + + new_features = self.conv2(self.act(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout( + new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.Module): + """DenseNet Blocks.""" + + def __init__(self, + num_layers, + in_channels, + bn_size, + growth_rate, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseBlock, self).__init__() + self.block = nn.ModuleList([ + DenseLayer( + in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) for i in range(num_layers) + ]) + + def forward(self, init_features): + features = [init_features] + for layer in self.block: + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + """DenseNet Transition Layers.""" + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')): + super(DenseTransition, self).__init__() + self.add_module('norm', build_norm_layer(norm_cfg, in_channels)[1]) + self.add_module('act', build_activation_layer(act_cfg)) + self.add_module( + 'conv', + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, + bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +@MODELS.register_module() +class DenseNet(BaseBackbone): + """DenseNet. + + A PyTorch implementation of : `Densely Connected Convolutional Networks + `_ + + Modified from the `official repo + `_ + and `pytorch + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``DenseNet.arch_settings``. And if dict, it + should include the following two keys: + + - growth_rate (int): Each layer of DenseBlock produce `k` feature + maps. Here refers `k` as the growth rate of the network. + - depths (list[int]): Number of repeated layers in each DenseBlock. + - init_channels (int): The output channels of stem layers. + + Defaults to '121'. + in_channels (int): Number of input image channels. Defaults to 3. + bn_size (int): Refers to channel expansion parameter of 1x1 + convolution layer. Defaults to 4. + drop_rate (float): Drop rate of Dropout Layer. Defaults to 0. + compression_factor (float): The reduction rate of transition layers. + Defaults to 0.5. + memory_efficient (bool): If True, uses checkpointing. Much more memory + efficient, but slower. Defaults to False. + See `"paper" `_. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='ReLU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '121': { + 'growth_rate': 32, + 'depths': [6, 12, 24, 16], + 'init_channels': 64, + }, + '169': { + 'growth_rate': 32, + 'depths': [6, 12, 32, 32], + 'init_channels': 64, + }, + '201': { + 'growth_rate': 32, + 'depths': [6, 12, 48, 32], + 'init_channels': 64, + }, + '161': { + 'growth_rate': 48, + 'depths': [6, 12, 36, 24], + 'init_channels': 96, + }, + } + + def __init__(self, + arch='121', + in_channels=3, + bn_size=4, + drop_rate=0, + compression_factor=0.5, + memory_efficient=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'growth_rate', 'depths', 'init_channels'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.growth_rate = arch['growth_rate'] + self.depths = arch['depths'] + self.init_channels = arch['init_channels'] + self.act = build_activation_layer(act_cfg) + + self.num_stages = len(self.depths) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.init_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False), + build_norm_layer(norm_cfg, self.init_channels)[1], self.act, + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Repetitions of DenseNet Blocks + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + + channels = self.init_channels + for i in range(self.num_stages): + depth = self.depths[i] + + stage = DenseBlock( + num_layers=depth, + in_channels=channels, + bn_size=bn_size, + growth_rate=self.growth_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) + self.stages.append(stage) + channels += depth * self.growth_rate + + if i != self.num_stages - 1: + transition = DenseTransition( + in_channels=channels, + out_channels=math.floor(channels * compression_factor), + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + channels = math.floor(channels * compression_factor) + else: + # Final layers after dense block is just bn with act. + # Unlike the paper, the original repo also put this in + # transition layer, whereas torchvision take this out. + # We reckon this as transition layer here. + transition = nn.Sequential( + build_norm_layer(norm_cfg, channels)[1], + self.act, + ) + self.transitions.append(transition) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i in range(self.num_stages): + x = self.stages[i](x) + x = self.transitions[i](x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.transitions[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(DenseNet, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/edgenext.py b/mmpretrain/models/backbones/edgenext.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4e768e7561eb49da3603f4394faaebed7c9251 --- /dev/null +++ b/mmpretrain/models/backbones/edgenext.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier, + build_norm_layer) +from .base_backbone import BaseBackbone +from .convnext import ConvNeXtBlock + + +class SDTAEncoder(BaseModule): + """A PyTorch implementation of split depth-wise transpose attention (SDTA) + encoder. + + Inspiration from + https://github.com/mmaaz60/EdgeNeXt + Args: + in_channel (int): Number of input channels. + drop_path_rate (float): Stochastic depth dropout rate. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + mlp_ratio (int): Number of channels ratio in the MLP. + Defaults to 4. + use_pos_emb (bool): Whether to use position encoding. + Defaults to True. + num_heads (int): Number of heads in the multihead attention. + Defaults to 8. + qkv_bias (bool): Whether to use bias in the multihead attention. + Defaults to True. + attn_drop (float): Dropout rate of the attention. + Defaults to 0. + proj_drop (float): Dropout rate of the projection. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + norm_cfg (dict): Dictionary to construct normalization layer. + Defaults to ``dict(type='LN')``. + act_cfg (dict): Dictionary to construct activation layer. + Defaults to ``dict(type='GELU')``. + scales (int): Number of scales. Default to 1. + """ + + def __init__(self, + in_channel, + drop_path_rate=0., + layer_scale_init_value=1e-6, + mlp_ratio=4, + use_pos_emb=True, + num_heads=8, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + scales=1, + init_cfg=None): + super(SDTAEncoder, self).__init__(init_cfg=init_cfg) + conv_channels = max( + int(math.ceil(in_channel / scales)), + int(math.floor(in_channel // scales))) + self.conv_channels = conv_channels + self.num_convs = scales if scales == 1 else scales - 1 + + self.conv_modules = ModuleList() + for i in range(self.num_convs): + self.conv_modules.append( + nn.Conv2d( + conv_channels, + conv_channels, + kernel_size=3, + padding=1, + groups=conv_channels)) + + self.pos_embed = PositionEncodingFourier( + embed_dims=in_channel) if use_pos_emb else None + + self.norm_csa = build_norm_layer(norm_cfg, in_channel) + self.gamma_csa = nn.Parameter( + layer_scale_init_value * torch.ones(in_channel), + requires_grad=True) if layer_scale_init_value > 0 else None + self.csa = ChannelMultiheadAttention( + embed_dims=in_channel, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop) + + self.norm = build_norm_layer(norm_cfg, in_channel) + self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel) + self.act = MODELS.build(act_cfg) + self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones(in_channel), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + spx = torch.split(x, self.conv_channels, dim=1) + for i in range(self.num_convs): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.conv_modules[i](sp) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + x = torch.cat((out, spx[self.num_convs]), 1) + + # Channel Self-attention + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embed: + pos_encoding = self.pos_embed((B, H, W)) + pos_encoding = pos_encoding.reshape(B, -1, + x.shape[1]).permute(0, 2, 1) + x += pos_encoding + + x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.pointwise_conv1(x) + x = self.act(x) + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + + x = shortcut + self.drop_path(x) + + return x + + +@MODELS.register_module() +class EdgeNeXt(BaseBackbone): + """EdgeNeXt. + + A PyTorch implementation of: `EdgeNeXt: Efficiently Amalgamated + CNN-Transformer Architecture for Mobile Vision Applications + `_ + + Inspiration from + https://github.com/mmaaz60/EdgeNeXt + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architectures in ``EdgeNeXt.arch_settings``. + And if dict, it should include the following keys: + + - channels (list[int]): The number of channels at each stage. + - depths (list[int]): The number of blocks at each stage. + - num_heads (list[int]): The number of heads at each stage. + + Defaults to 'xxsmall'. + in_channels (int): The number of input channels. + Defaults to 3. + global_blocks (list[int]): The number of global blocks. + Defaults to [0, 1, 1, 1]. + global_block_type (list[str]): The type of global blocks. + Defaults to ['None', 'SDTA', 'SDTA', 'SDTA']. + drop_path_rate (float): Stochastic depth dropout rate. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to False. + mlp_ratio (int): The number of channel ratio in MLP layers. + Defaults to 4. + conv_kernel_size (list[int]): The kernel size of convolutional layers + at each stage. Defaults to [3, 5, 7, 9]. + use_pos_embd_csa (list[bool]): Whether to use positional embedding in + Channel Self-Attention. Defaults to [False, True, False, False]. + use_pos_emebd_global (bool): Whether to use positional embedding for + whole network. Defaults to False. + d2_scales (list[int]): The number of channel groups used for SDTA at + each stage. Defaults to [2, 2, 3, 4]. + norm_cfg (dict): The config of normalization layer. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. Defaults to True. + act_cfg (dict): The config of activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict, optional): Config for initialization. + Defaults to None. + """ + arch_settings = { + 'xxsmall': { # parameters: 1.3M + 'channels': [24, 48, 88, 168], + 'depths': [2, 2, 6, 2], + 'num_heads': [4, 4, 4, 4] + }, + 'xsmall': { # parameters: 2.3M + 'channels': [32, 64, 100, 192], + 'depths': [3, 3, 9, 3], + 'num_heads': [4, 4, 4, 4] + }, + 'small': { # parameters: 5.6M + 'channels': [48, 96, 160, 304], + 'depths': [3, 3, 9, 3], + 'num_heads': [8, 8, 8, 8] + }, + 'base': { # parameters: 18.51M + 'channels': [80, 160, 288, 584], + 'depths': [3, 3, 9, 3], + 'num_heads': [8, 8, 8, 8] + }, + } + + def __init__(self, + arch='xxsmall', + in_channels=3, + global_blocks=[0, 1, 1, 1], + global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], + drop_path_rate=0., + layer_scale_init_value=1e-6, + linear_pw_conv=True, + mlp_ratio=4, + conv_kernel_sizes=[3, 5, 7, 9], + use_pos_embd_csa=[False, True, False, False], + use_pos_embd_global=False, + d2_scales=[2, 2, 3, 4], + norm_cfg=dict(type='LN2d', eps=1e-6), + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + act_cfg=dict(type='GELU'), + init_cfg=None): + super(EdgeNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Arch {arch} is not in default archs ' \ + f'{set(self.arch_settings)}' + self.arch_settings = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'channels', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.channels = self.arch_settings['channels'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.use_pos_embd_global = use_pos_embd_global + + for g in global_block_type: + assert g in ['None', + 'SDTA'], f'Global block type {g} is not supported' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + + if self.use_pos_embd_global: + self.pos_embed = PositionEncodingFourier( + embed_dims=self.channels[0]) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + self.stages = ModuleList() + block_idx = 0 + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2, + )) + self.downsample_layers.append(downsample_layer) + + stage_blocks = [] + for j in range(depth): + if j > depth - global_blocks[i] - 1: + stage_blocks.append( + SDTAEncoder( + in_channel=channels, + drop_path_rate=dpr[block_idx + j], + mlp_ratio=mlp_ratio, + scales=d2_scales[i], + use_pos_emb=use_pos_embd_csa[i], + num_heads=self.num_heads[i], + )) + else: + dw_conv_cfg = dict( + kernel_size=conv_kernel_sizes[i], + padding=conv_kernel_sizes[i] // 2, + ) + stage_blocks.append( + ConvNeXtBlock( + in_channels=channels, + dw_conv_cfg=dw_conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + drop_path_rate=dpr[block_idx + j], + layer_scale_init_value=layer_scale_init_value, + )) + block_idx += depth + + stage_blocks = Sequential(*stage_blocks) + self.stages.append(stage_blocks) + + if i in self.out_indices: + out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \ + else norm_cfg + norm_layer = build_norm_layer(out_norm_cfg, channels) + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self) -> None: + # TODO: need to be implemented in the future + return super().init_weights() + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if self.pos_embed and i == 0: + B, _, H, W = x.shape + x += self.pos_embed((B, H, W)) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap.flatten(1))) + else: + # The output of LayerNorm2d may be discontiguous, which + # may cause some problem in the downstream tasks + outs.append(norm_layer(x).contiguous()) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.downsample_layers[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(EdgeNeXt, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/efficientformer.py b/mmpretrain/models/backbones/efficientformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c2525c8faaa745ff5404e91004421f2360dd1c41 --- /dev/null +++ b/mmpretrain/models/backbones/efficientformer.py @@ -0,0 +1,606 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer, + build_norm_layer) +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import LayerScale +from .base_backbone import BaseBackbone +from .poolformer import Pooling + + +class AttentionWithBias(BaseModule): + """Multi-head Attention Module with attention_bias. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + key_dim (int): The dimension of q, k. Defaults to 32. + attn_ratio (float): The dimension of v equals to + ``key_dim * attn_ratio``. Defaults to 4. + resolution (int): The height and width of attention_bias. + Defaults to 7. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + key_dim=32, + attn_ratio=4., + resolution=7, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.attn_ratio = attn_ratio + self.key_dim = key_dim + self.nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + h = self.dh + self.nh_kd * 2 + self.qkv = nn.Linear(embed_dims, h) + self.proj = nn.Linear(self.dh, embed_dims) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + """change the mode of model.""" + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + """forward function. + + Args: + x (tensor): input features with shape of (B, N, C) + """ + B, N, _ = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Flat(nn.Module): + """Flat the input from (B, C, H, W) to (B, H*W, C).""" + + def __init__(self, ): + super().__init__() + + def forward(self, x: torch.Tensor): + x = x.flatten(2).transpose(1, 2) + return x + + +class LinearMlp(BaseModule): + """Mlp implemented with linear. + + The shape of input and output tensor are (B, N, C). + + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = build_activation_layer(act_cfg) + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor with shape (B, N, C). + + Returns: + torch.Tensor: output tensor with shape (B, N, C). + """ + x = self.drop1(self.act(self.fc1(x))) + x = self.drop2(self.fc2(x)) + return x + + +class ConvMlp(BaseModule): + """Mlp implemented with 1*1 convolutions. + + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1] + self.norm2 = build_norm_layer(norm_cfg, out_features)[1] + + self.drop = nn.Dropout(drop) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor with shape (B, C, H, W). + + Returns: + torch.Tensor: output tensor with shape (B, C, H, W). + """ + + x = self.act(self.norm1(self.fc1(x))) + x = self.drop(x) + x = self.norm2(self.fc2(x)) + x = self.drop(x) + return x + + +class Meta3D(BaseModule): + """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape + (B, N, C).""" + + def __init__(self, + dim, + mlp_ratio=4., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + use_layer_scale=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = AttentionWithBias(dim) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = LinearMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + if use_layer_scale: + self.ls1 = LayerScale(dim) + self.ls2 = LayerScale(dim) + else: + self.ls1, self.ls2 = nn.Identity(), nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x)))) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class Meta4D(BaseModule): + """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape + (B, C, H, W).""" + + def __init__(self, + dim, + pool_size=3, + mlp_ratio=4., + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + use_layer_scale=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.token_mixer = Pooling(pool_size=pool_size) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + if use_layer_scale: + self.ls1 = LayerScale(dim, data_format='channels_first') + self.ls2 = LayerScale(dim, data_format='channels_first') + else: + self.ls1, self.ls2 = nn.Identity(), nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(x))) + x = x + self.drop_path(self.ls2(self.mlp(x))) + return x + + +def basic_blocks(in_channels, + out_channels, + index, + layers, + pool_size=3, + mlp_ratio=4., + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + use_layer_scale=True, + vit_num=1, + has_downsamper=False): + """generate EfficientFormer blocks for a stage.""" + blocks = [] + if has_downsamper: + blocks.append( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + norm_cfg=dict(type='BN'), + act_cfg=None)) + if index == 3 and vit_num == layers[index]: + blocks.append(Flat()) + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + if index == 3 and layers[index] - block_idx <= vit_num: + blocks.append( + Meta3D( + out_channels, + mlp_ratio=mlp_ratio, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + )) + else: + blocks.append( + Meta4D( + out_channels, + pool_size=pool_size, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale)) + if index == 3 and layers[index] - block_idx - 1 == vit_num: + blocks.append(Flat()) + blocks = nn.Sequential(*blocks) + return blocks + + +@MODELS.register_module() +class EfficientFormer(BaseBackbone): + """EfficientFormer. + + A PyTorch implementation of EfficientFormer introduced by: + `EfficientFormer: Vision Transformers at MobileNet Speed `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``EfficientFormer.arch_settings``. And if dict, + it should include the following 4 keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - downsamples (list[int]): Has downsample or not in the four stages. + - vit_num (int): The num of vit blocks in the last stage. + + Defaults to 'l1'. + + in_channels (int): The num of input channels. Defaults to 3. + pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3. + mlp_ratios (int): The dimension ratio of multi-head attention mechanism + in ``Meta4D`` blocks. Defaults to 3. + reshape_last_feat (bool): Whether to reshape the feature map from + (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num`` + in ``arch`` is not 0. Defaults to False. Usually set to True + in downstream tasks. + out_indices (Sequence[int]): Output from which stages. + Defaults to -1. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer + block. Defaults to True. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmpretrain.models import EfficientFormer + >>> import torch + >>> inputs = torch.rand((1, 3, 224, 224)) + >>> # build EfficientFormer backbone for classification task + >>> model = EfficientFormer(arch="l1") + >>> model.eval() + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 448, 49) + >>> # build EfficientFormer backbone for downstream task + >>> model = EfficientFormer( + >>> arch="l3", + >>> out_indices=(0, 1, 2, 3), + >>> reshape_last_feat=True) + >>> model.eval() + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 320, 14, 14) + (1, 512, 7, 7) + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims: [x,x,x,x], embedding dims for the four stages + # --downsamples: [x,x,x,x], has downsample or not in the four stages + # --vit_num:(int), the num of vit blocks in the last stage + arch_settings = { + 'l1': { + 'layers': [3, 2, 6, 4], + 'embed_dims': [48, 96, 224, 448], + 'downsamples': [False, True, True, True], + 'vit_num': 1, + }, + 'l3': { + 'layers': [4, 4, 12, 6], + 'embed_dims': [64, 128, 320, 512], + 'downsamples': [False, True, True, True], + 'vit_num': 4, + }, + 'l7': { + 'layers': [6, 6, 18, 8], + 'embed_dims': [96, 192, 384, 768], + 'downsamples': [False, True, True, True], + 'vit_num': 8, + }, + } + + def __init__(self, + arch='l1', + in_channels=3, + pool_size=3, + mlp_ratios=4, + reshape_last_feat=False, + out_indices=-1, + frozen_stages=-1, + act_cfg=dict(type='GELU'), + drop_rate=0., + drop_path_rate=0., + use_layer_scale=True, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.num_extra_tokens = 0 # no cls_token, no dist_token + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + default_keys = set(self.arch_settings['l1'].keys()) + assert set(arch.keys()) == default_keys, \ + f'The arch dict must have {default_keys}, ' \ + f'but got {list(arch.keys())}.' + + self.layers = arch['layers'] + self.embed_dims = arch['embed_dims'] + self.downsamples = arch['downsamples'] + assert isinstance(self.layers, list) and isinstance( + self.embed_dims, list) and isinstance(self.downsamples, list) + assert len(self.layers) == len(self.embed_dims) == len( + self.downsamples) + + self.vit_num = arch['vit_num'] + self.reshape_last_feat = reshape_last_feat + + assert self.vit_num >= 0, "'vit_num' must be an integer " \ + 'greater than or equal to 0.' + assert self.vit_num <= self.layers[-1], ( + "'vit_num' must be an integer smaller than layer number") + + self._make_stem(in_channels, self.embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(self.layers)): + if i != 0: + in_channels = self.embed_dims[i - 1] + else: + in_channels = self.embed_dims[i] + out_channels = self.embed_dims[i] + stage = basic_blocks( + in_channels, + out_channels, + i, + self.layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + vit_num=self.vit_num, + use_layer_scale=use_layer_scale, + has_downsamper=self.downsamples[i]) + network.append(stage) + + self.network = ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + + self.out_indices = out_indices + for i_layer in self.out_indices: + if not self.reshape_last_feat and \ + i_layer == 3 and self.vit_num > 0: + layer = build_norm_layer( + dict(type='LN'), self.embed_dims[i_layer])[1] + else: + # use GN with 1 group as channel-first LN2D + layer = build_norm_layer( + dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1] + + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def _make_stem(self, in_channels: int, stem_channels: int): + """make 2-ConvBNReLu stem layer.""" + self.patch_embed = Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + inplace=True)) + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + if idx == len(self.network) - 1: + N, _, H, W = x.shape + if self.downsamples[idx]: + H, W = H // 2, W // 2 + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + + if idx == len(self.network) - 1 and x.dim() == 3: + # when ``vit-num`` > 0 and in the last stage, + # if `self.reshape_last_feat`` is True, reshape the + # features to `BCHW` format before the final normalization. + # if `self.reshape_last_feat`` is False, do + # normalization directly and permute the features to `BCN`. + if self.reshape_last_feat: + x = x.permute((0, 2, 1)).reshape(N, -1, H, W) + x_out = norm_layer(x) + else: + x_out = norm_layer(x).permute((0, 2, 1)) + else: + x_out = norm_layer(x) + + outs.append(x_out.contiguous()) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.patch_embed(x) + # through stages + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientFormer, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/efficientnet.py b/mmpretrain/models/backbones/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec7ee81186610f7adb8af92325471d794509ddc --- /dev/null +++ b/mmpretrain/models/backbones/efficientnet.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import BaseModule, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible +from mmpretrain.registry import MODELS + + +class EdgeResidual(BaseModule): + """Edge Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the second convolution. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + stride (int): The stride of the first convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + with_residual (bool): Use residual connection. Defaults to True. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_residual=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(EdgeResidual, self).__init__(init_cfg=init_cfg) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_residual = ( + stride == 1 and in_channels == out_channels and with_residual) + + if self.with_se: + assert isinstance(se_cfg, dict) + + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.conv2 = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + out = self.conv1(out) + + if self.with_se: + out = self.se(out) + + out = self.conv2(out) + + if self.with_residual: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +def model_scaling(layer_setting, arch_setting): + """Scaling operation to the layer's parameters according to the + arch_setting.""" + # scale width + new_layer_setting = copy.deepcopy(layer_setting) + for layer_cfg in new_layer_setting: + for block_cfg in layer_cfg: + block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8) + + # scale depth + split_layer_setting = [new_layer_setting[0]] + for layer_cfg in new_layer_setting[1:-1]: + tmp_index = [0] + for i in range(len(layer_cfg) - 1): + if layer_cfg[i + 1][1] != layer_cfg[i][1]: + tmp_index.append(i + 1) + tmp_index.append(len(layer_cfg)) + for i in range(len(tmp_index) - 1): + split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i + + 1]]) + split_layer_setting.append(new_layer_setting[-1]) + + num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]] + new_layers = [ + int(math.ceil(arch_setting[1] * num)) for num in num_of_layers + ] + + merge_layer_setting = [split_layer_setting[0]] + for i, layer_cfg in enumerate(split_layer_setting[1:-1]): + if new_layers[i] <= num_of_layers[i]: + tmp_layer_cfg = layer_cfg[:new_layers[i]] + else: + tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * ( + new_layers[i] - num_of_layers[i]) + if tmp_layer_cfg[0][3] == 1 and i != 0: + merge_layer_setting[-1] += tmp_layer_cfg.copy() + else: + merge_layer_setting.append(tmp_layer_cfg.copy()) + merge_layer_setting.append(split_layer_setting[-1]) + + return merge_layer_setting + + +@MODELS.register_module() +class EfficientNet(BaseBackbone): + """EfficientNet backbone. + + Args: + arch (str): Architecture of efficientnet. Defaults to b0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (6, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. + # 'b' represents the architecture of normal EfficientNet family includes + # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'. + # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es', + # 'em', 'el'. + # 6 parameters are needed to construct a layer, From left to right: + # - kernel_size: The kernel size of the block + # - out_channel: The number of out_channels of the block + # - se_ratio: The sequeeze ratio of SELayer. + # - stride: The stride of the block + # - expand_ratio: The expand_ratio of the mid_channels + # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual + layer_settings = { + 'b': [[[3, 32, 0, 2, 0, -1]], + [[3, 16, 4, 1, 1, 0]], + [[3, 24, 4, 2, 6, 0], + [3, 24, 4, 1, 6, 0]], + [[5, 40, 4, 2, 6, 0], + [5, 40, 4, 1, 6, 0]], + [[3, 80, 4, 2, 6, 0], + [3, 80, 4, 1, 6, 0], + [3, 80, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0]], + [[5, 192, 4, 2, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [3, 320, 4, 1, 6, 0]], + [[1, 1280, 0, 1, 0, -1]] + ], + 'e': [[[3, 32, 0, 2, 0, -1]], + [[3, 24, 0, 1, 3, 1]], + [[3, 32, 0, 2, 8, 1], + [3, 32, 0, 1, 8, 1]], + [[3, 48, 0, 2, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1]], + [[5, 96, 0, 2, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0]], + [[5, 192, 0, 2, 8, 0], + [5, 192, 0, 1, 8, 0]], + [[1, 1280, 0, 1, 0, -1]] + ] + } # yapf: disable + + # Parameters to build different kinds of architecture. + # From left to right: scaling factor for width, scaling factor for depth, + # resolution. + arch_settings = { + 'b0': (1.0, 1.0, 224), + 'b1': (1.0, 1.1, 240), + 'b2': (1.1, 1.2, 260), + 'b3': (1.2, 1.4, 300), + 'b4': (1.4, 1.8, 380), + 'b5': (1.6, 2.2, 456), + 'b6': (1.8, 2.6, 528), + 'b7': (2.0, 3.1, 600), + 'b8': (2.2, 3.6, 672), + 'l2': (4.3, 5.3, 800), + 'es': (1.0, 1.0, 224), + 'em': (1.0, 1.1, 240), + 'el': (1.2, 1.4, 300) + } + + def __init__(self, + arch='b0', + drop_path_rate=0., + out_indices=(6, ), + frozen_stages=0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='Swish'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNet, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch_setting = self.arch_settings[arch] + # layer_settings of arch='l2' is 'b' + self.layer_setting = self.layer_settings['b' if arch == + 'l2' else arch[:1]] + for index in out_indices: + if index not in range(0, len(self.layer_setting)): + raise ValueError('the item in out_indices must in ' + f'range(0, {len(self.layer_setting)}). ' + f'But received {index}') + + if frozen_stages not in range(len(self.layer_setting) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.layer_setting) + 1}). ' + f'But received {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layer_setting = model_scaling(self.layer_setting, + self.arch_setting) + block_cfg_0 = self.layer_setting[0][0] + block_cfg_last = self.layer_setting[-1][0] + self.in_channels = make_divisible(block_cfg_0[1], 8) + self.out_channels = block_cfg_last[1] + self.layers = nn.ModuleList() + self.layers.append( + ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=block_cfg_0[0], + stride=block_cfg_0[3], + padding=block_cfg_0[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.make_layer() + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=block_cfg_last[0], + stride=block_cfg_last[3], + padding=block_cfg_last[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def make_layer(self): + # Without the first and the final conv block. + layer_setting = self.layer_setting[1:-1] + + total_num_blocks = sum([len(x) for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for layer_cfg in layer_setting: + layer = [] + for i, block_cfg in enumerate(layer_cfg): + (kernel_size, out_channels, se_ratio, stride, expand_ratio, + block_type) = block_cfg + + mid_channels = int(self.in_channels * expand_ratio) + out_channels = make_divisible(out_channels, 8) + if se_ratio <= 0: + se_cfg = None + else: + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * se_ratio, + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + if block_type == 1: # edge tpu + if i > 0 and expand_ratio == 3: + with_residual = False + expand_ratio = 4 + else: + with_residual = True + mid_channels = int(self.in_channels * expand_ratio) + if se_cfg is not None: + se_cfg = dict( + channels=mid_channels, + ratio=se_ratio * expand_ratio, + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = partial(EdgeResidual, with_residual=with_residual) + else: + block = InvertedResidual + layer.append( + block( + in_channels=self.in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp)) + self.in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + def forward(self, x): + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/efficientnet_v2.py b/mmpretrain/models/backbones/efficientnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..fec002a4dac46f756f00ed8f596b37028ba18c37 --- /dev/null +++ b/mmpretrain/models/backbones/efficientnet_v2.py @@ -0,0 +1,343 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import Sequential +from torch import Tensor + +from mmpretrain.registry import MODELS +from ..utils import InvertedResidual as MBConv +from .base_backbone import BaseBackbone +from .efficientnet import EdgeResidual as FusedMBConv + + +class EnhancedConvModule(ConvModule): + """ConvModule with short-cut and droppath. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + has_skip (bool): Whether there is short-cut. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + def __init__(self, *args, has_skip=False, drop_path_rate=0, **kwargs): + super().__init__(*args, **kwargs) + self.has_skip = has_skip + if self.has_skip and (self.in_channels != self.out_channels + or self.stride != (1, 1)): + raise ValueError('the stride must be 1 and the `in_channels` and' + ' `out_channels` must be the same , when ' + '`has_skip` is True in `EnhancedConvModule` .') + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate else nn.Identity() + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + short_cut = x + x = super().forward(x, **kwargs) + if self.has_skip: + x = self.drop_path(x) + short_cut + return x + + +@MODELS.register_module() +class EfficientNetV2(BaseBackbone): + """EfficientNetV2 backbone. + + A PyTorch implementation of EfficientNetV2 introduced by: + `EfficientNetV2: Smaller Models and Faster Training + `_ + + Args: + arch (str): Architecture of efficientnetv2. Defaults to s. + in_channels (int): Number of input image channels. Defaults to 3. + drop_path_rate (float): The ratio of the stochastic depth. + Defaults to 0.0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (-1, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. From left to right: + # - repeat (int): The repeat number of the block in the layer + # - kernel_size (int): The kernel size of the layer + # - stride (int): The stride of the first block of the layer + # - expand_ratio (int, float): The expand_ratio of the mid_channels + # - in_channel (int): The number of in_channels of the layer + # - out_channel (int): The number of out_channels of the layer + # - se_ratio (float): The sequeeze ratio of SELayer. + # - block_type (int): -2: ConvModule, -1: EnhancedConvModule, + # 0: FusedMBConv, 1: MBConv + arch_settings = { + **dict.fromkeys(['small', 's'], [[2, 3, 1, 1, 24, 24, 0.0, -1], + [4, 3, 2, 4, 24, 48, 0.0, 0], + [4, 3, 2, 4, 48, 64, 0.0, 0], + [6, 3, 2, 4, 64, 128, 0.25, 1], + [9, 3, 1, 6, 128, 160, 0.25, 1], + [15, 3, 2, 6, 160, 256, 0.25, 1], + [1, 1, 1, 1, 256, 1280, 0.0, -2]]), + **dict.fromkeys(['m', 'medium'], [[3, 3, 1, 1, 24, 24, 0.0, -1], + [5, 3, 2, 4, 24, 48, 0.0, 0], + [5, 3, 2, 4, 48, 80, 0.0, 0], + [7, 3, 2, 4, 80, 160, 0.25, 1], + [14, 3, 1, 6, 160, 176, 0.25, 1], + [18, 3, 2, 6, 176, 304, 0.25, 1], + [5, 3, 1, 6, 304, 512, 0.25, 1], + [1, 1, 1, 1, 512, 1280, 0.0, -2]]), + **dict.fromkeys(['l', 'large'], [[4, 3, 1, 1, 32, 32, 0.0, -1], + [7, 3, 2, 4, 32, 64, 0.0, 0], + [7, 3, 2, 4, 64, 96, 0.0, 0], + [10, 3, 2, 4, 96, 192, 0.25, 1], + [19, 3, 1, 6, 192, 224, 0.25, 1], + [25, 3, 2, 6, 224, 384, 0.25, 1], + [7, 3, 1, 6, 384, 640, 0.25, 1], + [1, 1, 1, 1, 640, 1280, 0.0, -2]]), + **dict.fromkeys(['xl'], [[4, 3, 1, 1, 32, 32, 0.0, -1], + [8, 3, 2, 4, 32, 64, 0.0, 0], + [8, 3, 2, 4, 64, 96, 0.0, 0], + [16, 3, 2, 4, 96, 192, 0.25, 1], + [24, 3, 1, 6, 192, 256, 0.25, 1], + [32, 3, 2, 6, 256, 512, 0.25, 1], + [8, 3, 1, 6, 512, 640, 0.25, 1], + [1, 1, 1, 1, 640, 1280, 0.0, -2]]), + **dict.fromkeys(['b0'], [[1, 3, 1, 1, 32, 16, 0.0, -1], + [2, 3, 2, 4, 16, 32, 0.0, 0], + [2, 3, 2, 4, 32, 48, 0.0, 0], + [3, 3, 2, 4, 48, 96, 0.25, 1], + [5, 3, 1, 6, 96, 112, 0.25, 1], + [8, 3, 2, 6, 112, 192, 0.25, 1], + [1, 1, 1, 1, 192, 1280, 0.0, -2]]), + **dict.fromkeys(['b1'], [[2, 3, 1, 1, 32, 16, 0.0, -1], + [3, 3, 2, 4, 16, 32, 0.0, 0], + [3, 3, 2, 4, 32, 48, 0.0, 0], + [4, 3, 2, 4, 48, 96, 0.25, 1], + [6, 3, 1, 6, 96, 112, 0.25, 1], + [9, 3, 2, 6, 112, 192, 0.25, 1], + [1, 1, 1, 1, 192, 1280, 0.0, -2]]), + **dict.fromkeys(['b2'], [[2, 3, 1, 1, 32, 16, 0.0, -1], + [3, 3, 2, 4, 16, 32, 0.0, 0], + [3, 3, 2, 4, 32, 56, 0.0, 0], + [4, 3, 2, 4, 56, 104, 0.25, 1], + [6, 3, 1, 6, 104, 120, 0.25, 1], + [10, 3, 2, 6, 120, 208, 0.25, 1], + [1, 1, 1, 1, 208, 1408, 0.0, -2]]), + **dict.fromkeys(['b3'], [[2, 3, 1, 1, 40, 16, 0.0, -1], + [3, 3, 2, 4, 16, 40, 0.0, 0], + [3, 3, 2, 4, 40, 56, 0.0, 0], + [5, 3, 2, 4, 56, 112, 0.25, 1], + [7, 3, 1, 6, 112, 136, 0.25, 1], + [12, 3, 2, 6, 136, 232, 0.25, 1], + [1, 1, 1, 1, 232, 1536, 0.0, -2]]) + } + + def __init__(self, + arch: str = 's', + in_channels: int = 3, + drop_path_rate: float = 0., + out_indices: Sequence[int] = (-1, ), + frozen_stages: int = 0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.1), + act_cfg=dict(type='Swish'), + norm_eval: bool = False, + with_cp: bool = False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNetV2, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch = self.arch_settings[arch] + if frozen_stages not in range(len(self.arch) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.arch)}), but get {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layers = nn.ModuleList() + assert self.arch[-1][-1] == -2, \ + f'the last block_type of `arch_setting` must be -2 ,' \ + f'but get `{self.arch[-1][-1]}`' + self.in_channels = in_channels + self.out_channels = self.arch[-1][5] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.make_layers() + + # there len(slef.arch) + 2 layers in the backbone + # including: the first + len(self.arch) layers + the last + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.layers) + index + assert 0 <= out_indices[i] <= len(self.layers), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + def make_layers(self, ): + # make the first layer + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.arch[0][4], + kernel_size=3, + stride=2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + in_channels = self.arch[0][4] + layer_setting = self.arch[:-1] + + total_num_blocks = sum([x[0] for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for layer_cfg in layer_setting: + layer = [] + (repeat, kernel_size, stride, expand_ratio, _, out_channels, + se_ratio, block_type) = layer_cfg + for i in range(repeat): + stride = stride if i == 0 else 1 + if block_type == -1: + has_skip = stride == 1 and in_channels == out_channels + droppath_rate = dpr[block_idx] if has_skip else 0.0 + layer.append( + EnhancedConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + has_skip=has_skip, + drop_path_rate=droppath_rate, + stride=stride, + padding=1, + conv_cfg=None, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channels = out_channels + else: + mid_channels = int(in_channels * expand_ratio) + se_cfg = None + if block_type != 0 and se_ratio > 0: + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * (1.0 / se_ratio), + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = FusedMBConv if block_type == 0 else MBConv + conv_cfg = self.conv_cfg if stride == 2 else None + layer.append( + block( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp)) + in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + # make the last layer + self.layers.append( + ConvModule( + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=self.arch[-1][1], + stride=self.arch[-1][2], + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x: Tensor) -> Tuple[Tensor]: + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/hivit.py b/mmpretrain/models/backbones/hivit.py new file mode 100644 index 0000000000000000000000000000000000000000..981cbf819138ace2c2e8441e7e65f927883c55fd --- /dev/null +++ b/mmpretrain/models/backbones/hivit.py @@ -0,0 +1,656 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer, to_2tuple +from .base_backbone import BaseBackbone + + +class Mlp(nn.Module): + """MLP block. + + Args: + in_features (int): Number of input dims. + hidden_features (int): Number of hidden dims. + out_feature (int): Number of out dims. + act_layer: MLP activation layer. + drop (float): MLP dropout rate. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + """Attention. + + Args: + input size (int): Input size. + dim (int): Number of input dims. + num_heads (int): Number of attention heads. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + rpe (bool): If True, add relative position embedding to + the patch embedding. + """ + + def __init__(self, + input_size, + dim, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + rpe=True): + super().__init__() + self.input_size = input_size + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * input_size - 1) * + (2 * input_size - 1), num_heads)) if rpe else None + if rpe: + coords_h = torch.arange(input_size) + coords_w = torch.arange(input_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += input_size - 1 + relative_coords[:, :, 1] += input_size - 1 + relative_coords[:, :, 0] *= 2 * input_size - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', + relative_position_index) + + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, rpe_index=None, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if rpe_index is not None: + rpe_index = self.relative_position_index.view(-1) + S = int(math.sqrt(rpe_index.size(-1))) + relative_position_bias = self.relative_position_bias_table[ + rpe_index].view(-1, S, S, self.num_heads) + relative_position_bias = relative_position_bias.permute( + 0, 3, 1, 2).contiguous() + attn = attn + relative_position_bias + if mask is not None: + mask = mask.bool() + attn = attn.masked_fill(~mask[:, None, None, :], float('-inf')) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BlockWithRPE(nn.Module): + """HiViT block. + + Args: + input_size (int): Input size. + dim (int): Number of input dims. + num_heads (int): Number of attention heads. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + rpe (bool): If True, add relative position embedding to + the patch embedding. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + act_layer: MLP activation layer. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__(self, + input_size, + dim, + num_heads=0., + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + rpe=True, + layer_scale_init_value=0.0, + act_layer=nn.GELU, + norm_cfg=dict(type='LN')): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + + with_attn = num_heads > 0. + + self.norm1 = build_norm_layer(norm_cfg, dim) if with_attn else None + self.attn = Attention( + input_size, + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + rpe=rpe, + ) if with_attn else None + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones( + (dim)), requires_grad=True) if with_attn else None + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rpe_index=None, mask=None): + if self.attn is not None: + if self.gamma_1 is not None: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), rpe_index, mask)) + else: + x = x + self.drop_path( + self.attn(self.norm1(x), rpe_index, mask)) + if self.gamma_2 is not None: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """PatchEmbed for HiViT. + + Args: + img_size (int): Input image size. + patch_size (int): Patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + in_chans (int): Number of image input channels. + embed_dim (int): Transformer embedding dimension. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + kernel_size (int): Kernel size. + pad_size (int): Pad size. + """ + + def __init__(self, + img_size=224, + patch_size=16, + inner_patches=4, + in_chans=3, + embed_dim=128, + norm_cfg=None, + kernel_size=None, + pad_size=None): + super().__init__() + img_size = to_2tuple(img_size) if not isinstance(img_size, + tuple) else img_size + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.inner_patches = inner_patches + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + conv_size = [size // inner_patches for size in patch_size] + kernel_size = kernel_size or conv_size + pad_size = pad_size or 0 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=conv_size, + padding=pad_size) + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + patches_resolution = (H // self.patch_size[0], W // self.patch_size[1]) + num_patches = patches_resolution[0] * patches_resolution[1] + x = self.proj(x).view( + B, + -1, + patches_resolution[0], + self.inner_patches, + patches_resolution[1], + self.inner_patches, + ).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches, + self.inner_patches, -1) + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchMerge(nn.Module): + """PatchMerge for HiViT. + + Args: + dim (int): Number of input channels. + norm_cfg (dict): Config dict for normalization layer. + """ + + def __init__(self, dim, norm_cfg): + super().__init__() + self.norm = build_norm_layer(norm_cfg, dim * 4) + self.reduction = nn.Linear(dim * 4, dim * 2, bias=False) + + def forward(self, x, *args, **kwargs): + is_main_stage = len(x.shape) == 3 + if is_main_stage: + B, N, C = x.shape + S = int(math.sqrt(N)) + x = x.reshape(B, S // 2, 2, S // 2, 2, C) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(B, -1, 2, 2, C) + x0 = x[..., 0::2, 0::2, :] + x1 = x[..., 1::2, 0::2, :] + x2 = x[..., 0::2, 1::2, :] + x3 = x[..., 1::2, 1::2, :] + + x = torch.cat([x0, x1, x2, x3], dim=-1) + x = self.norm(x) + x = self.reduction(x) + + if is_main_stage: + x = x[:, :, 0, 0, :] + return x + + +@MODELS.register_module() +class HiViT(BaseBackbone): + """HiViT. + + A PyTorch implement of: `HiViT: A Simple and More Efficient Design + of Hierarchical Vision Transformer `_. + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', and'base'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (int): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int): Input image size. + patch_size (int): Patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + in_chans (int): Number of image input channels. + embed_dim (int): Transformer embedding dimension. + depths (list[int]): Number of successive HiViT blocks. + num_heads (int): Number of attention heads. + stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim + in the first two stages. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in + the last stage. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): If True, add absolute position embedding to + the patch embedding. + rpe (bool): If True, add relative position embedding to + the patch embedding. + patch_norm (bool): If True, use norm_cfg for normalization layer. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + kernel_size (int): Kernel size. + pad_size (int): Pad size. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 384, + 'depths': [1, 1, 10], + 'num_heads': 6}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 384, + 'depths': [2, 2, 20], + 'num_heads': 6}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 512, + 'depths': [2, 2, 24], + 'num_heads': 8}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 768, + 'depths': [2, 2, 40], + 'num_heads': 12}), + } # yapf: disable + + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + inner_patches=4, + in_chans=3, + stem_mlp_ratio=3., + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.0, + norm_cfg=dict(type='LN'), + out_indices=[23], + ape=True, + rpe=False, + patch_norm=True, + frozen_stages=-1, + kernel_size=None, + pad_size=None, + layer_scale_init_value=0.0, + init_cfg=None): + super(HiViT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.num_stages = len(self.depths) + self.ape = ape + self.rpe = rpe + self.patch_size = patch_size + self.num_features = self.embed_dims + self.mlp_ratio = mlp_ratio + self.num_main_blocks = self.depths[-1] + self.out_indices = out_indices + self.out_indices[-1] = self.depths[-1] - 1 + + img_size = to_2tuple(img_size) if not isinstance(img_size, + tuple) else img_size + + embed_dim = self.embed_dims // 2**(self.num_stages - 1) + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + in_chans=in_chans, + embed_dim=embed_dim, + norm_cfg=norm_cfg if patch_norm else None, + kernel_size=kernel_size, + pad_size=pad_size) + num_patches = self.patch_embed.num_patches + Hp, Wp = self.patch_embed.patches_resolution + + if rpe: + assert Hp == Wp, 'If you use relative position, make sure H == W ' + 'of input size' + + # absolute position embedding + if ape: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.num_features)) + trunc_normal_(self.pos_embed, std=.02) + if rpe: + # get pair-wise relative position index for each token inside the + # window + coords_h = torch.arange(Hp) + coords_w = torch.arange(Wp) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += Hp - 1 + relative_coords[:, :, 1] += Wp - 1 + relative_coords[:, :, 0] *= 2 * Wp - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', + relative_position_index) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = iter( + x.item() + for x in torch.linspace(0, drop_path_rate, + sum(self.depths) + sum(self.depths[:-1]))) + + # build blocks + self.blocks = nn.ModuleList() + for stage_i, stage_depth in enumerate(self.depths): + is_main_stage = embed_dim == self.num_features + nhead = self.num_heads if is_main_stage else 0 + ratio = mlp_ratio if is_main_stage else stem_mlp_ratio + # every block not in main stage includes two mlp blocks + stage_depth = stage_depth if is_main_stage else stage_depth * 2 + for _ in range(stage_depth): + self.blocks.append( + BlockWithRPE( + Hp, + embed_dim, + nhead, + ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=next(dpr), + rpe=rpe, + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + )) + if stage_i + 1 < self.num_stages: + self.blocks.append(PatchMerge(embed_dim, norm_cfg)) + embed_dim *= 2 + + self.frozen_stages = frozen_stages + if self.frozen_stages > 0: + self._freeze_stages() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, h, w): + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + patch_pos_embed = self.pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), + dim).permute(0, 3, 1, 2), + scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(h0) == patch_pos_embed.shape[-2] and int( + w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, x): + B, C, H, W = x.shape + Hp, Wp = H // self.patch_size, W // self.patch_size + + x = self.patch_embed(x) + + outs = [] + for i, blk in enumerate(self.blocks[:-self.num_main_blocks]): + x = blk(x) + if i in self.out_indices: + x = x.reshape(B, Hp, Wp, *x.shape[-3:]).permute( + 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * x.shape[-3], + Wp * x.shape[-2]).contiguous() + outs.append(x) + + x = x[..., 0, 0, :] + if self.ape: + x = x + self.interpolate_pos_encoding(x, H, W) + x = self.pos_drop(x) + + rpe_index = True if self.rpe else None + + for i, blk in enumerate(self.blocks[-self.num_main_blocks:]): + x = blk(x, rpe_index) + if i in self.out_indices: + x = x.transpose(1, 2).view(B, -1, Hp, Wp).contiguous() + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.pos_drop.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + for param in self.fc_norm.parameters(): + param.requires_grad = False + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + self.num_layers = len(self.blocks) + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in 'pos_embed': + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/hornet.py b/mmpretrain/models/backbones/hornet.py new file mode 100644 index 0000000000000000000000000000000000000000..460f2dc57975712b5eae8308e2fca9c38b89a3e2 --- /dev/null +++ b/mmpretrain/models/backbones/hornet.py @@ -0,0 +1,500 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from official impl at https://github.com/raoyongming/HorNet. +try: + import torch.fft + fft = True +except ImportError: + fft = None + +import copy +from functools import partial +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import LayerScale + + +def get_dwconv(dim, kernel_size, bias=True): + """build a pepth-wise convolution.""" + return nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + bias=bias, + groups=dim) + + +class HorNetLayerNorm(nn.Module): + """An implementation of LayerNorm of HorNet. + + The differences between HorNetLayerNorm & torch LayerNorm: + 1. Supports two data formats channels_last or channels_first. + Args: + normalized_shape (int or list or torch.Size): input shape from an + expected input of size. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-5. + data_format (str): The ordering of the dimensions in the inputs. + channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with + shape (batch_size, channels, height, width). + Defaults to 'channels_last'. + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise ValueError( + 'data_format must be channels_last or channels_first') + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GlobalLocalFilter(nn.Module): + """A GlobalLocalFilter of HorNet. + + Args: + dim (int): Number of input channels. + h (int): Height of complex_weight. + Defaults to 14. + w (int): Width of complex_weight. + Defaults to 8. + """ + + def __init__(self, dim, h=14, w=8): + super().__init__() + self.dw = nn.Conv2d( + dim // 2, + dim // 2, + kernel_size=3, + padding=1, + bias=False, + groups=dim // 2) + self.complex_weight = nn.Parameter( + torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02) + self.pre_norm = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + self.post_norm = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + + def forward(self, x): + x = self.pre_norm(x) + x1, x2 = torch.chunk(x, 2, dim=1) + x1 = self.dw(x1) + + x2 = x2.to(torch.float32) + B, C, a, b = x2.shape + x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho') + + weight = self.complex_weight + if not weight.shape[1:3] == x2.shape[2:4]: + weight = F.interpolate( + weight.permute(3, 0, 1, 2), + size=x2.shape[2:4], + mode='bilinear', + align_corners=True).permute(1, 2, 3, 0) + + weight = torch.view_as_complex(weight.contiguous()) + + x2 = x2 * weight + x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho') + + x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], + dim=2).reshape(B, 2 * C, a, b) + x = self.post_norm(x) + return x + + +class gnConv(nn.Module): + """A gnConv of HorNet. + + Args: + dim (int): Number of input channels. + order (int): Order of gnConv. + Defaults to 5. + dw_cfg (dict): The Config for dw conv. + Defaults to ``dict(type='DW', kernel_size=7)``. + scale (float): Scaling parameter of gflayer outputs. + Defaults to 1.0. + """ + + def __init__(self, + dim, + order=5, + dw_cfg=dict(type='DW', kernel_size=7), + scale=1.0): + super().__init__() + self.order = order + self.dims = [dim // 2**i for i in range(order)] + self.dims.reverse() + self.proj_in = nn.Conv2d(dim, 2 * dim, 1) + + cfg = copy.deepcopy(dw_cfg) + dw_type = cfg.pop('type') + assert dw_type in ['DW', 'GF'],\ + 'dw_type should be `DW` or `GF`' + if dw_type == 'DW': + self.dwconv = get_dwconv(sum(self.dims), **cfg) + elif dw_type == 'GF': + self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg) + + self.proj_out = nn.Conv2d(dim, dim, 1) + + self.projs = nn.ModuleList([ + nn.Conv2d(self.dims[i], self.dims[i + 1], 1) + for i in range(order - 1) + ]) + + self.scale = scale + + def forward(self, x): + x = self.proj_in(x) + y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1) + + x = self.dwconv(x) * self.scale + + dw_list = torch.split(x, self.dims, dim=1) + x = y * dw_list[0] + + for i in range(self.order - 1): + x = self.projs[i](x) * dw_list[i + 1] + + x = self.proj_out(x) + + return x + + +class HorNetBlock(nn.Module): + """A block of HorNet. + + Args: + dim (int): Number of input channels. + order (int): Order of gnConv. + Defaults to 5. + dw_cfg (dict): The Config for dw conv. + Defaults to ``dict(type='DW', kernel_size=7)``. + scale (float): Scaling parameter of gflayer outputs. + Defaults to 1.0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_layer_scale (bool): Whether to use use_layer_scale in HorNet + block. Defaults to True. + """ + + def __init__(self, + dim, + order=5, + dw_cfg=dict(type='DW', kernel_size=7), + scale=1.0, + drop_path_rate=0., + use_layer_scale=True): + super().__init__() + self.out_channels = dim + + self.norm1 = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + self.gnconv = gnConv(dim, order, dw_cfg, scale) + self.norm2 = HorNetLayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + + if use_layer_scale: + self.gamma1 = LayerScale(dim, data_format='channels_first') + self.gamma2 = LayerScale(dim) + else: + self.gamma1, self.gamma2 = nn.Identity(), nn.Identity() + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x)))) + + input = x + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm2(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + x = self.gamma2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +@MODELS.register_module() +class HorNet(BaseBackbone): + """HorNet backbone. + + A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial + Interactions with Recursive Gated Convolutions + `_ . + Inspiration from https://github.com/raoyongming/HorNet + + Args: + arch (str | dict): HorNet architecture. + + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **base_dim** (int): The base dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **orders** (List[int]): The number of order of gnConv in each + stage. + - **dw_cfg** (List[dict]): The Config for dw conv. + + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3. + use_layer_scale (bool): Whether to use use_layer_scale in HorNet + block. Defaults to True. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'base_dim': 64, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['t-gf', 'tiny-gf'], + {'base_dim': 64, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['s', 'small'], + {'base_dim': 96, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['s-gf', 'small-gf'], + {'base_dim': 96, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['b', 'base'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['b-gf', 'base-gf'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['b-gf384', 'base-gf384'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=24, w=12), + dict(type='GF', h=13, w=7)]}), + **dict.fromkeys(['l', 'large'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['l-gf', 'large-gf'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['l-gf384', 'large-gf384'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=24, w=12), + dict(type='GF', h=13, w=7)]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + in_channels=3, + drop_path_rate=0., + scale=1 / 3, + use_layer_scale=True, + out_indices=(3, ), + frozen_stages=-1, + with_cp=False, + gap_before_final_norm=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if fft is None: + raise RuntimeError( + 'Failed to import torch.fft. Please install "torch>=1.7".') + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.scale = scale + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.with_cp = with_cp + self.gap_before_final_norm = gap_before_final_norm + + base_dim = self.arch_settings['base_dim'] + dims = list(map(lambda x: 2**x * base_dim, range(4))) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4), + HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + HorNetLayerNorm( + dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + total_depth = sum(self.arch_settings['depths']) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + self.stages = nn.ModuleList() + for i in range(4): + stage = nn.Sequential(*[ + HorNetBlock( + dim=dims[i], + order=self.arch_settings['orders'][i], + dw_cfg=self.arch_settings['dw_cfg'][i], + scale=self.scale, + drop_path_rate=dpr[cur_block_idx + j], + use_layer_scale=use_layer_scale) + for j in range(self.arch_settings['depths'][i]) + ]) + self.stages.append(stage) + cur_block_idx += self.arch_settings['depths'][i] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + norm_layer = partial( + HorNetLayerNorm, eps=1e-6, data_format='channels_first') + for i_layer in out_indices: + layer = norm_layer(dims[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + super(HorNet, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = self.downsample_layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i in self.out_indices: + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(4): + x = self.downsample_layers[i](x) + if self.with_cp: + x = checkpoint.checkpoint_sequential(self.stages[i], + len(self.stages[i]), x) + else: + x = self.stages[i](x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap).flatten(1)) + else: + # The output of LayerNorm2d may be discontiguous, which + # may cause some problem in the downstream tasks + outs.append(norm_layer(x).contiguous()) + return tuple(outs) diff --git a/mmpretrain/models/backbones/hrnet.py b/mmpretrain/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..99afa908531326f05ff1c977f0146a528683af43 --- /dev/null +++ b/mmpretrain/models/backbones/hrnet.py @@ -0,0 +1,563 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + + Args: + num_branches (int): The number of branches. + block (``BaseModule``): Convolution block module. + num_blocks (tuple): The number of blocks in each branch. + The length must be equal to ``num_branches``. + num_channels (tuple): The number of base channels in each branch. + The length must be equal to ``num_branches``. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + conv_cfg (dict, optional): Dictionary to construct and config conv + layer. Defaults to None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN')``. + block_init_cfg (dict, optional): The initialization configs of every + blocks. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_branches, + block, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + block_init_cfg=None, + init_cfg=None): + super(HRModule, self).__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, block, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_BLOCKS({len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_CHANNELS({len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + out_channels = num_channels[i] * get_expansion(block) + branches.append( + ResLayer( + block=block, + num_blocks=num_blocks[i], + in_channels=self.in_channels[i], + out_channels=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp, + init_cfg=self.block_init_cfg, + )) + + return ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + # Upsample the feature maps of smaller scales. + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + # Keep the feature map with the same scale. + fuse_layer.append(None) + else: + # Downsample the feature maps of larger scales. + conv_downsamples = [] + for k in range(i - j): + # Use stacked convolution layers to downsample. + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + `High-Resolution Representations for Labeling Pixels and Regions + `_. + + Args: + arch (str): The preset HRNet architecture, includes 'w18', 'w30', + 'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if + extra is ``None``. Defaults to 'w32'. + extra (dict, optional): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. Please choose between + 'BOTTLENECK' and 'BASIC'. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of base channels in each branch. + The length must be equal to num_branches. + + Defaults to None. + in_channels (int): Number of input image channels. Defaults to 3. + conv_cfg (dict, optional): Dictionary to construct and config conv + layer. Defaults to None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN')``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> import torch + >>> from mmpretrain.models import HRNet + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + arch_zoo = { + # num_modules, num_branches, block, num_blocks, num_channels + 'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (18, 36)], + [4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)], + [3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]], + 'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (30, 60)], + [4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)], + [3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]], + 'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (32, 64)], + [4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)], + [3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]], + 'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (40, 80)], + [4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)], + [3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]], + 'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (44, 88)], + [4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)], + [3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]], + 'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (48, 96)], + [4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)], + [3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]], + 'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (64, 128)], + [4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)], + [3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]], + } # yapf:disable + + def __init__(self, + arch='w32', + extra=None, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=False, + with_cp=False, + zero_init_residual=False, + multiscale_output=True, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(HRNet, self).__init__(init_cfg) + + extra = self.parse_arch(arch, extra) + + # Assert configurations of 4 stages are in extra + for i in range(1, 5): + assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".' + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + cfg = extra[f'stage{i}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + + # -------------------- stem net -------------------- + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.conv2 = build_conv_layer( + self.conv_cfg, + in_channels=64, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # -------------------- stage 1 -------------------- + self.stage1_cfg = self.extra['stage1'] + base_channels = self.stage1_cfg['num_channels'] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'] + + block = self.blocks_dict[block_type] + num_channels = [ + channel * get_expansion(block) for channel in base_channels + ] + # To align with the original code, use layer1 instead of stage1 here. + self.layer1 = ResLayer( + block, + in_channels=64, + out_channels=num_channels[0], + num_blocks=num_blocks[0]) + pre_num_channels = num_channels + + # -------------------- stage 2~4 -------------------- + for i in range(2, 5): + stage_cfg = self.extra[f'stage{i}'] + base_channels = stage_cfg['num_channels'] + block = self.blocks_dict[stage_cfg['block']] + multiscale_output_ = multiscale_output if i == 4 else True + + num_channels = [ + channel * get_expansion(block) for channel in base_channels + ] + # The transition layer from layer1 to stage2 + transition = self._make_transition_layer(pre_num_channels, + num_channels) + self.add_module(f'transition{i-1}', transition) + stage = self._make_stage( + stage_cfg, num_channels, multiscale_output=multiscale_output_) + self.add_module(f'stage{i}', stage) + + pre_num_channels = num_channels + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + # For existing scale branches, + # add conv block when the channels are not the same. + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + # For new scale branches, add stacked downsample conv blocks. + # For example, num_branches_pre = 2, for the 4th branch, add + # stacked two downsample conv blocks. + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules) + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [x] + + for i in range(2, 5): + # Apply transition + transition = getattr(self, f'transition{i-1}') + inputs = [] + for j, layer in enumerate(transition): + if j < len(x_list): + inputs.append(layer(x_list[j])) + else: + inputs.append(layer(x_list[-1])) + # Forward HRModule + stage = getattr(self, f'stage{i}') + x_list = stage(inputs) + + return tuple(x_list) + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(HRNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def parse_arch(self, arch, extra=None): + if extra is not None: + return extra + + assert arch in self.arch_zoo, \ + ('Invalid arch, please choose arch from ' + f'{list(self.arch_zoo.keys())}, or specify `extra` ' + 'argument directly.') + + extra = dict() + for i, stage_setting in enumerate(self.arch_zoo[arch], start=1): + extra[f'stage{i}'] = dict( + num_modules=stage_setting[0], + num_branches=stage_setting[1], + block=stage_setting[2], + num_blocks=stage_setting[3], + num_channels=stage_setting[4], + ) + + return extra diff --git a/mmpretrain/models/backbones/inception_v3.py b/mmpretrain/models/backbones/inception_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6c04b9fba4b50fce31539d14874dc7a47a539a --- /dev/null +++ b/mmpretrain/models/backbones/inception_v3.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class BasicConv2d(BaseModule): + """A basic convolution block including convolution, batch norm and ReLU. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict, optional): The config of convolution layer. + Defaults to None, which means to use ``nn.Conv2d``. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + **kwargs: Other keyword arguments of the convolution layer. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.conv = build_conv_layer( + conv_cfg, in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class InceptionA(BaseModule): + """Type-A Inception block. + + Args: + in_channels (int): The number of input channels. + pool_features (int): The number of channels in pooling branch. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + pool_features: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + + self.branch5x5_1 = BasicConv2d( + in_channels, 48, kernel_size=1, conv_cfg=conv_cfg) + self.branch5x5_2 = BasicConv2d( + 48, 64, kernel_size=5, padding=2, conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3 = BasicConv2d( + 96, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, pool_features, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionB(BaseModule): + """Type-B Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch3x3 = BasicConv2d( + in_channels, 384, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3 = BasicConv2d( + 96, 96, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool(x) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionC(BaseModule): + """Type-C Inception block. + + Args: + in_channels (int): The number of input channels. + channels_7x7 (int): The number of channels in 7x7 convolution branch. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + channels_7x7: int, + conv_cfg: Optional[dict] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + c7 = channels_7x7 + self.branch7x7_1 = BasicConv2d( + in_channels, c7, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7_2 = BasicConv2d( + c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7_3 = BasicConv2d( + c7, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + + self.branch7x7dbl_1 = BasicConv2d( + in_channels, c7, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7dbl_2 = BasicConv2d( + c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7dbl_3 = BasicConv2d( + c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7dbl_4 = BasicConv2d( + c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7dbl_5 = BasicConv2d( + c7, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionD(BaseModule): + """Type-D Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch3x3_1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3_2 = BasicConv2d( + 192, 320, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch7x7x3_1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7x3_2 = BasicConv2d( + 192, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7x3_3 = BasicConv2d( + 192, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7x3_4 = BasicConv2d( + 192, 192, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = self.branch_pool(x) + outputs = [branch3x3, branch7x7x3, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionE(BaseModule): + """Type-E Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 320, kernel_size=1, conv_cfg=conv_cfg) + + self.branch3x3_1 = BasicConv2d( + in_channels, 384, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3_2a = BasicConv2d( + 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg) + self.branch3x3_2b = BasicConv2d( + 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 448, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 448, 384, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3a = BasicConv2d( + 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg) + self.branch3x3dbl_3b = BasicConv2d( + 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionAux(BaseModule): + """The Inception block for the auxiliary classification branch. + + Args: + in_channels (int): The number of input channels. + num_classes (int): The number of categroies. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to use trunc normal with ``std=0.01`` for Conv2d layers + and use trunc normal with ``std=0.001`` for Linear layers.. + """ + + def __init__(self, + in_channels: int, + num_classes: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = [ + dict(type='TruncNormal', layer='Conv2d', std=0.01), + dict(type='TruncNormal', layer='Linear', std=0.001) + ]): + super().__init__(init_cfg=init_cfg) + self.downsample = nn.AvgPool2d(kernel_size=5, stride=3) + self.conv0 = BasicConv2d( + in_channels, 128, kernel_size=1, conv_cfg=conv_cfg) + self.conv1 = BasicConv2d(128, 768, kernel_size=5, conv_cfg=conv_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(768, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + # N x 768 x 17 x 17 + x = self.downsample(x) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = self.gap(x) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +@MODELS.register_module() +class InceptionV3(BaseBackbone): + """Inception V3 backbone. + + A PyTorch implementation of `Rethinking the Inception Architecture for + Computer Vision `_ + + This implementation is modified from + https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py. + Licensed under the BSD 3-Clause License. + + Args: + num_classes (int): The number of categroies. Defaults to 1000. + aux_logits (bool): Whether to enable the auxiliary branch. If False, + the auxiliary logits output will be None. Defaults to False. + dropout (float): Dropout rate. Defaults to 0.5. + init_cfg (dict, optional): The config of initialization. Defaults + to use trunc normal with ``std=0.1`` for all Conv2d and Linear + layers and constant with ``val=1`` for all BatchNorm2d layers. + + Example: + >>> import torch + >>> from mmpretrain.models import build_backbone + >>> + >>> inputs = torch.rand(2, 3, 299, 299) + >>> cfg = dict(type='InceptionV3', num_classes=100) + >>> backbone = build_backbone(cfg) + >>> aux_out, out = backbone(inputs) + >>> # The auxiliary branch is disabled by default. + >>> assert aux_out is None + >>> print(out.shape) + torch.Size([2, 100]) + >>> cfg = dict(type='InceptionV3', num_classes=100, aux_logits=True) + >>> backbone = build_backbone(cfg) + >>> aux_out, out = backbone(inputs) + >>> print(aux_out.shape, out.shape) + torch.Size([2, 100]) torch.Size([2, 100]) + """ + + def __init__( + self, + num_classes: int = 1000, + aux_logits: bool = False, + dropout: float = 0.5, + init_cfg: Optional[dict] = [ + dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=0.1), + dict(type='Constant', layer='BatchNorm2d', val=1) + ], + ) -> None: + super().__init__(init_cfg=init_cfg) + + self.aux_logits = aux_logits + self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + self.AuxLogits: Optional[nn.Module] = None + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(p=dropout) + self.fc = nn.Linear(2048, num_classes) + + def forward( + self, + x: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function.""" + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.maxpool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.maxpool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + aux: Optional[torch.Tensor] = None + if self.aux_logits and self.training: + aux = self.AuxLogits(x) + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + # Adaptive average pooling + x = self.avgpool(x) + # N x 2048 x 1 x 1 + x = self.dropout(x) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + x = self.fc(x) + # N x 1000 (num_classes) + return aux, x diff --git a/mmpretrain/models/backbones/lenet.py b/mmpretrain/models/backbones/lenet.py new file mode 100644 index 0000000000000000000000000000000000000000..8e423c0b15a60660714617e47fd68857b3a6d1e0 --- /dev/null +++ b/mmpretrain/models/backbones/lenet.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class LeNet5(BaseBackbone): + """`LeNet5 `_ backbone. + + The input for LeNet-5 is a 32×32 grayscale image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super(LeNet5, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh()) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(120, 84), + nn.Tanh(), + nn.Linear(84, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = self.classifier(x.squeeze()) + + return (x, ) diff --git a/mmpretrain/models/backbones/levit.py b/mmpretrain/models/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7aa324e28b1725fb9e67110a26ea2d5c2831bd --- /dev/null +++ b/mmpretrain/models/backbones/levit.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, fuse_conv_bn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +class HybridBackbone(BaseModule): + + def __init__( + self, + embed_dim, + kernel_size=3, + stride=2, + pad=1, + dilation=1, + groups=1, + act_cfg=dict(type='HSwish'), + conv_cfg=None, + norm_cfg=dict(type='BN'), + init_cfg=None, + ): + super(HybridBackbone, self).__init__(init_cfg=init_cfg) + + self.input_channels = [ + 3, embed_dim // 8, embed_dim // 4, embed_dim // 2 + ] + self.output_channels = [ + embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim + ] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.patch_embed = Sequential() + + for i in range(len(self.input_channels)): + conv_bn = ConvolutionBatchNorm( + self.input_channels[i], + self.output_channels[i], + kernel_size=kernel_size, + stride=stride, + pad=pad, + dilation=dilation, + groups=groups, + norm_cfg=norm_cfg, + ) + self.patch_embed.add_module('%d' % (2 * i), conv_bn) + if i < len(self.input_channels) - 1: + self.patch_embed.add_module('%d' % (i * 2 + 1), + build_activation_layer(act_cfg)) + + def forward(self, x): + x = self.patch_embed(x) + return x + + +class ConvolutionBatchNorm(BaseModule): + + def __init__( + self, + in_channel, + out_channel, + kernel_size=3, + stride=2, + pad=1, + dilation=1, + groups=1, + norm_cfg=dict(type='BN'), + ): + super(ConvolutionBatchNorm, self).__init__() + self.conv = nn.Conv2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=pad, + dilation=dilation, + groups=groups, + bias=False) + self.bn = build_norm_layer(norm_cfg, out_channel) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + @torch.no_grad() + def fuse(self): + return fuse_conv_bn(self).conv + + +class LinearBatchNorm(BaseModule): + + def __init__(self, in_feature, out_feature, norm_cfg=dict(type='BN1d')): + super(LinearBatchNorm, self).__init__() + self.linear = nn.Linear(in_feature, out_feature, bias=False) + self.bn = build_norm_layer(norm_cfg, out_feature) + + def forward(self, x): + x = self.linear(x) + x = self.bn(x.flatten(0, 1)).reshape_as(x) + return x + + @torch.no_grad() + def fuse(self): + w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 + w = self.linear.weight * w[:, None] + b = self.bn.bias - self.bn.running_mean * self.bn.weight / \ + (self.bn.running_var + self.bn.eps) ** 0.5 + + factory_kwargs = { + 'device': self.linear.weight.device, + 'dtype': self.linear.weight.dtype + } + bias = nn.Parameter( + torch.empty(self.linear.out_features, **factory_kwargs)) + self.linear.register_parameter('bias', bias) + self.linear.weight.data.copy_(w) + self.linear.bias.data.copy_(b) + return self.linear + + +class Residual(BaseModule): + + def __init__(self, block, drop_path_rate=0.): + super(Residual, self).__init__() + self.block = block + if drop_path_rate > 0: + self.drop_path = DropPath(drop_path_rate) + else: + self.drop_path = nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.block(x)) + return x + + +class Attention(BaseModule): + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + act_cfg=dict(type='HSwish'), + resolution=14, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = LinearBatchNorm(dim, h) + self.proj = nn.Sequential( + build_activation_layer(act_cfg), LinearBatchNorm(self.dh, dim)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + """change the mode of model.""" + super(Attention, self).train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape # 2 196 128 + qkv = self.qkv(x) # 2 196 128 + q, k, v = qkv.view(B, N, self.num_heads, -1).split( + [self.key_dim, self.key_dim, self.d], + dim=3) # q 2 196 4 16 ; k 2 196 4 16; v 2 196 4 32 + q = q.permute(0, 2, 1, 3) # 2 4 196 16 + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ((q @ k.transpose(-2, -1)) * + self.scale # 2 4 196 16 * 2 4 16 196 -> 2 4 196 196 + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) # 2 4 196 196 -> 2 4 196 196 + x = (attn @ v).transpose(1, 2).reshape( + B, N, + self.dh) # 2 4 196 196 * 2 4 196 32 -> 2 4 196 32 -> 2 196 128 + x = self.proj(x) + return x + + +class MLP(nn.Sequential): + + def __init__(self, embed_dim, mlp_ratio, act_cfg=dict(type='HSwish')): + super(MLP, self).__init__() + h = embed_dim * mlp_ratio + self.linear1 = LinearBatchNorm(embed_dim, h) + self.activation = build_activation_layer(act_cfg) + self.linear2 = LinearBatchNorm(h, embed_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.activation(x) + x = self.linear2(x) + return x + + +class Subsample(BaseModule): + + def __init__(self, stride, resolution): + super(Subsample, self).__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, _, C = x.shape + # B, N, C -> B, H, W, C + x = x.view(B, self.resolution, self.resolution, C) + x = x[:, ::self.stride, ::self.stride] + x = x.reshape(B, -1, C) # B, H', W', C -> B, N', C + return x + + +class AttentionSubsample(nn.Sequential): + + def __init__(self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=2, + act_cfg=dict(type='HSwish'), + stride=2, + resolution=14): + super(AttentionSubsample, self).__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.sub_resolution = (resolution - 1) // stride + 1 + h = self.dh + nh_kd + self.kv = LinearBatchNorm(in_dim, h) + + self.q = nn.Sequential( + Subsample(stride, resolution), LinearBatchNorm(in_dim, nh_kd)) + self.proj = nn.Sequential( + build_activation_layer(act_cfg), LinearBatchNorm(self.dh, out_dim)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + sub_points = list( + itertools.product( + range(self.sub_resolution), range(self.sub_resolution))) + N = len(points) + N_sub = len(sub_points) + attention_offsets = {} + idxs = [] + for p1 in sub_points: + for p2 in points: + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N_sub, N)) + + @torch.no_grad() + def train(self, mode=True): + super(AttentionSubsample, self).train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, + -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.sub_resolution**2, self.num_heads, + self.key_dim).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + \ + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +@MODELS.register_module() +class LeViT(BaseBackbone): + """LeViT backbone. + + A PyTorch implementation of `LeViT: A Vision Transformer in ConvNet's + Clothing for Faster Inference `_ + + Modified from the official implementation: + https://github.com/facebookresearch/LeViT + + Args: + arch (str | dict): LeViT architecture. + + If use string, choose from '128s', '128', '192', '256' and '384'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The embed dimensions of each stage. + - **key_dims** (List[int]): The embed dimensions of the key in the + attention layers of each stage. + - **num_heads** (List[int]): The number of heads in each stage. + - **depths** (List[int]): The number of blocks in each stage. + + img_size (int): Input image size + patch_size (int | tuple): The patch size. Deault to 16 + attn_ratio (int): Ratio of hidden dimensions of the value in attention + layers. Defaults to 2. + mlp_ratio (int): Ratio of hidden dimensions in MLP layers. + Defaults to 2. + act_cfg (dict): The config of activation functions. + Defaults to ``dict(type='HSwish')``. + hybrid_backbone (callable): A callable object to build the patch embed + module. Defaults to use :class:`HybridBackbone`. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + deploy (bool): Whether to switch the model structure to + deployment mode. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + '128s': { + 'embed_dims': [128, 256, 384], + 'num_heads': [4, 6, 8], + 'depths': [2, 3, 4], + 'key_dims': [16, 16, 16], + }, + '128': { + 'embed_dims': [128, 256, 384], + 'num_heads': [4, 8, 12], + 'depths': [4, 4, 4], + 'key_dims': [16, 16, 16], + }, + '192': { + 'embed_dims': [192, 288, 384], + 'num_heads': [3, 5, 6], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + '256': { + 'embed_dims': [256, 384, 512], + 'num_heads': [4, 6, 8], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + '384': { + 'embed_dims': [384, 512, 768], + 'num_heads': [6, 9, 12], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + } + + def __init__(self, + arch, + img_size=224, + patch_size=16, + attn_ratio=2, + mlp_ratio=2, + act_cfg=dict(type='HSwish'), + hybrid_backbone=HybridBackbone, + out_indices=-1, + deploy=False, + drop_path_rate=0, + init_cfg=None): + super(LeViT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch = self.arch_zoo[arch] + elif isinstance(arch, dict): + essential_keys = {'embed_dim', 'num_heads', 'depth', 'key_dim'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch = arch + else: + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + self.embed_dims = self.arch['embed_dims'] + self.num_heads = self.arch['num_heads'] + self.key_dims = self.arch['key_dims'] + self.depths = self.arch['depths'] + self.num_stages = len(self.embed_dims) + self.drop_path_rate = drop_path_rate + + self.patch_embed = hybrid_backbone(self.embed_dims[0]) + + self.resolutions = [] + resolution = img_size // patch_size + self.stages = ModuleList() + for i, (embed_dims, key_dims, depth, num_heads) in enumerate( + zip(self.embed_dims, self.key_dims, self.depths, + self.num_heads)): + blocks = [] + if i > 0: + downsample = AttentionSubsample( + in_dim=self.embed_dims[i - 1], + out_dim=embed_dims, + key_dim=key_dims, + num_heads=self.embed_dims[i - 1] // key_dims, + attn_ratio=4, + act_cfg=act_cfg, + stride=2, + resolution=resolution) + blocks.append(downsample) + resolution = downsample.sub_resolution + if mlp_ratio > 0: # mlp_ratio + blocks.append( + Residual( + MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), + self.drop_path_rate)) + self.resolutions.append(resolution) + for _ in range(depth): + blocks.append( + Residual( + Attention( + embed_dims, + key_dims, + num_heads, + attn_ratio=attn_ratio, + act_cfg=act_cfg, + resolution=resolution, + ), self.drop_path_rate)) + if mlp_ratio > 0: + blocks.append( + Residual( + MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), + self.drop_path_rate)) + + self.stages.append(Sequential(*blocks)) + + if isinstance(out_indices, int): + out_indices = [out_indices] + elif isinstance(out_indices, tuple): + out_indices = list(out_indices) + elif not isinstance(out_indices, list): + raise TypeError('"out_indices" must by a list, tuple or int, ' + f'get {type(out_indices)} instead.') + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] < self.num_stages, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + self.deploy = False + if deploy: + self.switch_to_deploy() + + def switch_to_deploy(self): + if self.deploy: + return + fuse_parameters(self) + self.deploy = True + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) # B, C, H, W -> B, L, C + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + B, _, C = x.shape + if i in self.out_indices: + out = x.reshape(B, self.resolutions[i], self.resolutions[i], C) + out = out.permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + +def fuse_parameters(module): + for child_name, child in module.named_children(): + if hasattr(child, 'fuse'): + setattr(module, child_name, child.fuse()) + else: + fuse_parameters(child) diff --git a/mmpretrain/models/backbones/mixmim.py b/mmpretrain/models/backbones/mixmim.py new file mode 100644 index 0000000000000000000000000000000000000000..2c67aa0c3a45c5c85adbacb94ae90dc170b2d0bb --- /dev/null +++ b/mmpretrain/models/backbones/mixmim.py @@ -0,0 +1,533 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging +from mmengine.model import BaseModule +from torch import nn +from torch.utils.checkpoint import checkpoint + +from mmpretrain.registry import MODELS +from ..utils import WindowMSA, to_2tuple +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class MixMIMWindowAttention(WindowMSA): + """MixMIM Window Attention. + + Compared with WindowMSA, we add some modifications + in ``forward`` to meet the requirement of MixMIM during + pretraining. + + Implements one windown attention in MixMIM. + Args: + embed_dims (int): The feature dimension. + window_size (list): The height and width of the window. + num_heads (int): The number of head in attention. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=proj_drop_rate, + init_cfg=init_cfg) + + def forward(self, x, mask=None): + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + mask = mask.reshape(B_, 1, 1, N) + mask_new = mask * mask.transpose( + 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3) + mask_new = 1 - mask_new + + if mask_new.dtype == torch.float16: + attn = attn - 65500 * mask_new + else: + attn = attn - 1e30 * mask_new + + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MixMIMBlock(TransformerEncoderLayer): + """MixMIM Block. Implements one block in MixMIM. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + num_fcs (int): The number of linear layers in a block. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + num_fcs=2, + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(mlp_ratio * embed_dims), + drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + self.window_size = min(self.input_resolution) + + self.attn = MixMIMWindowAttention( + embed_dims=embed_dims, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate) + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + def forward(self, x, attn_mask=None): + H, W = self.input_resolution + B, L, C = x.shape + + shortcut = x + x = self.ln1(x) + x = x.view(B, H, W, C) + + # partition windows + x_windows = self.window_partition( + x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + if attn_mask is not None: + attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1 + attn_mask = attn_mask.view(B, H, W, 1) + attn_mask = self.window_partition(attn_mask, self.window_size) + attn_mask = attn_mask.view(-1, self.window_size * self.window_size, + 1) + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + x = self.window_reverse(attn_windows, H, W, + self.window_size) # B H' W' C + + x = x.view(B, H * W, C) + + x = shortcut + self.drop_path(x) + + x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath + + return x + + +class MixMIMLayer(BaseModule): + """Implements one MixMIM layer, which may contains several MixMIM blocks. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + depth (int): The number of blocks in this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + downsample (class, optional): Downsample the output of blocks b + y patch merging.Defaults to None. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + input_resolution: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio=4., + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=[0.], + norm_cfg=dict(type='LN'), + downsample=None, + use_checkpoint=False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + self.blocks.append( + MixMIMBlock( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate[i], + norm_cfg=norm_cfg)) + # patch merging layer + if downsample is not None: + self.downsample = downsample( + in_channels=embed_dims, + out_channels=2 * embed_dims, + norm_cfg=norm_cfg) + else: + self.downsample = None + + def forward(self, x, attn_mask=None): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask=attn_mask) + if self.downsample is not None: + x, _ = self.downsample(x, self.input_resolution) + return x + + def extra_repr(self) -> str: + return f'dim={self.embed_dims}, \ + input_resolution={self.input_resolution}, depth={self.depth}' + + +@MODELS.register_module() +class MixMIMTransformer(BaseBackbone): + """MixMIM backbone. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + attn_drop_rate (float): attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [11, 22, 44, 88] + }), + } + + def __init__( + self, + arch='base', + mlp_ratio=4, + img_size=224, + patch_size=4, + in_channels=3, + window_size=[14, 14, 14, 7], + qkv_bias=True, + patch_cfg=dict(), + norm_cfg=dict(type='LN'), + drop_rate=0.0, + drop_path_rate=0.0, + attn_drop_rate=0.0, + use_checkpoint=False, + init_cfg: Optional[dict] = None, + ) -> None: + super(MixMIMTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.encoder_stride = 32 + + self.num_layers = len(self.depths) + self.qkv_bias = qkv_bias + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.use_checkpoint = use_checkpoint + self.mlp_ratio = mlp_ratio + self.window_size = window_size + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + self.layers.append( + MixMIMLayer( + embed_dims=int(self.embed_dims * 2**i_layer), + input_resolution=(self.patch_resolution[0] // (2**i_layer), + self.patch_resolution[1] // + (2**i_layer)), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + proj_drop_rate=self.drop_rate, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:i_layer] + ):sum(self.depths[:i_layer + + 1])], + norm_cfg=norm_cfg, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=self.use_checkpoint)) + + self.num_features = int(self.embed_dims * 2**(self.num_layers - 1)) + self.drop_after_pos = nn.Dropout(p=self.drop_rate) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, self.embed_dims), + requires_grad=False) + + _, self.norm = build_norm_layer(norm_cfg, self.num_features) + + def forward(self, x: torch.Tensor): + x, _ = self.patch_embed(x) + + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for layer in self.layers: + x = layer(x, attn_mask=None) + + x = self.norm(x) + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return (x, ) + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = sum(self.depths) + 2 + + if not param_name.startswith(prefix): + # For subsequent module like neck and head + if param_name.startswith('neck'): + return num_layers - 2, num_layers + else: + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + stem_layers = ('patch_embed', 'absolute_pos_embed', 'pos_embed') + if any(stem in param_name for stem in stem_layers): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + block_id = param_name.split('.')[3] + + if block_id in ('downsample', 'reduction', 'norm'): + layer_depth = sum(self.depths[:layer_id + 1]) + else: + layer_depth = sum(self.depths[:layer_id]) + int(block_id) + 1 + else: + layer_depth = num_layers - 2 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/mlp_mixer.py b/mmpretrain/models/backbones/mlp_mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..26fb8ce0186c2451a5698c413ebf2bc24f33b6ec --- /dev/null +++ b/mmpretrain/models/backbones/mlp_mixer.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import to_2tuple +from .base_backbone import BaseBackbone + + +class MixerBlock(BaseModule): + """Mlp-Mixer basic block. + + Basic module of `MLP-Mixer: An all-MLP Architecture for Vision + `_ + + Args: + num_tokens (int): The number of patched tokens + embed_dims (int): The feature dimension + tokens_mlp_dims (int): The hidden dimension for tokens FFNs + channels_mlp_dims (int): The hidden dimension for channels FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_tokens, + embed_dims, + tokens_mlp_dims, + channels_mlp_dims, + drop_rate=0., + drop_path_rate=0., + num_fcs=2, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(MixerBlock, self).__init__(init_cfg=init_cfg) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + self.token_mix = FFN( + embed_dims=num_tokens, + feedforward_channels=tokens_mlp_dims, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + self.channel_mix = FFN( + embed_dims=embed_dims, + feedforward_channels=channels_mlp_dims, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def init_weights(self): + super(MixerBlock, self).init_weights() + for m in self.token_mix.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + for m in self.channel_mix.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + out = self.norm1(x).transpose(1, 2) + x = x + self.token_mix(out).transpose(1, 2) + x = self.channel_mix(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class MlpMixer(BaseBackbone): + """Mlp-Mixer backbone. + + Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision + `_ + + Args: + arch (str | dict): MLP Mixer architecture. If use string, choose from + 'small', 'base' and 'large'. If use dict, it should have below + keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of MLP blocks. + - **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs. + - **channels_mlp_dims** (int): The The hidden dimensions for + channels FFNs. + + Defaults to 'base'. + img_size (int | tuple): The input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + out_indices (Sequence | int): Output from which layer. + Defaults to -1, means the last layer. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + act_cfg (dict): The activation config for FFNs. Default GELU. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each mixer block layer. + Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 512, + 'num_layers': 8, + 'tokens_mlp_dims': 256, + 'channels_mlp_dims': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'tokens_mlp_dims': 384, + 'channels_mlp_dims': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'tokens_mlp_dims': 512, + 'channels_mlp_dims': 4096, + }), + } + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(MlpMixer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'tokens_mlp_dims', + 'channels_mlp_dims' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.tokens_mlp_dims = self.arch_settings['tokens_mlp_dims'] + self.channels_mlp_dims = self.arch_settings['channels_mlp_dims'] + + self.img_size = to_2tuple(img_size) + + _patch_cfg = dict( + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must be a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + else: + assert index >= self.num_layers, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + num_tokens=num_patches, + embed_dims=self.embed_dims, + tokens_mlp_dims=self.tokens_mlp_dims, + channels_mlp_dims=self.channels_mlp_dims, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(MixerBlock(**_layer_cfg)) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + assert x.shape[2:] == self.img_size, \ + "The MLP-Mixer doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + x, _ = self.patch_embed(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1: + x = self.norm1(x) + + if i in self.out_indices: + out = x.transpose(1, 2) + outs.append(out) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/mobilenet_v2.py b/mmpretrain/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..bca1418a13c4ed81c4666e7f53b0417c36b2e99b --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v2.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import make_divisible +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class InvertedResidual(BaseModule): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class MobileNetV2(BaseBackbone): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileNetV2, self).__init__(init_cfg) + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/mobilenet_v3.py b/mmpretrain/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..577dba94040dec5ecda9388819b8b5205f307dce --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v3.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import InvertedResidual +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class MobileNetV3(BaseBackbone): + """MobileNetV3 backbone. + + Args: + arch (str): Architecture of mobilnetv3, from {small, large}. + Default: small. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (None or Sequence[int]): Output from which stages. + Default: None, which means output tensors from final stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'small_075': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 32, True, 'HSwish', 2], + [5, 192, 32, True, 'HSwish', 1], + [5, 192, 32, True, 'HSwish', 1], + [5, 96, 40, True, 'HSwish', 1], + [5, 120, 40, True, 'HSwish', 1], + [5, 240, 72, True, 'HSwish', 2], + [5, 432, 72, True, 'HSwish', 1], + [5, 432, 72, True, 'HSwish', 1]], + 'small_050': [[3, 16, 8, True, 'ReLU', 2], + [3, 40, 16, False, 'ReLU', 2], + [3, 56, 16, False, 'ReLU', 1], + [5, 64, 24, True, 'HSwish', 2], + [5, 144, 24, True, 'HSwish', 1], + [5, 144, 24, True, 'HSwish', 1], + [5, 72, 24, True, 'HSwish', 1], + [5, 72, 24, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 2], + [5, 288, 48, True, 'HSwish', 1], + [5, 288, 48, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], + [3, 64, 24, False, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN', eps=0.001, momentum=0.01), + out_indices=None, + frozen_stages=-1, + norm_eval=False, + with_cp=False, + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + nonlinearity='leaky_relu'), + dict(type='Normal', layer=['Linear'], std=0.01), + dict(type='Constant', layer=['BatchNorm2d'], val=1) + ]): + super(MobileNetV3, self).__init__(init_cfg) + assert arch in self.arch_settings + if out_indices is None: + out_indices = (12, ) if 'small' in arch else (16, ) + for order, index in enumerate(out_indices): + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch]) + 2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch]) + 2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layers = self._make_layer() + self.feat_dim = self.arch_settings[arch][-1][1] + + def _make_layer(self): + layers = [] + layer_setting = self.arch_settings[self.arch] + in_channels = 16 + + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict( + type='HSigmoid', + bias=3, + divisor=6, + min_value=0, + max_value=1))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # Build the last layer before pooling + # TODO: No dilation + layer = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = 'layer{}'.format(len(layer_setting) + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV3, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/mobileone.py b/mmpretrain/models/backbones/mobileone.py new file mode 100644 index 0000000000000000000000000000000000000000..1111441af82d43a49d15ecbb5dc0778fc9f87596 --- /dev/null +++ b/mmpretrain/models/backbones/mobileone.py @@ -0,0 +1,515 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from official impl https://github.com/apple/ml-mobileone/blob/main/mobileone.py # noqa: E501 +from typing import Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .base_backbone import BaseBackbone + + +class MobileOneBlock(BaseModule): + """MobileOne block for MobileOne backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + kernel_size (int): The kernel size of the convs in the block. If the + kernel size is large than 1, there will be a ``branch_scale`` in + the block. + num_convs (int): Number of the convolution branches in the block. + stride (int): Stride of convolution layers. Defaults to 1. + padding (int): Padding of the convolution layers. Defaults to 1. + dilation (int): Dilation of the convolution layers. Defaults to 1. + groups (int): Groups of the convolution layers. Defaults to 1. + se_cfg (None or dict): The configuration of the se module. + Defaults to None. + norm_cfg (dict): Configuration to construct and config norm layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + deploy (bool): Whether the model structure is in the deployment mode. + Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int, + num_convs: int, + stride: int = 1, + padding: int = 1, + dilation: int = 1, + groups: int = 1, + se_cfg: Optional[dict] = None, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='ReLU'), + deploy: bool = False, + init_cfg: Optional[dict] = None): + super(MobileOneBlock, self).__init__(init_cfg) + + assert se_cfg is None or isinstance(se_cfg, dict) + if se_cfg is not None: + self.se = SELayer(channels=out_channels, **se_cfg) + else: + self.se = nn.Identity() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_conv_branches = num_convs + self.stride = stride + self.padding = padding + self.se_cfg = se_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.deploy = deploy + self.groups = groups + self.dilation = dilation + + if deploy: + self.branch_reparam = build_conv_layer( + conv_cfg, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=self.groups, + stride=stride, + padding=padding, + dilation=dilation, + bias=True) + else: + # judge if input shape and output shape are the same. + # If true, add a normalized identity shortcut. + if out_channels == in_channels and stride == 1: + self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.branch_norm = None + + self.branch_scale = None + if kernel_size > 1: + self.branch_scale = self.create_conv_bn(kernel_size=1) + + self.branch_conv_list = ModuleList() + for _ in range(num_convs): + self.branch_conv_list.append( + self.create_conv_bn( + kernel_size=kernel_size, + padding=padding, + dilation=dilation)) + + self.act = build_activation_layer(act_cfg) + + def create_conv_bn(self, kernel_size, dilation=1, padding=0): + """cearte a (conv + bn) Sequential layer.""" + conv_bn = Sequential() + conv_bn.add_module( + 'conv', + build_conv_layer( + self.conv_cfg, + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + groups=self.groups, + stride=self.stride, + dilation=dilation, + padding=padding, + bias=False)) + conv_bn.add_module( + 'norm', + build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) + + return conv_bn + + def forward(self, x): + + def _inner_forward(inputs): + if self.deploy: + return self.branch_reparam(inputs) + + inner_out = 0 + if self.branch_norm is not None: + inner_out = self.branch_norm(inputs) + + if self.branch_scale is not None: + inner_out += self.branch_scale(inputs) + + for branch_conv in self.branch_conv_list: + inner_out += branch_conv(inputs) + + return inner_out + + return self.act(self.se(_inner_forward(x))) + + def switch_to_deploy(self): + """Switch the model structure from training mode to deployment mode.""" + if self.deploy: + return + assert self.norm_cfg['type'] == 'BN', \ + "Switch is not allowed when norm_cfg['type'] != 'BN'." + + reparam_weight, reparam_bias = self.reparameterize() + self.branch_reparam = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + self.branch_reparam.weight.data = reparam_weight + self.branch_reparam.bias.data = reparam_bias + + for param in self.parameters(): + param.detach_() + delattr(self, 'branch_conv_list') + if hasattr(self, 'branch_scale'): + delattr(self, 'branch_scale') + delattr(self, 'branch_norm') + + self.deploy = True + + def reparameterize(self): + """Fuse all the parameters of all branches. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all + branches. the first element is the weights and the second is + the bias. + """ + weight_conv, bias_conv = 0, 0 + for branch_conv in self.branch_conv_list: + weight, bias = self._fuse_conv_bn(branch_conv) + weight_conv += weight + bias_conv += bias + + weight_scale, bias_scale = 0, 0 + if self.branch_scale is not None: + weight_scale, bias_scale = self._fuse_conv_bn(self.branch_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + weight_scale = F.pad(weight_scale, [pad, pad, pad, pad]) + + weight_norm, bias_norm = 0, 0 + if self.branch_norm: + tmp_conv_bn = self._norm_to_conv(self.branch_norm) + weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) + + return (weight_conv + weight_scale + weight_norm, + bias_conv + bias_scale + bias_norm) + + def _fuse_conv_bn(self, branch): + """Fuse the parameters in a branch with a conv and bn. + + Args: + branch (mmcv.runner.Sequential): A branch with conv and bn. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + + std = (running_var + eps).sqrt() + fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * kernel + fused_bias = beta - running_mean * gamma / std + + return fused_weight, fused_bias + + def _norm_to_conv(self, branch_nrom): + """Convert a norm layer to a conv-bn sequence towards + ``self.kernel_size``. + + Args: + branch (nn.BatchNorm2d): A branch only with bn in the block. + + Returns: + (mmcv.runner.Sequential): a sequential with conv and bn. + """ + input_dim = self.in_channels // self.groups + conv_weight = torch.zeros( + (self.in_channels, input_dim, self.kernel_size, self.kernel_size), + dtype=branch_nrom.weight.dtype) + + for i in range(self.in_channels): + conv_weight[i, i % input_dim, self.kernel_size // 2, + self.kernel_size // 2] = 1 + conv_weight = conv_weight.to(branch_nrom.weight.device) + + tmp_conv = self.create_conv_bn(kernel_size=self.kernel_size) + tmp_conv.conv.weight.data = conv_weight + tmp_conv.norm = branch_nrom + return tmp_conv + + +@MODELS.register_module() +class MobileOne(BaseBackbone): + """MobileOne backbone. + + A PyTorch impl of : `An Improved One millisecond Mobile Backbone + `_ + + Args: + arch (str | dict): MobileOne architecture. If use string, choose + from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should + have below keys: + + - num_blocks (Sequence[int]): Number of blocks in each stage. + - width_factor (Sequence[float]): Width factor in each stage. + - num_conv_branches (Sequence[int]): Number of conv branches + in each stage. + - num_se_blocks (Sequence[int]): Number of SE layers in each + stage, all the SE layers are placed in the subsequent order + in each stage. + + Defaults to 's0'. + in_channels (int): Number of input image channels. Default: 3. + out_indices (Sequence[int] | int): Output from which stages. + Defaults to ``(3, )``. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + deploy (bool): Whether to switch the model structure to deployment + mode. Defaults to False. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> from mmpretrain.models import MobileOne + >>> import torch + >>> x = torch.rand(1, 3, 224, 224) + >>> model = MobileOne("s0", out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> outputs = model(x) + >>> for out in outputs: + ... print(tuple(out.shape)) + (1, 48, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 1024, 7, 7) + """ + + arch_zoo = { + 's0': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[0.75, 1.0, 1.0, 2.0], + num_conv_branches=[4, 4, 4, 4], + num_se_blocks=[0, 0, 0, 0]), + 's1': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[1.5, 1.5, 2.0, 2.5], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's2': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[1.5, 2.0, 2.5, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's3': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[2.0, 2.5, 3.0, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's4': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[3.0, 3.5, 3.5, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 5, 1]) + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + se_cfg=dict(ratio=16), + deploy=False, + norm_eval=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm']) + ]): + super(MobileOne, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_zoo, f'"arch": "{arch}"' \ + f' is not one of the {list(self.arch_zoo.keys())}' + arch = self.arch_zoo[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + self.arch = arch + for k, value in self.arch.items(): + assert isinstance(value, list) and len(value) == 4, \ + f'the value of {k} in arch must be list with 4 items.' + + self.in_channels = in_channels + self.deploy = deploy + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.se_cfg = se_cfg + self.act_cfg = act_cfg + + base_channels = [64, 128, 256, 512] + channels = min(64, + int(base_channels[0] * self.arch['width_factor'][0])) + self.stage0 = MobileOneBlock( + self.in_channels, + channels, + stride=2, + kernel_size=3, + num_convs=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + deploy=deploy) + + self.in_planes = channels + self.stages = [] + for i, num_blocks in enumerate(self.arch['num_blocks']): + planes = int(base_channels[i] * self.arch['width_factor'][i]) + + stage = self._make_stage(planes, num_blocks, + arch['num_se_blocks'][i], + arch['num_conv_branches'][i]) + + stage_name = f'stage{i + 1}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + def _make_stage(self, planes, num_blocks, num_se, num_conv_branches): + strides = [2] + [1] * (num_blocks - 1) + if num_se > num_blocks: + raise ValueError('Number of SE blocks cannot ' + 'exceed number of layers.') + blocks = [] + for i in range(num_blocks): + use_se = False + if i >= (num_blocks - num_se): + use_se = True + + blocks.append( + # Depthwise conv + MobileOneBlock( + in_channels=self.in_planes, + out_channels=self.in_planes, + kernel_size=3, + num_convs=num_conv_branches, + stride=strides[i], + padding=1, + groups=self.in_planes, + se_cfg=self.se_cfg if use_se else None, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy)) + + blocks.append( + # Pointwise conv + MobileOneBlock( + in_channels=self.in_planes, + out_channels=planes, + kernel_size=1, + num_convs=num_conv_branches, + stride=1, + padding=0, + se_cfg=self.se_cfg if use_se else None, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy)) + + self.in_planes = planes + + return Sequential(*blocks) + + def forward(self, x): + x = self.stage0(x) + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stage0.eval() + for param in self.stage0.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = getattr(self, f'stage{i+1}') + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + """switch the mobile to train mode or not.""" + super(MobileOne, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + """switch the model to deploy mode, which has smaller amount of + parameters and calculations.""" + for m in self.modules(): + if isinstance(m, MobileOneBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/mobilevit.py b/mmpretrain/models/backbones/mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4043fe46049a4d1bddecc6b7b3768236318e82 --- /dev/null +++ b/mmpretrain/models/backbones/mobilevit.py @@ -0,0 +1,431 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Callable, Optional, Sequence + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer +from torch import nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .mobilenet_v2 import InvertedResidual +from .vision_transformer import TransformerEncoderLayer + + +class MobileVitBlock(nn.Module): + """MobileViT block. + + According to the paper, the MobileViT block has a local representation. + a transformer-as-convolution layer which consists of a global + representation with unfolding and folding, and a final fusion layer. + + Args: + in_channels (int): Number of input image channels. + transformer_dim (int): Number of transformer channels. + ffn_dim (int): Number of ffn channels in transformer block. + out_channels (int): Number of channels in output. + conv_ksize (int): Conv kernel size in local representation + and fusion. Defaults to 3. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Defaults to dict(type='Swish'). + num_transformer_blocks (int): Number of transformer blocks in + a MobileViT block. Defaults to 2. + patch_size (int): Patch size for unfolding and folding. + Defaults to 2. + num_heads (int): Number of heads in global representation. + Defaults to 4. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + no_fusion (bool): Whether to remove the fusion layer. + Defaults to False. + transformer_norm_cfg (dict, optional): Config dict for normalization + layer in transformer. Defaults to dict(type='LN'). + """ + + def __init__( + self, + in_channels: int, + transformer_dim: int, + ffn_dim: int, + out_channels: int, + conv_ksize: int = 3, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='Swish'), + num_transformer_blocks: int = 2, + patch_size: int = 2, + num_heads: int = 4, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + no_fusion: bool = False, + transformer_norm_cfg: Callable = dict(type='LN'), + ): + super(MobileVitBlock, self).__init__() + + self.local_rep = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + padding=int((conv_ksize - 1) / 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=in_channels, + out_channels=transformer_dim, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=None, + act_cfg=None), + ) + + global_rep = [ + TransformerEncoderLayer( + embed_dims=transformer_dim, + num_heads=num_heads, + feedforward_channels=ffn_dim, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + qkv_bias=True, + act_cfg=dict(type='Swish'), + norm_cfg=transformer_norm_cfg) + for _ in range(num_transformer_blocks) + ] + global_rep.append( + build_norm_layer(transformer_norm_cfg, transformer_dim)[1]) + self.global_rep = nn.Sequential(*global_rep) + + self.conv_proj = ConvModule( + in_channels=transformer_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if no_fusion: + self.conv_fusion = None + else: + self.conv_fusion = ConvModule( + in_channels=in_channels + out_channels, + out_channels=out_channels, + kernel_size=conv_ksize, + padding=int((conv_ksize - 1) / 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.patch_size = (patch_size, patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + # Local representation + x = self.local_rep(x) + + # Unfold (feature map -> patches) + patch_h, patch_w = self.patch_size + B, C, H, W = x.shape + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil( + W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa + num_patches = num_patch_h * num_patch_w # N + interpolate = False + if new_h != H or new_w != W: + # Note: Padding can be done, but then it needs to be handled in attention function. # noqa + x = F.interpolate( + x, size=(new_h, new_w), mode='bilinear', align_corners=False) + interpolate = True + + # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] + x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, + patch_w).transpose(1, 2) + # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa + x = x.reshape(B, C, num_patches, + self.patch_area).transpose(1, 3).reshape( + B * self.patch_area, num_patches, -1) + + # Global representations + x = self.global_rep(x) + + # Fold (patch -> feature map) + # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] + x = x.contiguous().view(B, self.patch_area, num_patches, -1) + x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, + patch_h, patch_w) + # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa + x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, + num_patch_w * patch_w) + if interpolate: + x = F.interpolate( + x, size=(H, W), mode='bilinear', align_corners=False) + + x = self.conv_proj(x) + if self.conv_fusion is not None: + x = self.conv_fusion(torch.cat((shortcut, x), dim=1)) + return x + + +@MODELS.register_module() +class MobileViT(BaseBackbone): + """MobileViT backbone. + + A PyTorch implementation of : `MobileViT: Light-weight, General-purpose, + and Mobile-friendly Vision Transformer `_ + + Modified from the `official repo + `_ + and `timm + `_. + + Args: + arch (str | List[list]): Architecture of MobileViT. + + - If a string, choose from "small", "x_small" and "xx_small". + + - If a list, every item should be also a list, and the first item + of the sub-list can be chosen from "moblienetv2" and "mobilevit", + which indicates the type of this layer sequence. If "mobilenetv2", + the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer` + (except ``in_channels``) and if "mobilevit", the other items are + the arguments of :attr:`~MobileViT.make_mobilevit_layer` + (except ``in_channels``). + + Defaults to "small". + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Channels of stem layer. Defaults to 16. + last_exp_factor (int): Channels expand factor of last layer. + Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Defaults to (4, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Defaults to dict(type='Swish'). + init_cfg (dict, optional): Initialization config dict. + """ # noqa + + # Parameters to build layers. The first param is the type of layer. + # For `mobilenetv2` layer, the rest params from left to right are: + # out channels, stride, num of blocks, expand_ratio. + # For `mobilevit` layer, the rest params from left to right are: + # out channels, stride, transformer_channels, ffn channels, + # num of transformer blocks, expand_ratio. + arch_settings = { + 'small': [ + ['mobilenetv2', 32, 1, 1, 4], + ['mobilenetv2', 64, 2, 3, 4], + ['mobilevit', 96, 2, 144, 288, 2, 4], + ['mobilevit', 128, 2, 192, 384, 4, 4], + ['mobilevit', 160, 2, 240, 480, 3, 4], + ], + 'x_small': [ + ['mobilenetv2', 32, 1, 1, 4], + ['mobilenetv2', 48, 2, 3, 4], + ['mobilevit', 64, 2, 96, 192, 2, 4], + ['mobilevit', 80, 2, 120, 240, 4, 4], + ['mobilevit', 96, 2, 144, 288, 3, 4], + ], + 'xx_small': [ + ['mobilenetv2', 16, 1, 1, 2], + ['mobilenetv2', 24, 2, 3, 2], + ['mobilevit', 48, 2, 64, 128, 2, 2], + ['mobilevit', 64, 2, 80, 160, 4, 2], + ['mobilevit', 80, 2, 96, 192, 3, 2], + ] + } + + def __init__(self, + arch='small', + in_channels=3, + stem_channels=16, + last_exp_factor=4, + out_indices=(4, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='Swish'), + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileViT, self).__init__(init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a list.' + arch = self.arch_settings[arch] + + self.arch = arch + self.num_stages = len(arch) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + if frozen_stages not in range(-1, self.num_stages): + raise ValueError('frozen_stages must be in range(-1, ' + f'{self.num_stages}). ' + f'But received {frozen_stages}') + self.frozen_stages = frozen_stages + + _make_layer_func = { + 'mobilenetv2': self.make_mobilenetv2_layer, + 'mobilevit': self.make_mobilevit_layer, + } + + self.stem = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + in_channels = stem_channels + layers = [] + for i, layer_settings in enumerate(arch): + layer_type, settings = layer_settings[0], layer_settings[1:] + layer, out_channels = _make_layer_func[layer_type](in_channels, + *settings) + layers.append(layer) + in_channels = out_channels + self.layers = nn.Sequential(*layers) + + self.conv_1x1_exp = ConvModule( + in_channels=in_channels, + out_channels=last_exp_factor * in_channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + @staticmethod + def make_mobilevit_layer(in_channels, + out_channels, + stride, + transformer_dim, + ffn_dim, + num_transformer_blocks, + expand_ratio=4): + """Build mobilevit layer, which consists of one InvertedResidual and + one MobileVitBlock. + + Args: + in_channels (int): The input channels. + out_channels (int): The output channels. + stride (int): The stride of the first 3x3 convolution in the + ``InvertedResidual`` layers. + transformer_dim (int): The channels of the transformer layers. + ffn_dim (int): The mid-channels of the feedforward network in + transformer layers. + num_transformer_blocks (int): The number of transformer blocks. + expand_ratio (int): adjusts number of channels of the hidden layer + in ``InvertedResidual`` by this amount. Defaults to 4. + """ + layer = [] + layer.append( + InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + act_cfg=dict(type='Swish'), + )) + layer.append( + MobileVitBlock( + in_channels=out_channels, + transformer_dim=transformer_dim, + ffn_dim=ffn_dim, + out_channels=out_channels, + num_transformer_blocks=num_transformer_blocks, + )) + return nn.Sequential(*layer), out_channels + + @staticmethod + def make_mobilenetv2_layer(in_channels, + out_channels, + stride, + num_blocks, + expand_ratio=4): + """Build mobilenetv2 layer, which consists of several InvertedResidual + layers. + + Args: + in_channels (int): The input channels. + out_channels (int): The output channels. + stride (int): The stride of the first 3x3 convolution in the + ``InvertedResidual`` layers. + num_blocks (int): The number of ``InvertedResidual`` blocks. + expand_ratio (int): adjusts number of channels of the hidden layer + in ``InvertedResidual`` by this amount. Defaults to 4. + """ + layer = [] + for i in range(num_blocks): + stride = stride if i == 0 else 1 + + layer.append( + InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + act_cfg=dict(type='Swish'), + )) + in_channels = out_channels + return nn.Sequential(*layer), out_channels + + def _freeze_stages(self): + for i in range(0, self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileViT, self).train(mode) + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + x = self.conv_1x1_exp(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/mvit.py b/mmpretrain/models/backbones/mvit.py new file mode 100644 index 0000000000000000000000000000000000000000..68aee97ddf3077ca58e488f38e9d9422b171d691 --- /dev/null +++ b/mmpretrain/models/backbones/mvit.py @@ -0,0 +1,700 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import to_2tuple + +from ..builder import BACKBONES +from ..utils import resize_pos_embed +from .base_backbone import BaseBackbone + + +def resize_decomposed_rel_pos(rel_pos, q_size, k_size): + """Get relative positional embeddings according to the relative positions + of query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + resized = F.interpolate( + # (L, C) -> (1, C, L) + rel_pos.transpose(0, 1).unsqueeze(0), + size=max_rel_dist, + mode='linear', + ) + # (1, C, L) -> (L, C) + resized = resized.squeeze(0).transpose(0, 1) + else: + resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_h_ratio = max(k_size / q_size, 1.0) + k_h_ratio = max(q_size / k_size, 1.0) + q_coords = torch.arange(q_size)[:, None] * q_h_ratio + k_coords = torch.arange(k_size)[None, :] * k_h_ratio + relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio + + return resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, + q, + q_shape, + k_shape, + rel_pos_h, + rel_pos_w, + has_cls_token=False): + """Spatial Relative Positional Embeddings.""" + sp_idx = 1 if has_cls_token else 0 + B, num_heads, _, C = q.shape + q_h, q_w = q_shape + k_h, k_w = k_shape + + Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h) + Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w) + + r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C) + rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) + rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) + rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :] + + attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + attn_map += rel_pos_embed + attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class MLP(BaseModule): + """Two-layer multilayer perceptron. + + Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows + different input and output channel numbers. + + Args: + in_channels (int): The number of input channels. + hidden_channels (int, optional): The number of hidden layer channels. + If None, same as the ``in_channels``. Defaults to None. + out_channels (int, optional): The number of output channels. If None, + same as the ``in_channels``. Defaults to None. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + act_cfg=dict(type='GELU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Linear(hidden_channels, out_channels) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def attention_pool(x: torch.Tensor, + pool: nn.Module, + in_size: tuple, + norm: Optional[nn.Module] = None): + """Pooling the feature tokens. + + Args: + x (torch.Tensor): The input tensor, should be with shape + ``(B, num_heads, L, C)`` or ``(B, L, C)``. + pool (nn.Module): The pooling module. + in_size (Tuple[int]): The shape of the input feature map. + norm (nn.Module, optional): The normalization module. + Defaults to None. + """ + ndim = x.ndim + if ndim == 4: + B, num_heads, L, C = x.shape + elif ndim == 3: + num_heads = 1 + B, L, C = x.shape + else: + raise RuntimeError(f'Unsupported input dimension {x.shape}') + + H, W = in_size + assert L == H * W + + # (B, num_heads, H*W, C) -> (B*num_heads, C, H, W) + x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous() + x = pool(x) + out_size = x.shape[-2:] + + # (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C) + x = x.reshape(B, num_heads, C, -1).transpose(2, 3) + + if norm is not None: + x = norm(x) + + if ndim == 3: + x = x.squeeze(1) + + return x, out_size + + +class MultiScaleAttention(BaseModule): + """Multiscale Multi-head Attention block. + + Args: + in_dims (int): Number of input channels. + out_dims (int): Number of output channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key and + value. Defaults to True. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='LN')``. + pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + stride_q (int): stride size for q pooling layer. Defaults to 1. + stride_kv (int): stride size for kv pooling layer. Defaults to 1. + rel_pos_spatial (bool): Whether to enable the spatial relative + position embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + input_size (Tuple[int], optional): The input resolution, necessary + if enable the ``rel_pos_spatial``. Defaults to None. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__(self, + in_dims, + out_dims, + num_heads, + qkv_bias=True, + norm_cfg=dict(type='LN'), + pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + rel_pos_spatial=False, + residual_pooling=True, + input_size=None, + rel_pos_zero_init=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.in_dims = in_dims + self.out_dims = out_dims + + head_dim = out_dims // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(out_dims, out_dims) + + # qkv pooling + pool_padding = [k // 2 for k in pool_kernel] + pool_dims = out_dims // num_heads + + def build_pooling(stride): + pool = nn.Conv2d( + pool_dims, + pool_dims, + pool_kernel, + stride=stride, + padding=pool_padding, + groups=pool_dims, + bias=False, + ) + norm = build_norm_layer(norm_cfg, pool_dims)[1] + return pool, norm + + self.pool_q, self.norm_q = build_pooling(stride_q) + self.pool_k, self.norm_k = build_pooling(stride_kv) + self.pool_v, self.norm_v = build_pooling(stride_kv) + + self.residual_pooling = residual_pooling + + self.rel_pos_spatial = rel_pos_spatial + self.rel_pos_zero_init = rel_pos_zero_init + if self.rel_pos_spatial: + # initialize relative positional embeddings + assert input_size[0] == input_size[1] + + size = input_size[0] + rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim)) + + def init_weights(self): + """Weight initialization.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress rel_pos_zero_init if use pretrained model. + return + + if not self.rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x, in_size): + """Forward the MultiScaleAttention.""" + B, N, _ = x.shape # (B, H*W, C) + + # qkv: (B, H*W, 3, num_heads, C) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1) + # q, k, v: (B, num_heads, H*W, C) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + + q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q) + k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k) + v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_spatial: + attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape, + self.rel_pos_h, self.rel_pos_w) + + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + # (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C') + x = x.transpose(1, 2).reshape(B, -1, self.out_dims) + x = self.proj(x) + + return x, q_shape + + +class MultiScaleBlock(BaseModule): + """Multiscale Transformer blocks. + + Args: + in_dims (int): Number of input channels. + out_dims (int): Number of output channels. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of hidden dimensions in MLP layers. + Defaults to 4.0. + qkv_bias (bool): If True, add a learnable bias to query, key and + value. Defaults to True. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='LN')``. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + stride_q (int): stride size for q pooling layer. Defaults to 1. + stride_kv (int): stride size for kv pooling layer. Defaults to 1. + rel_pos_spatial (bool): Whether to enable the spatial relative + position embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in + attention layers. If False, multiply it in MLP layers. + Defaults to True. + input_size (Tuple[int], optional): The input resolution, necessary + if enable the ``rel_pos_spatial``. Defaults to None. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__( + self, + in_dims, + out_dims, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + qkv_pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + rel_pos_spatial=True, + residual_pooling=True, + dim_mul_in_attention=True, + input_size=None, + rel_pos_zero_init=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_dims = in_dims + self.out_dims = out_dims + self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] + self.dim_mul_in_attention = dim_mul_in_attention + + attn_dims = out_dims if dim_mul_in_attention else in_dims + self.attn = MultiScaleAttention( + in_dims, + attn_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + pool_kernel=qkv_pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + rel_pos_spatial=rel_pos_spatial, + residual_pooling=residual_pooling, + input_size=input_size, + rel_pos_zero_init=rel_pos_zero_init) + self.drop_path = DropPath( + drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1] + + self.mlp = MLP( + in_channels=attn_dims, + hidden_channels=int(attn_dims * mlp_ratio), + out_channels=out_dims, + act_cfg=act_cfg) + + if in_dims != out_dims: + self.proj = nn.Linear(in_dims, out_dims) + else: + self.proj = None + + if stride_q > 1: + kernel_skip = stride_q + 1 + padding_skip = int(kernel_skip // 2) + self.pool_skip = nn.MaxPool2d( + kernel_skip, stride_q, padding_skip, ceil_mode=False) + + if input_size is not None: + input_size = to_2tuple(input_size) + out_size = [size // stride_q for size in input_size] + self.init_out_size = out_size + else: + self.init_out_size = None + else: + self.pool_skip = None + self.init_out_size = input_size + + def forward(self, x, in_size): + x_norm = self.norm1(x) + x_attn, out_size = self.attn(x_norm, in_size) + + if self.dim_mul_in_attention and self.proj is not None: + skip = self.proj(x_norm) + else: + skip = x + + if self.pool_skip is not None: + skip, _ = attention_pool(skip, self.pool_skip, in_size) + + x = skip + self.drop_path(x_attn) + x_norm = self.norm2(x) + x_mlp = self.mlp(x_norm) + + if not self.dim_mul_in_attention and self.proj is not None: + skip = self.proj(x_norm) + else: + skip = x + + x = skip + self.drop_path(x_mlp) + + return x, out_size + + +@BACKBONES.register_module() +class MViT(BaseBackbone): + """Multi-scale ViT v2. + + A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers + for Classification and Detection `_ + + Inspiration from `the official implementation + `_ and `the detectron2 + implementation `_ + + Args: + arch (str | dict): MViT architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of layers. + - **num_heads** (int): The number of heads in attention + modules of the initial layer. + - **downscale_indices** (List[int]): The layer indices to downscale + the feature map. + + Defaults to 'base'. + img_size (int): The expected input image shape. Defaults to 224. + in_channels (int): The num of input channels. Defaults to 3. + out_scales (int | Sequence[int]): The output scale indices. + They should not exceed the length of ``downscale_indices``. + Defaults to -1, which means the last scale. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embedding vector resize. Defaults to "bicubic". + pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + dim_mul (int): The magnification for ``embed_dims`` in the downscale + layers. Defaults to 2. + head_mul (int): The magnification for ``num_heads`` in the downscale + layers. Defaults to 2. + adaptive_kv_stride (int): The stride size for kv pooling in the initial + layer. Defaults to 4. + rel_pos_spatial (bool): Whether to enable the spatial relative position + embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in + attention layers. If False, multiply it in MLP layers. + Defaults to True. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + mlp_ratio (float): Ratio of hidden dimensions in MLP layers. + Defaults to 4.0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN', eps=1e-6)``. + patch_cfg (dict): Config dict for the patch embedding layer. + Defaults to ``dict(kernel_size=7, stride=4, padding=3)``. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_backbone + >>> + >>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3]) + >>> model = build_backbone(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for i, output in enumerate(outputs): + >>> print(f'scale{i}: {output.shape}') + scale0: torch.Size([1, 96, 56, 56]) + scale1: torch.Size([1, 192, 28, 28]) + scale2: torch.Size([1, 384, 14, 14]) + scale3: torch.Size([1, 768, 7, 7]) + """ + arch_zoo = { + 'tiny': { + 'embed_dims': 96, + 'num_layers': 10, + 'num_heads': 1, + 'downscale_indices': [1, 3, 8] + }, + 'small': { + 'embed_dims': 96, + 'num_layers': 16, + 'num_heads': 1, + 'downscale_indices': [1, 3, 14] + }, + 'base': { + 'embed_dims': 96, + 'num_layers': 24, + 'num_heads': 1, + 'downscale_indices': [2, 5, 21] + }, + 'large': { + 'embed_dims': 144, + 'num_layers': 48, + 'num_heads': 2, + 'downscale_indices': [2, 8, 44] + }, + } + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + in_channels=3, + out_scales=-1, + drop_path_rate=0., + use_abs_pos_embed=False, + interpolate_mode='bicubic', + pool_kernel=(3, 3), + dim_mul=2, + head_mul=2, + adaptive_kv_stride=4, + rel_pos_spatial=True, + residual_pooling=True, + dim_mul_in_attention=True, + rel_pos_zero_init=False, + mlp_ratio=4., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + patch_cfg=dict(kernel_size=7, stride=4, padding=3), + init_cfg=None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'downscale_indices' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.num_heads = self.arch_settings['num_heads'] + self.downscale_indices = self.arch_settings['downscale_indices'] + self.num_scales = len(self.downscale_indices) + 1 + self.stage_indices = { + index - 1: i + for i, index in enumerate(self.downscale_indices) + } + self.stage_indices[self.num_layers - 1] = self.num_scales - 1 + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + + if isinstance(out_scales, int): + out_scales = [out_scales] + assert isinstance(out_scales, Sequence), \ + f'"out_scales" must by a sequence or int, ' \ + f'get {type(out_scales)} instead.' + for i, index in enumerate(out_scales): + if index < 0: + out_scales[i] = self.num_scales + index + assert 0 <= out_scales[i] <= self.num_scales, \ + f'Invalid out_scales {index}' + self.out_scales = sorted(list(out_scales)) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + # Set absolute position embedding + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.blocks = ModuleList() + out_dims_list = [self.embed_dims] + num_heads = self.num_heads + stride_kv = adaptive_kv_stride + input_size = self.patch_resolution + for i in range(self.num_layers): + if i in self.downscale_indices: + num_heads *= head_mul + stride_q = 2 + stride_kv = max(stride_kv // 2, 1) + else: + stride_q = 1 + + # Set output embed_dims + if dim_mul_in_attention and i in self.downscale_indices: + # multiply embed_dims in downscale layers. + out_dims = out_dims_list[-1] * dim_mul + elif not dim_mul_in_attention and i + 1 in self.downscale_indices: + # multiply embed_dims before downscale layers. + out_dims = out_dims_list[-1] * dim_mul + else: + out_dims = out_dims_list[-1] + + attention_block = MultiScaleBlock( + in_dims=out_dims_list[-1], + out_dims=out_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_cfg=norm_cfg, + qkv_pool_kernel=pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + rel_pos_spatial=rel_pos_spatial, + residual_pooling=residual_pooling, + dim_mul_in_attention=dim_mul_in_attention, + input_size=input_size, + rel_pos_zero_init=rel_pos_zero_init) + self.blocks.append(attention_block) + + input_size = attention_block.init_out_size + out_dims_list.append(out_dims) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + norm_layer = build_norm_layer(norm_cfg, out_dims)[1] + self.add_module(f'norm{stage_index}', norm_layer) + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + """Forward the MViT.""" + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + + outs = [] + for i, block in enumerate(self.blocks): + x, patch_resolution = block(x, patch_resolution) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + B, _, C = x.shape + x = getattr(self, f'norm{stage_index}')(x) + out = x.transpose(1, 2).reshape(B, C, *patch_resolution) + outs.append(out.contiguous()) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/poolformer.py b/mmpretrain/models/backbones/poolformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ad67043dbeb0ce6969c2770853342b30df2a74 --- /dev/null +++ b/mmpretrain/models/backbones/poolformer.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class PatchEmbed(nn.Module): + """Patch Embedding module implemented by a layer of convolution. + + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + Args: + patch_size (int): Patch size of the patch embedding. Defaults to 16. + stride (int): Stride of the patch embedding. Defaults to 16. + padding (int): Padding of the patch embedding. Defaults to 0. + in_chans (int): Input channels. Defaults to 3. + embed_dim (int): Output dimension of the patch embedding. + Defaults to 768. + norm_layer (module): Normalization module. Defaults to None (not use). + """ + + def __init__(self, + patch_size=16, + stride=16, + padding=0, + in_chans=3, + embed_dim=768, + norm_layer=None): + super().__init__() + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class Pooling(nn.Module): + """Pooling module. + + Args: + pool_size (int): Pooling size. Defaults to 3. + """ + + def __init__(self, pool_size=3): + super().__init__() + self.pool = nn.AvgPool2d( + pool_size, + stride=1, + padding=pool_size // 2, + count_include_pad=False) + + def forward(self, x): + return self.pool(x) - x + + +class Mlp(nn.Module): + """Mlp implemented by with 1*1 convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PoolFormerBlock(BaseModule): + """PoolFormer Block. + + Args: + dim (int): Embedding dim. + pool_size (int): Pooling size. Defaults to 3. + mlp_ratio (float): Mlp expansion ratio. Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='GN', num_groups=1)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-5. + """ + + def __init__(self, + dim, + pool_size=3, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = Pooling(pool_size=pool_size) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + +def basic_blocks(dim, + index, + layers, + pool_size=3, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + layer_scale_init_value=1e-5): + """ + generate PoolFormer blocks for a stage + return: PoolFormer blocks + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + blocks.append( + PoolFormerBlock( + dim, + pool_size=pool_size, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + )) + blocks = nn.Sequential(*blocks) + + return blocks + + +@MODELS.register_module() +class PoolFormer(BaseBackbone): + """PoolFormer. + + A PyTorch implementation of PoolFormer introduced by: + `MetaFormer is Actually What You Need for Vision `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``PoolFormer.arch_settings``. And if dict, it + should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - mlp_ratios (list[int]): Expansion ratio of MLPs. + - layer_scale_init_value (float): Init value for Layer Scale. + + Defaults to 'S12'. + + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + in_patch_size (int): The patch size of input image patch embedding. + Defaults to 7. + in_stride (int): The stride of input image patch embedding. + Defaults to 4. + in_pad (int): The padding of input image patch embedding. + Defaults to 2. + down_patch_size (int): The patch size of downsampling patch embedding. + Defaults to 3. + down_stride (int): The stride of downsampling patch embedding. + Defaults to 2. + down_pad (int): The padding of downsampling patch embedding. + Defaults to 1. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims, --mlp_ratios: + # embedding dims and mlp ratios for the four stages + # --downsamples: flags to apply downsampling or not in four blocks + arch_settings = { + 's12': { + 'layers': [2, 2, 6, 2], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's24': { + 'layers': [4, 4, 12, 4], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm48': { + 'layers': [8, 8, 24, 8], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + } + + def __init__(self, + arch='s12', + pool_size=3, + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., + drop_path_rate=0., + out_indices=-1, + frozen_stages=0, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + layers = arch['layers'] + embed_dims = arch['embed_dims'] + mlp_ratios = arch['mlp_ratios'] \ + if 'mlp_ratios' in arch else [4, 4, 4, 4] + layer_scale_init_value = arch['layer_scale_init_value'] \ + if 'layer_scale_init_value' in arch else 1e-5 + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chans=3, + embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + stage = basic_blocks( + embed_dims[i], + i, + layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], + embed_dim=embed_dims[i + 1])) + + self.network = nn.ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 7 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, + embed_dims[(i_layer + 1) // 2])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(PoolFormer, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/regnet.py b/mmpretrain/models/backbones/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..85dbdef0bfeb607ecddff1d68d1cf405b61bea65 --- /dev/null +++ b/mmpretrain/models/backbones/regnet.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResNet +from .resnext import Bottleneck + + +@MODELS.register_module() +class RegNet(ResNet): + """RegNet backbone. + + More details can be found in `paper `_ . + + Args: + arch (dict): The parameter of RegNets. + - w0 (int): initial width + - wa (float): slope of width + - wm (float): quantization parameter to quantize the width + - depth (int): depth of the backbone + - group_w (int): width of group + - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck. + strides (Sequence[int]): Strides of the first block of each stage. + base_channels (int): Base channels after stem layer. + in_channels (int): Number of input image channels. Default: 3. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: "pytorch". + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Default: -1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import RegNet + >>> import torch + >>> self = RegNet( + arch=dict( + w0=88, + wa=26.31, + wm=2.25, + group_w=48, + depth=25, + bot_mul=1.0)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 96, 8, 8) + (1, 192, 4, 4) + (1, 432, 2, 2) + (1, 1008, 1, 1) + """ + arch_settings = { + 'regnetx_400mf': + dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0), + 'regnetx_800mf': + dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0), + 'regnetx_1.6gf': + dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0), + 'regnetx_3.2gf': + dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0), + 'regnetx_4.0gf': + dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0), + 'regnetx_6.4gf': + dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0), + 'regnetx_8.0gf': + dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0), + 'regnetx_12gf': + dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0), + } + + def __init__(self, + arch, + in_channels=3, + stem_channels=32, + base_channels=32, + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + + # Generate RegNet parameters first + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the' \ + ' arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + widths, num_stages = self.generate_regnet( + arch['w0'], + arch['wa'], + arch['wm'], + arch['depth'], + ) + # Convert to per stage format + stage_widths, stage_blocks = self.get_stages_from_blocks(widths) + # Generate group widths and bot muls + group_widths = [arch['group_w'] for _ in range(num_stages)] + self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)] + # Adjust the compatibility of stage_widths and group_widths + stage_widths, group_widths = self.adjust_width_group( + stage_widths, self.bottleneck_ratio, group_widths) + + # Group params by stage + self.stage_widths = stage_widths + self.group_widths = group_widths + self.depth = sum(stage_blocks) + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + if self.deep_stem: + raise NotImplementedError( + 'deep_stem has not been implemented for RegNet') + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.stage_blocks = stage_blocks[:num_stages] + + self._make_stem_layer(in_channels, stem_channels) + + _in_channels = stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + group_width = self.group_widths[i] + width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i])) + stage_groups = width // group_width + + res_layer = self.make_res_layer( + block=Bottleneck, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=self.stage_widths[i], + expansion=1, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + base_channels=self.stage_widths[i], + groups=stage_groups, + width_per_group=group_width) + _in_channels = self.stage_widths[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = stage_widths[-1] + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def generate_regnet(self, + initial_width, + width_slope, + width_parameter, + depth, + divisor=8): + """Generates per block width from RegNet parameters. + + Args: + initial_width ([int]): Initial width of the backbone + width_slope ([float]): Slope of the quantized linear function + width_parameter ([int]): Parameter used to quantize the width. + depth ([int]): Depth of the backbone. + divisor (int): The divisor of channels. Defaults to 8. + + Returns: + tuple: tuple containing: + - list: Widths of each stage. + - int: The number of stages. + """ + assert width_slope >= 0 + assert initial_width > 0 + assert width_parameter > 1 + assert initial_width % divisor == 0 + widths_cont = np.arange(depth) * width_slope + initial_width + ks = np.round( + np.log(widths_cont / initial_width) / np.log(width_parameter)) + widths = initial_width * np.power(width_parameter, ks) + widths = np.round(np.divide(widths, divisor)) * divisor + num_stages = len(np.unique(widths)) + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages + + @staticmethod + def quantize_float(number, divisor): + """Converts a float to closest non-zero int divisible by divior. + + Args: + number (int): Original number to be quantized. + divisor (int): Divisor used to quantize the number. + + Returns: + int: quantized number that is divisible by devisor. + """ + return int(round(number / divisor) * divisor) + + def adjust_width_group(self, widths, bottleneck_ratio, groups): + """Adjusts the compatibility of widths and groups. + + Args: + widths (list[int]): Width of each stage. + bottleneck_ratio (float): Bottleneck ratio. + groups (int): number of groups in each stage + + Returns: + tuple(list): The adjusted widths and groups of each stage. + """ + bottleneck_width = [ + int(w * b) for w, b in zip(widths, bottleneck_ratio) + ] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)] + bottleneck_width = [ + self.quantize_float(w_bot, g) + for w_bot, g in zip(bottleneck_width, groups) + ] + widths = [ + int(w_bot / b) + for w_bot, b in zip(bottleneck_width, bottleneck_ratio) + ] + return widths, groups + + def get_stages_from_blocks(self, widths): + """Gets widths/stage_blocks of network at each stage. + + Args: + widths (list[int]): Width in each stage. + + Returns: + tuple(list): width and depth of each stage + """ + width_diff = [ + width != width_prev + for width, width_prev in zip(widths + [0], [0] + widths) + ] + stage_widths = [ + width for width, diff in zip(widths, width_diff[:-1]) if diff + ] + stage_blocks = np.diff([ + depth for depth, diff in zip(range(len(width_diff)), width_diff) + if diff + ]).tolist() + return stage_widths, stage_blocks + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/replknet.py b/mmpretrain/models/backbones/replknet.py new file mode 100644 index 0000000000000000000000000000000000000000..4dce4154fbe1d95806eec118b69ff70f0d74c1c6 --- /dev/null +++ b/mmpretrain/models/backbones/replknet.py @@ -0,0 +1,668 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +def conv_bn(in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1, + norm_cfg=dict(type='BN')): + """Construct a sequential conv and bn. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the convolution. + stride (int): stride of the convolution. + padding (int): stride of the convolution. + groups (int): groups of the convolution. + dilation (int): dilation of the convolution. Default to 1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + + Returns: + nn.Sequential(): A conv layer and a batch norm layer. + """ + if padding is None: + padding = kernel_size // 2 + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False)) + result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1]) + return result + + +def conv_bn_relu(in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1): + """Construct a sequential conv, bn and relu. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the convolution. + stride (int): stride of the convolution. + padding (int): stride of the convolution. + groups (int): groups of the convolution. + dilation (int): dilation of the convolution. Default to 1. + + Returns: + nn.Sequential(): A conv layer, batch norm layer and a relu function. + """ + + if padding is None: + padding = kernel_size // 2 + result = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + dilation=dilation) + result.add_module('nonlinear', nn.ReLU()) + return result + + +def fuse_bn(conv, bn): + """Fuse the parameters in a branch with a conv and bn. + + Args: + conv (nn.Conv2d): The convolution module to fuse. + bn (nn.BatchNorm2d): The batch normalization to fuse. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ReparamLargeKernelConv(BaseModule): + """Super large kernel implemented by with large convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the large convolution. + stride (int): stride of the large convolution. + groups (int): groups of the large convolution. + small_kernel (int): kernel_size of the small convolution. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + small_kernel, + small_kernel_merged=False, + init_cfg=None): + super(ReparamLargeKernelConv, self).__init__(init_cfg) + self.kernel_size = kernel_size + self.small_kernel = small_kernel + self.small_kernel_merged = small_kernel_merged + # We assume the conv does not change the feature map size, + # so padding = k//2. + # Otherwise, you may configure padding as you wish, + # and change the padding of small_conv accordingly. + padding = kernel_size // 2 + if small_kernel_merged: + self.lkb_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=True) + else: + self.lkb_origin = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups) + if small_kernel is not None: + assert small_kernel <= kernel_size + self.small_conv = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=small_kernel, + stride=stride, + padding=small_kernel // 2, + groups=groups, + dilation=1) + + def forward(self, inputs): + if hasattr(self, 'lkb_reparam'): + out = self.lkb_reparam(inputs) + else: + out = self.lkb_origin(inputs) + if hasattr(self, 'small_conv'): + out += self.small_conv(inputs) + return out + + def get_equivalent_kernel_bias(self): + eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, 'small_conv'): + small_k, small_b = fuse_bn(self.small_conv.conv, + self.small_conv.bn) + eq_b += small_b + # add to the central part + eq_k += nn.functional.pad( + small_k, [(self.kernel_size - self.small_kernel) // 2] * 4) + return eq_k, eq_b + + def merge_kernel(self): + """Switch the model structure from training mode to deployment mode.""" + if self.small_kernel_merged: + return + eq_k, eq_b = self.get_equivalent_kernel_bias() + self.lkb_reparam = nn.Conv2d( + in_channels=self.lkb_origin.conv.in_channels, + out_channels=self.lkb_origin.conv.out_channels, + kernel_size=self.lkb_origin.conv.kernel_size, + stride=self.lkb_origin.conv.stride, + padding=self.lkb_origin.conv.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.lkb_origin.conv.groups, + bias=True) + + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__('lkb_origin') + if hasattr(self, 'small_conv'): + self.__delattr__('small_conv') + + self.small_kernel_merged = True + + +class ConvFFN(BaseModule): + """Mlp implemented by with 1*1 convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + + Args: + in_channels (int): Dimension of input features. + internal_channels (int): Dimension of hidden features. + out_channels (int): Dimension of output features. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + internal_channels, + out_channels, + drop_path, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(ConvFFN, self).__init__(init_cfg) + self.drop_path = DropPath( + drop_prob=drop_path) if drop_path > 0. else nn.Identity() + self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1] + self.pw1 = conv_bn( + in_channels=in_channels, + out_channels=internal_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.pw2 = conv_bn( + in_channels=internal_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.nonlinear = build_activation_layer(act_cfg) + + def forward(self, x): + out = self.preffn_bn(x) + out = self.pw1(out) + out = self.nonlinear(out) + out = self.pw2(out) + return x + self.drop_path(out) + + +class RepLKBlock(BaseModule): + """RepLKBlock for RepLKNet backbone. + + Args: + in_channels (int): The input channels of the block. + dw_channels (int): The intermediate channels of the block, + i.e., input channels of the large kernel convolution. + block_lk_size (int): size of the super large kernel. Defaults: 31. + small_kernel (int): size of the parallel small kernel. Defaults: 5. + drop_path (float): Stochastic depth rate. Defaults: 0. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + act_cfg (dict): Config dict for activation layer. + Default to ``dict(type='ReLU')``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default to None + """ + + def __init__(self, + in_channels, + dw_channels, + block_lk_size, + small_kernel, + drop_path, + small_kernel_merged=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(RepLKBlock, self).__init__(init_cfg) + self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1) + self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1) + self.large_kernel = ReparamLargeKernelConv( + in_channels=dw_channels, + out_channels=dw_channels, + kernel_size=block_lk_size, + stride=1, + groups=dw_channels, + small_kernel=small_kernel, + small_kernel_merged=small_kernel_merged) + self.lk_nonlinear = build_activation_layer(act_cfg) + self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1] + self.drop_path = DropPath( + drop_prob=drop_path) if drop_path > 0. else nn.Identity() + # print('drop path:', self.drop_path) + + def forward(self, x): + out = self.prelkb_bn(x) + out = self.pw1(out) + out = self.large_kernel(out) + out = self.lk_nonlinear(out) + out = self.pw2(out) + return x + self.drop_path(out) + + +class RepLKNetStage(BaseModule): + """ + generate RepLKNet blocks for a stage + return: RepLKNet blocks + + Args: + channels (int): The input channels of the stage. + num_blocks (int): The number of blocks of the stage. + stage_lk_size (int): size of the super large kernel. Defaults: 31. + drop_path (float): Stochastic depth rate. Defaults: 0. + small_kernel (int): size of the parallel small kernel. Defaults: 5. + dw_ratio (float): The intermediate channels + expansion ratio of the block. Defaults: 1. + ffn_ratio (float): Mlp expansion ratio. Defaults to 4. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default to False. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + norm_intermediate_features (bool): Construct and config norm layer + or not. + Using True will normalize the intermediate features for + downstream dense prediction tasks. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default to None + """ + + def __init__( + self, + channels, + num_blocks, + stage_lk_size, + drop_path, + small_kernel, + dw_ratio=1, + ffn_ratio=4, + with_cp=False, # train with torch.utils.checkpoint to save memory + small_kernel_merged=False, + norm_intermediate_features=False, + norm_cfg=dict(type='BN'), + init_cfg=None): + super(RepLKNetStage, self).__init__(init_cfg) + self.with_cp = with_cp + blks = [] + for i in range(num_blocks): + block_drop_path = drop_path[i] if isinstance(drop_path, + list) else drop_path + # Assume all RepLK Blocks within a stage share the same lk_size. + # You may tune it on your own model. + replk_block = RepLKBlock( + in_channels=channels, + dw_channels=int(channels * dw_ratio), + block_lk_size=stage_lk_size, + small_kernel=small_kernel, + drop_path=block_drop_path, + small_kernel_merged=small_kernel_merged) + convffn_block = ConvFFN( + in_channels=channels, + internal_channels=int(channels * ffn_ratio), + out_channels=channels, + drop_path=block_drop_path) + blks.append(replk_block) + blks.append(convffn_block) + self.blocks = nn.ModuleList(blks) + if norm_intermediate_features: + self.norm = build_norm_layer(norm_cfg, channels)[1] + else: + self.norm = nn.Identity() + + def forward(self, x): + for blk in self.blocks: + if self.with_cp: + x = checkpoint.checkpoint(blk, x) # Save training memory + else: + x = blk(x) + return x + + +@MODELS.register_module() +class RepLKNet(BaseBackbone): + """RepLKNet backbone. + + A PyTorch impl of : + `Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs + `_ + + Args: + arch (str | dict): The parameter of RepLKNet. + If it's a dict, it should contain the following keys: + + - large_kernel_sizes (Sequence[int]): + Large kernel size in each stage. + - layers (Sequence[int]): Number of blocks in each stage. + - channels (Sequence[int]): Number of channels in each stage. + - small_kernel (int): size of the parallel small kernel. + - dw_ratio (float): The intermediate channels + expansion ratio of the block. + in_channels (int): Number of input image channels. Default to 3. + ffn_ratio (float): Mlp expansion ratio. Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Default to (3, ). + strides (Sequence[int]): Strides of the first block of each stage. + Default to (2, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default to (1, 1, 1, 1). + frozen_stages (int): Stages to be frozen + (all param fixed). -1 means not freezing any parameters. + Default to -1. + conv_cfg (dict | None): The config dict for conv layers. + Default to None. + norm_cfg (dict): The config dict for norm layers. + Default to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Default to ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default to False. + deploy (bool): Whether to switch the model structure to deployment + mode. Default to False. + norm_intermediate_features (bool): Construct and + config norm layer or not. + Using True will normalize the intermediate features + for downstream dense prediction tasks. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + arch_settings = { + '31B': + dict( + large_kernel_sizes=[31, 29, 27, 13], + layers=[2, 2, 18, 2], + channels=[128, 256, 512, 1024], + small_kernel=5, + dw_ratio=1), + '31L': + dict( + large_kernel_sizes=[31, 29, 27, 13], + layers=[2, 2, 18, 2], + channels=[192, 384, 768, 1536], + small_kernel=5, + dw_ratio=1), + 'XL': + dict( + large_kernel_sizes=[27, 27, 27, 13], + layers=[2, 2, 18, 2], + channels=[256, 512, 1024, 2048], + small_kernel=None, + dw_ratio=1.5), + } + + def __init__(self, + arch, + in_channels=3, + ffn_ratio=4, + out_indices=(3, ), + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + drop_path_rate=0.3, + small_kernel_merged=False, + norm_intermediate_features=False, + norm_eval=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(RepLKNet, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + assert len(arch['layers']) == len( + arch['channels']) == len(strides) == len(dilations) + assert max(out_indices) < len(arch['layers']) + + self.arch = arch + self.in_channels = in_channels + self.out_indices = out_indices + self.strides = strides + self.dilations = dilations + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_cp = with_cp + self.drop_path_rate = drop_path_rate + self.small_kernel_merged = small_kernel_merged + self.norm_eval = norm_eval + self.norm_intermediate_features = norm_intermediate_features + + self.out_indices = out_indices + + base_width = self.arch['channels'][0] + self.norm_intermediate_features = norm_intermediate_features + self.num_stages = len(self.arch['layers']) + self.stem = nn.ModuleList([ + conv_bn_relu( + in_channels=in_channels, + out_channels=base_width, + kernel_size=3, + stride=2, + padding=1, + groups=1), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=3, + stride=1, + padding=1, + groups=base_width), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=1, + stride=1, + padding=0, + groups=1), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=3, + stride=2, + padding=1, + groups=base_width) + ]) + # stochastic depth. We set block-wise drop-path rate. + # The higher level blocks are more likely to be dropped. + # This implementation follows Swin. + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(self.arch['layers'])) + ] + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + for stage_idx in range(self.num_stages): + layer = RepLKNetStage( + channels=self.arch['channels'][stage_idx], + num_blocks=self.arch['layers'][stage_idx], + stage_lk_size=self.arch['large_kernel_sizes'][stage_idx], + drop_path=dpr[sum(self.arch['layers'][:stage_idx] + ):sum(self.arch['layers'][:stage_idx + 1])], + small_kernel=self.arch['small_kernel'], + dw_ratio=self.arch['dw_ratio'], + ffn_ratio=ffn_ratio, + with_cp=with_cp, + small_kernel_merged=small_kernel_merged, + norm_intermediate_features=(stage_idx in out_indices)) + self.stages.append(layer) + if stage_idx < len(self.arch['layers']) - 1: + transition = nn.Sequential( + conv_bn_relu( + self.arch['channels'][stage_idx], + self.arch['channels'][stage_idx + 1], + 1, + 1, + 0, + groups=1), + conv_bn_relu( + self.arch['channels'][stage_idx + 1], + self.arch['channels'][stage_idx + 1], + 3, + stride=2, + padding=1, + groups=self.arch['channels'][stage_idx + 1])) + self.transitions.append(transition) + + def forward_features(self, x): + x = self.stem[0](x) + for stem_layer in self.stem[1:]: + if self.with_cp: + x = checkpoint.checkpoint(stem_layer, x) # save memory + else: + x = stem_layer(x) + + # Need the intermediate feature maps + outs = [] + for stage_idx in range(self.num_stages): + x = self.stages[stage_idx](x) + if stage_idx in self.out_indices: + outs.append(self.stages[stage_idx].norm(x)) + # For RepLKNet-XL normalize the features + # before feeding them into the heads + if stage_idx < self.num_stages - 1: + x = self.transitions[stage_idx](x) + return outs + + def forward(self, x): + x = self.forward_features(x) + return tuple(x) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RepLKNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + for m in self.modules(): + if hasattr(m, 'merge_kernel'): + m.merge_kernel() + self.small_kernel_merged = True diff --git a/mmpretrain/models/backbones/repmlp.py b/mmpretrain/models/backbones/repmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c06c4875710b33c57f2794c437034d93169b30 --- /dev/null +++ b/mmpretrain/models/backbones/repmlp.py @@ -0,0 +1,578 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from official impl at https://github.com/DingXiaoH/RepMLP. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.models.utils import SELayer, to_2tuple +from mmpretrain.registry import MODELS + + +def fuse_bn(conv_or_fc, bn): + """fuse conv and bn.""" + std = (bn.running_var + bn.eps).sqrt() + tmp_weight = bn.weight / std + tmp_weight = tmp_weight.reshape(-1, 1, 1, 1) + + if len(tmp_weight) == conv_or_fc.weight.size(0): + return (conv_or_fc.weight * tmp_weight, + bn.bias - bn.running_mean * bn.weight / std) + else: + # in RepMLPBlock, dim0 of fc3 weights and fc3_bn weights + # are different. + repeat_times = conv_or_fc.weight.size(0) // len(tmp_weight) + repeated = tmp_weight.repeat_interleave(repeat_times, 0) + fused_weight = conv_or_fc.weight * repeated + bias = bn.bias - bn.running_mean * bn.weight / std + fused_bias = (bias).repeat_interleave(repeat_times, 0) + return (fused_weight, fused_bias) + + +class PatchEmbed(_PatchEmbed): + """Image to Patch Embedding. + + Compared with default Patch Embedding(in ViT), Patch Embedding of RepMLP + have ReLu and do not convert output tensor into shape (N, L, C). + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The type of convolution + to generate patch embedding. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: 16. + padding (int | tuple | string): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only works when `dynamic_size` + is False. Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, *args, **kwargs): + super(PatchEmbed, self).__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): The output tensor. + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + if self.norm is not None: + x = self.norm(x) + x = self.relu(x) + out_size = (x.shape[2], x.shape[3]) + return x, out_size + + +class GlobalPerceptron(SELayer): + """GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``. + + Args: + input_channels (int): The number of input (and output) channels + in the GlobalPerceptron. + ratio (int): Squeeze ratio in GlobalPerceptron, the intermediate + channel will be ``make_divisible(channels // ratio, divisor)``. + """ + + def __init__(self, input_channels: int, ratio: int, **kwargs) -> None: + super(GlobalPerceptron, self).__init__( + channels=input_channels, + ratio=ratio, + return_weight=True, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + **kwargs) + + +class RepMLPBlock(BaseModule): + """Basic RepMLPNet, consists of PartitionPerceptron and GlobalPerceptron. + + Args: + channels (int): The number of input and the output channels of the + block. + path_h (int): The height of patches. + path_w (int): The weidth of patches. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + path_h, + path_w, + reparam_conv_kernels=None, + globalperceptron_ratio=4, + num_sharesets=1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + deploy=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.deploy = deploy + self.channels = channels + self.num_sharesets = num_sharesets + self.path_h, self.path_w = path_h, path_w + # the input channel of fc3 + self._path_vec_channles = path_h * path_w * num_sharesets + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.gp = GlobalPerceptron( + input_channels=channels, ratio=globalperceptron_ratio) + + # using a conv layer to implement a fc layer + self.fc3 = build_conv_layer( + conv_cfg, + in_channels=self._path_vec_channles, + out_channels=self._path_vec_channles, + kernel_size=1, + stride=1, + padding=0, + bias=deploy, + groups=num_sharesets) + if deploy: + self.fc3_bn = nn.Identity() + else: + norm_layer = build_norm_layer(norm_cfg, num_sharesets)[1] + self.add_module('fc3_bn', norm_layer) + + self.reparam_conv_kernels = reparam_conv_kernels + if not deploy and reparam_conv_kernels is not None: + for k in reparam_conv_kernels: + conv_branch = ConvModule( + in_channels=num_sharesets, + out_channels=num_sharesets, + kernel_size=k, + stride=1, + padding=k // 2, + norm_cfg=dict(type='BN', requires_grad=True), + groups=num_sharesets, + act_cfg=None) + self.__setattr__('repconv{}'.format(k), conv_branch) + + def partition(self, x, h_parts, w_parts): + # convert (N, C, H, W) to (N, h_parts, w_parts, C, path_h, path_w) + x = x.reshape(-1, self.channels, h_parts, self.path_h, w_parts, + self.path_w) + x = x.permute(0, 2, 4, 1, 3, 5) + return x + + def partition_affine(self, x, h_parts, w_parts): + """perform Partition Perceptron.""" + fc_inputs = x.reshape(-1, self._path_vec_channles, 1, 1) + out = self.fc3(fc_inputs) + out = out.reshape(-1, self.num_sharesets, self.path_h, self.path_w) + out = self.fc3_bn(out) + out = out.reshape(-1, h_parts, w_parts, self.num_sharesets, + self.path_h, self.path_w) + return out + + def forward(self, inputs): + # Global Perceptron + global_vec = self.gp(inputs) + + origin_shape = inputs.size() + h_parts = origin_shape[2] // self.path_h + w_parts = origin_shape[3] // self.path_w + + partitions = self.partition(inputs, h_parts, w_parts) + + # Channel Perceptron + fc3_out = self.partition_affine(partitions, h_parts, w_parts) + + # perform Local Perceptron + if self.reparam_conv_kernels is not None and not self.deploy: + conv_inputs = partitions.reshape(-1, self.num_sharesets, + self.path_h, self.path_w) + conv_out = 0 + for k in self.reparam_conv_kernels: + conv_branch = self.__getattr__('repconv{}'.format(k)) + conv_out += conv_branch(conv_inputs) + conv_out = conv_out.reshape(-1, h_parts, w_parts, + self.num_sharesets, self.path_h, + self.path_w) + fc3_out += conv_out + + # N, h_parts, w_parts, num_sharesets, out_h, out_w + fc3_out = fc3_out.permute(0, 3, 1, 4, 2, 5) + out = fc3_out.reshape(*origin_shape) + out = out * global_vec + return out + + def get_equivalent_fc3(self): + """get the equivalent fc3 weight and bias.""" + fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn) + if self.reparam_conv_kernels is not None: + largest_k = max(self.reparam_conv_kernels) + largest_branch = self.__getattr__('repconv{}'.format(largest_k)) + total_kernel, total_bias = fuse_bn(largest_branch.conv, + largest_branch.bn) + for k in self.reparam_conv_kernels: + if k != largest_k: + k_branch = self.__getattr__('repconv{}'.format(k)) + kernel, bias = fuse_bn(k_branch.conv, k_branch.bn) + total_kernel += F.pad(kernel, [(largest_k - k) // 2] * 4) + total_bias += bias + rep_weight, rep_bias = self._convert_conv_to_fc( + total_kernel, total_bias) + final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight + final_fc3_bias = rep_bias + fc_bias + else: + final_fc3_weight = fc_weight + final_fc3_bias = fc_bias + return final_fc3_weight, final_fc3_bias + + def local_inject(self): + """inject the Local Perceptron into Partition Perceptron.""" + self.deploy = True + # Locality Injection + fc3_weight, fc3_bias = self.get_equivalent_fc3() + # Remove Local Perceptron + if self.reparam_conv_kernels is not None: + for k in self.reparam_conv_kernels: + self.__delattr__('repconv{}'.format(k)) + self.__delattr__('fc3') + self.__delattr__('fc3_bn') + self.fc3 = build_conv_layer( + self.conv_cfg, + self._path_vec_channles, + self._path_vec_channles, + 1, + 1, + 0, + bias=True, + groups=self.num_sharesets) + self.fc3_bn = nn.Identity() + self.fc3.weight.data = fc3_weight + self.fc3.bias.data = fc3_bias + + def _convert_conv_to_fc(self, conv_kernel, conv_bias): + """convert conv_k1 to fc, which is still a conv_k2, and the k2 > k1.""" + in_channels = torch.eye(self.path_h * self.path_w).repeat( + 1, self.num_sharesets).reshape(self.path_h * self.path_w, + self.num_sharesets, self.path_h, + self.path_w).to(conv_kernel.device) + fc_k = F.conv2d( + in_channels, + conv_kernel, + padding=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), + groups=self.num_sharesets) + fc_k = fc_k.reshape(self.path_w * self.path_w, self.num_sharesets * + self.path_h * self.path_w).t() + fc_bias = conv_bias.repeat_interleave(self.path_h * self.path_w) + return fc_k, fc_bias + + +class RepMLPNetUnit(BaseModule): + """A basic unit in RepMLPNet : [REPMLPBlock + BN + ConvFFN + BN]. + + Args: + channels (int): The number of input and the output channels of the + unit. + path_h (int): The height of patches. + path_w (int): The weidth of patches. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + path_h, + path_w, + reparam_conv_kernels, + globalperceptron_ratio, + norm_cfg=dict(type='BN', requires_grad=True), + ffn_expand=4, + num_sharesets=1, + deploy=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.repmlp_block = RepMLPBlock( + channels=channels, + path_h=path_h, + path_w=path_w, + reparam_conv_kernels=reparam_conv_kernels, + globalperceptron_ratio=globalperceptron_ratio, + num_sharesets=num_sharesets, + deploy=deploy) + self.ffn_block = ConvFFN(channels, channels * ffn_expand) + norm1 = build_norm_layer(norm_cfg, channels)[1] + self.add_module('norm1', norm1) + norm2 = build_norm_layer(norm_cfg, channels)[1] + self.add_module('norm2', norm2) + + def forward(self, x): + y = x + self.repmlp_block(self.norm1(x)) + out = y + self.ffn_block(self.norm2(y)) + return out + + +class ConvFFN(nn.Module): + """ConvFFN implemented by using point-wise convs.""" + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='GELU')): + super().__init__() + out_features = out_channels or in_channels + hidden_features = hidden_channels or in_channels + self.ffn_fc1 = ConvModule( + in_channels=in_channels, + out_channels=hidden_features, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + act_cfg=None) + self.ffn_fc2 = ConvModule( + in_channels=hidden_features, + out_channels=out_features, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + act_cfg=None) + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + x = self.ffn_fc1(x) + x = self.act(x) + x = self.ffn_fc2(x) + return x + + +@MODELS.register_module() +class RepMLPNet(BaseModule): + """RepMLPNet backbone. + + A PyTorch impl of : `RepMLP: Re-parameterizing Convolutions into + Fully-connected Layers for Image Recognition + `_ + + Args: + arch (str | dict): RepMLP architecture. If use string, choose + from 'base' and 'b'. If use dict, it should have below keys: + + - channels (List[int]): Number of blocks in each stage. + - depths (List[int]): The number of blocks in each branch. + - sharesets_nums (List[int]): RepVGG Block that declares + the need to apply group convolution. + + img_size (int | tuple): The size of input image. Defaults: 224. + in_channels (int): Number of input image channels. Default: 3. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + Default: dict(type='BN', requires_grad=True). + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to deployment + mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + arch_zoo = { + **dict.fromkeys(['b', 'base'], + {'channels': [96, 192, 384, 768], + 'depths': [2, 2, 12, 2], + 'sharesets_nums': [1, 4, 32, 128]}), + } # yapf: disable + + num_extra_tokens = 0 # there is no cls-token in RepMLP + + def __init__(self, + arch, + img_size=224, + in_channels=3, + patch_size=4, + out_indices=(3, ), + reparam_conv_kernels=(3, ), + globalperceptron_ratio=4, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + patch_cfg=dict(), + final_norm=True, + deploy=False, + init_cfg=None): + super(RepMLPNet, self).__init__(init_cfg=init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'channels', 'depths', 'sharesets_nums'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}.' + self.arch_settings = arch + + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.num_stage = len(self.arch_settings['channels']) + for value in self.arch_settings.values(): + assert isinstance(value, list) and len(value) == self.num_stage, ( + 'Length of setting item in arch dict must be type of list and' + ' have the same length.') + + self.channels = self.arch_settings['channels'] + self.depths = self.arch_settings['depths'] + self.sharesets_nums = self.arch_settings['sharesets_nums'] + + _patch_cfg = dict( + in_channels=in_channels, + input_size=self.img_size, + embed_dims=self.channels[0], + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + norm_cfg=self.norm_cfg, + bias=False) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.patch_hs = [ + self.patch_resolution[0] // 2**i for i in range(self.num_stage) + ] + self.patch_ws = [ + self.patch_resolution[1] // 2**i for i in range(self.num_stage) + ] + + self.stages = ModuleList() + self.downsample_layers = ModuleList() + for stage_idx in range(self.num_stage): + # make stage layers + _stage_cfg = dict( + channels=self.channels[stage_idx], + path_h=self.patch_hs[stage_idx], + path_w=self.patch_ws[stage_idx], + reparam_conv_kernels=reparam_conv_kernels, + globalperceptron_ratio=globalperceptron_ratio, + norm_cfg=self.norm_cfg, + ffn_expand=4, + num_sharesets=self.sharesets_nums[stage_idx], + deploy=deploy) + stage_blocks = [ + RepMLPNetUnit(**_stage_cfg) + for _ in range(self.depths[stage_idx]) + ] + self.stages.append(Sequential(*stage_blocks)) + + # make downsample layers + if stage_idx < self.num_stage - 1: + self.downsample_layers.append( + ConvModule( + in_channels=self.channels[stage_idx], + out_channels=self.channels[stage_idx + 1], + kernel_size=2, + stride=2, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + + self.out_indice = out_indices + + if final_norm: + norm_layer = build_norm_layer(norm_cfg, self.channels[-1])[1] + else: + norm_layer = nn.Identity() + self.add_module('final_norm', norm_layer) + + def forward(self, x): + assert x.shape[2:] == self.img_size, \ + "The Rep-MLP doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + + outs = [] + + x, _ = self.patch_embed(x) + for i, stage in enumerate(self.stages): + x = stage(x) + + # downsample after each stage except last stage + if i < len(self.stages) - 1: + downsample = self.downsample_layers[i] + x = downsample(x) + + if i in self.out_indice: + if self.final_norm and i == len(self.stages) - 1: + out = self.final_norm(x) + else: + out = x + outs.append(out) + + return tuple(outs) + + def switch_to_deploy(self): + for m in self.modules(): + if hasattr(m, 'local_inject'): + m.local_inject() diff --git a/mmpretrain/models/backbones/repvgg.py b/mmpretrain/models/backbones/repvgg.py new file mode 100644 index 0000000000000000000000000000000000000000..67c9d147546eb2839a44749040a1a787ee5ce0ea --- /dev/null +++ b/mmpretrain/models/backbones/repvgg.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmengine.model import BaseModule, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from torch import nn + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .base_backbone import BaseBackbone + + +class RepVGGBlock(BaseModule): + """RepVGG block for RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1. + padding (int): Padding of the 3x3 convolution layer. + dilation (int): Dilation of the 3x3 convolution layer. + groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1. + padding_mode (str): Padding mode of the 3x3 convolution layer. + Default: 'zeros'. + se_cfg (None or dict): The configuration of the se module. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + padding=1, + dilation=1, + groups=1, + padding_mode='zeros', + se_cfg=None, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + deploy=False, + init_cfg=None): + super(RepVGGBlock, self).__init__(init_cfg) + + assert se_cfg is None or isinstance(se_cfg, dict) + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.se_cfg = se_cfg + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.deploy = deploy + + if deploy: + self.branch_reparam = build_conv_layer( + conv_cfg, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode) + else: + # judge if input shape and output shape are the same. + # If true, add a normalized identity shortcut. + if out_channels == in_channels and stride == 1 and \ + padding == dilation: + self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.branch_norm = None + + self.branch_3x3 = self.create_conv_bn( + kernel_size=3, + dilation=dilation, + padding=padding, + ) + self.branch_1x1 = self.create_conv_bn(kernel_size=1) + + if se_cfg is not None: + self.se_layer = SELayer(channels=out_channels, **se_cfg) + else: + self.se_layer = None + + self.act = build_activation_layer(act_cfg) + + def create_conv_bn(self, kernel_size, dilation=1, padding=0): + conv_bn = Sequential() + conv_bn.add_module( + 'conv', + build_conv_layer( + self.conv_cfg, + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + dilation=dilation, + padding=padding, + groups=self.groups, + bias=False)) + conv_bn.add_module( + 'norm', + build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) + + return conv_bn + + def forward(self, x): + + def _inner_forward(inputs): + if self.deploy: + return self.branch_reparam(inputs) + + if self.branch_norm is None: + branch_norm_out = 0 + else: + branch_norm_out = self.branch_norm(inputs) + + inner_out = self.branch_3x3(inputs) + self.branch_1x1( + inputs) + branch_norm_out + + if self.se_cfg is not None: + inner_out = self.se_layer(inner_out) + + return inner_out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.act(out) + + return out + + def switch_to_deploy(self): + """Switch the model structure from training mode to deployment mode.""" + if self.deploy: + return + assert self.norm_cfg['type'] == 'BN', \ + "Switch is not allowed when norm_cfg['type'] != 'BN'." + + reparam_weight, reparam_bias = self.reparameterize() + self.branch_reparam = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.out_channels, + kernel_size=3, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + self.branch_reparam.weight.data = reparam_weight + self.branch_reparam.bias.data = reparam_bias + + for param in self.parameters(): + param.detach_() + delattr(self, 'branch_3x3') + delattr(self, 'branch_1x1') + delattr(self, 'branch_norm') + + self.deploy = True + + def reparameterize(self): + """Fuse all the parameters of all branches. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all + branches. the first element is the weights and the second is + the bias. + """ + weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3) + weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1) + # pad a conv1x1 weight to a conv3x3 weight + weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0) + + weight_norm, bias_norm = 0, 0 + if self.branch_norm: + tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm) + weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) + + return (weight_3x3 + weight_1x1 + weight_norm, + bias_3x3 + bias_1x1 + bias_norm) + + def _fuse_conv_bn(self, branch): + """Fuse the parameters in a branch with a conv and bn. + + Args: + branch (mmcv.runner.Sequential): A branch with conv and bn. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + if branch is None: + return 0, 0 + conv_weight = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + + std = (running_var + eps).sqrt() + fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight + fused_bias = -running_mean * gamma / std + beta + + return fused_weight, fused_bias + + def _norm_to_conv3x3(self, branch_nrom): + """Convert a norm layer to a conv3x3-bn sequence. + + Args: + branch (nn.BatchNorm2d): A branch only with bn in the block. + + Returns: + tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and + bn. + """ + input_dim = self.in_channels // self.groups + conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3), + dtype=branch_nrom.weight.dtype) + + for i in range(self.in_channels): + conv_weight[i, i % input_dim, 1, 1] = 1 + conv_weight = conv_weight.to(branch_nrom.weight.device) + + tmp_conv3x3 = self.create_conv_bn(kernel_size=3) + tmp_conv3x3.conv.weight.data = conv_weight + tmp_conv3x3.norm = branch_nrom + return tmp_conv3x3 + + +class MTSPPF(BaseModule): + """MTSPPF block for YOLOX-PAI RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of pooling. Default: 5. + """ + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + kernel_size=5): + super().__init__() + hidden_features = in_channels // 2 # hidden channels + self.conv1 = ConvModule( + in_channels, + hidden_features, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + hidden_features * 4, + out_channels, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d( + kernel_size=kernel_size, stride=1, padding=kernel_size // 2) + + def forward(self, x): + x = self.conv1(x) + y1 = self.maxpool(x) + y2 = self.maxpool(y1) + return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1)) + + +@MODELS.register_module() +class RepVGG(BaseBackbone): + """RepVGG backbone. + + A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again + `_ + + Args: + arch (str | dict): RepVGG architecture. If use string, choose from + 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2', + 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should + have below keys: + + - **num_blocks** (Sequence[int]): Number of blocks in each stage. + - **width_factor** (Sequence[float]): Width deflator in each stage. + - **group_layer_map** (dict | None): RepVGG Block that declares + the need to apply group convolution. + - **se_cfg** (dict | None): SE Layer config. + - **stem_channels** (int, optional): The stem channels, the final + stem channels will be + ``min(stem_channels, base_channels*width_factor[0])``. + If not set here, 64 is used by default in the code. + + in_channels (int): Number of input image channels. Defaults to 3. + base_channels (int): Base channels of RepVGG backbone, work with + width_factor together. Defaults to 64. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(2, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + deploy (bool): Whether to switch the model structure to deployment + mode. Defaults to False. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + add_ppf (bool): Whether to use the MTSPPF block. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] + g2_layer_map = {layer: 2 for layer in groupwise_layers} + g4_layer_map = {layer: 4 for layer in groupwise_layers} + + arch_settings = { + 'A0': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[0.75, 0.75, 0.75, 2.5], + group_layer_map=None, + se_cfg=None), + 'A1': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[1, 1, 1, 2.5], + group_layer_map=None, + se_cfg=None), + 'A2': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[1.5, 1.5, 1.5, 2.75], + group_layer_map=None, + se_cfg=None), + 'B0': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[1, 1, 1, 2.5], + group_layer_map=None, + se_cfg=None, + stem_channels=64), + 'B1': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=None, + se_cfg=None), + 'B1g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B1g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=g4_layer_map, + se_cfg=None), + 'B2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=None, + se_cfg=None), + 'B2g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B2g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=g4_layer_map, + se_cfg=None), + 'B3': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=None, + se_cfg=None), + 'B3g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B3g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=g4_layer_map, + se_cfg=None), + 'D2se': + dict( + num_blocks=[8, 14, 24, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=None, + se_cfg=dict(ratio=16, divisor=1)), + 'yolox-pai-small': + dict( + num_blocks=[3, 5, 7, 3], + width_factor=[1, 1, 1, 1], + group_layer_map=None, + se_cfg=None, + stem_channels=32), + } + + def __init__(self, + arch, + in_channels=3, + base_channels=64, + out_indices=(3, ), + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + deploy=False, + norm_eval=False, + add_ppf=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(RepVGG, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + assert len(arch['num_blocks']) == len( + arch['width_factor']) == len(strides) == len(dilations) + assert max(out_indices) < len(arch['num_blocks']) + if arch['group_layer_map'] is not None: + assert max(arch['group_layer_map'].keys()) <= sum( + arch['num_blocks']) + + if arch['se_cfg'] is not None: + assert isinstance(arch['se_cfg'], dict) + + self.base_channels = base_channels + self.arch = arch + self.in_channels = in_channels + self.out_indices = out_indices + self.strides = strides + self.dilations = dilations + self.deploy = deploy + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + + # defaults to 64 to prevert BC-breaking if stem_channels + # not in arch dict; + # the stem channels should not be larger than that of stage1. + channels = min( + arch.get('stem_channels', 64), + int(self.base_channels * self.arch['width_factor'][0])) + self.stem = RepVGGBlock( + self.in_channels, + channels, + stride=2, + se_cfg=arch['se_cfg'], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + deploy=deploy) + + next_create_block_idx = 1 + self.stages = [] + for i in range(len(arch['num_blocks'])): + num_blocks = self.arch['num_blocks'][i] + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = int(self.base_channels * 2**i * + self.arch['width_factor'][i]) + + stage, next_create_block_idx = self._make_stage( + channels, out_channels, num_blocks, stride, dilation, + next_create_block_idx, init_cfg) + stage_name = f'stage_{i + 1}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + channels = out_channels + + if add_ppf: + self.ppf = MTSPPF( + out_channels, + out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + kernel_size=5) + else: + self.ppf = nn.Identity() + + def _make_stage(self, in_channels, out_channels, num_blocks, stride, + dilation, next_create_block_idx, init_cfg): + strides = [stride] + [1] * (num_blocks - 1) + dilations = [dilation] * num_blocks + + blocks = [] + for i in range(num_blocks): + groups = self.arch['group_layer_map'].get( + next_create_block_idx, + 1) if self.arch['group_layer_map'] is not None else 1 + blocks.append( + RepVGGBlock( + in_channels, + out_channels, + stride=strides[i], + padding=dilations[i], + dilation=dilations[i], + groups=groups, + se_cfg=self.arch['se_cfg'], + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy, + init_cfg=init_cfg)) + in_channels = out_channels + next_create_block_idx += 1 + + return Sequential(*blocks), next_create_block_idx + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + if i + 1 == len(self.stages): + x = self.ppf(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = getattr(self, f'stage_{i+1}') + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RepVGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + for m in self.modules(): + if isinstance(m, RepVGGBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/res2net.py b/mmpretrain/models/backbones/res2net.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9bb6df37a2d2c9d19e613faa50ce0103aff357 --- /dev/null +++ b/mmpretrain/models/backbones/res2net.py @@ -0,0 +1,317 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottle2neck(_Bottleneck): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + scales=4, + base_width=26, + base_channels=64, + stage_type='normal', + **kwargs): + """Bottle2neck block for Res2Net.""" + super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs) + assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.' + + mid_channels = out_channels // self.expansion + width = int(math.floor(mid_channels * (base_width / base_channels))) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width * scales, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + width * scales, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + + if stage_type == 'stage': + self.pool = nn.AvgPool2d( + kernel_size=3, stride=self.conv2_stride, padding=1) + + self.convs = ModuleList() + self.bns = ModuleList() + for i in range(scales - 1): + self.convs.append( + build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False)) + self.bns.append( + build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1]) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width * scales, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.stage_type = stage_type + self.scales = scales + self.width = width + delattr(self, 'conv2') + delattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + sp = self.convs[0](spx[0].contiguous()) + sp = self.relu(self.bns[0](sp)) + out = sp + for i in range(1, self.scales - 1): + if self.stage_type == 'stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp.contiguous()) + sp = self.relu(self.bns[i](sp)) + out = torch.cat((out, sp), 1) + + if self.stage_type == 'normal' and self.scales != 1: + out = torch.cat((out, spx[self.scales - 1]), 1) + elif self.stage_type == 'stage' and self.scales != 1: + out = torch.cat((out, self.pool(spx[self.scales - 1])), 1) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Res2Layer(Sequential): + """Res2Layer to build Res2Net style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + scales (int): Scales used in Res2Net. Default: 4 + base_width (int): Basic width of each scale. Default: 26 + drop_path_rate (float or np.ndarray): stochastic depth rate. + Default: 0. + """ + + def __init__(self, + block, + in_channels, + out_channels, + num_blocks, + stride=1, + avg_down=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + scales=4, + base_width=26, + drop_path_rate=0.0, + **kwargs): + self.block = block + + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + + downsample = None + if stride != 1 or in_channels != out_channels: + if avg_down: + downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + else: + downsample = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + stage_type='stage', + drop_path_rate=drop_path_rate[0], + **kwargs)) + in_channels = out_channels + for i in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + drop_path_rate=drop_path_rate[i], + **kwargs)) + super(Res2Layer, self).__init__(*layers) + + +@MODELS.register_module() +class Res2Net(ResNet): + """Res2Net backbone. + + A PyTorch implement of : `Res2Net: A New Multi-scale Backbone + Architecture `_ + + Args: + depth (int): Depth of Res2Net, choose from {50, 101, 152}. + scales (int): Scales used in Res2Net. Defaults to 4. + base_width (int): Basic width of each scale. Defaults to 26. + in_channels (int): Number of input image channels. Defaults to 3. + num_stages (int): Number of Res2Net stages. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Defaults to "pytorch". + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN', requires_grad=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmpretrain.models import Res2Net + >>> import torch + >>> model = Res2Net(depth=50, + ... scales=4, + ... base_width=26, + ... out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = model.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottle2neck, (3, 4, 6, 3)), + 101: (Bottle2neck, (3, 4, 23, 3)), + 152: (Bottle2neck, (3, 8, 36, 3)) + } + + def __init__(self, + scales=4, + base_width=26, + style='pytorch', + deep_stem=True, + avg_down=True, + init_cfg=None, + **kwargs): + self.scales = scales + self.base_width = base_width + super(Res2Net, self).__init__( + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + init_cfg=init_cfg, + **kwargs) + + def make_res_layer(self, **kwargs): + return Res2Layer( + scales=self.scales, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/resnest.py b/mmpretrain/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb438f042d606946fd7b69d73568f28563e0efa --- /dev/null +++ b/mmpretrain/models/backbones/resnest.py @@ -0,0 +1,339 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN')): + super(SplitAttentionConv2d, self).__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + return getattr(self, self.norm0_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + groups=1, + width_per_group=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs) + + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = SplitAttentionConv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152, 200}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)), + 269: (Bottleneck, (3, 30, 48, 8)) + } + + def __init__(self, + depth, + groups=1, + width_per_group=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.width_per_group = width_per_group + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(depth=depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4a254f7c2b76f03974e05194b39fbb802684873a --- /dev/null +++ b/mmpretrain/models/backbones/resnet.py @@ -0,0 +1,768 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + +eps = 1.0e-5 + + +class BasicBlock(BaseModule): + """BasicBlock for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the output channels of conv1. This is a + reserved argument in BasicBlock and should always be 1. Default: 1. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None. + style (str): `pytorch` or `caffe`. It is unused and reserved for + unified API with Bottleneck. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + expansion=1, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + drop_path_rate=0.0, + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert self.expansion == 1 + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, out_channels, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + 3, + padding=1, + bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = build_activation_layer(act_cfg) + self.downsample = downsample + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.drop_path(out) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. Default: 4. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None. + style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the stride-two + layer is the first 1x1 conv layer. Default: "pytorch". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + expansion=4, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU', inplace=True), + drop_path_rate=0.0, + init_cfg=None): + super(Bottleneck, self).__init__(init_cfg=init_cfg) + assert style in ['pytorch', 'caffe'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = build_activation_layer(act_cfg) + self.downsample = downsample + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + @property + def norm3(self): + return getattr(self, self.norm3_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.drop_path(out) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def get_expansion(block, expansion=None): + """Get the expansion of a residual block. + + The block expansion will be obtained by the following order: + + 1. If ``expansion`` is given, just return it. + 2. If ``block`` has the attribute ``expansion``, then return + ``block.expansion``. + 3. Return the default value according the the block type: + 1 for ``BasicBlock`` and 4 for ``Bottleneck``. + + Args: + block (class): The block class. + expansion (int | None): The given expansion ratio. + + Returns: + int: The expansion of the block. + """ + if isinstance(expansion, int): + assert expansion > 0 + elif expansion is None: + if hasattr(block, 'expansion'): + expansion = block.expansion + elif issubclass(block, BasicBlock): + expansion = 1 + elif issubclass(block, Bottleneck): + expansion = 4 + else: + raise TypeError(f'expansion is not specified for {block.__name__}') + else: + raise TypeError('expansion must be an integer or None') + + return expansion + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): Residual block used to build ResLayer. + num_blocks (int): Number of blocks. + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int, optional): The expansion for BasicBlock/Bottleneck. + If not specified, it will firstly be obtained via + ``block.expansion``. If the block has no attribute "expansion", + the following default values will be used: 1 for BasicBlock and + 4 for Bottleneck. Default: None. + stride (int): stride of the first block. Default: 1. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + drop_path_rate (float or list): stochastic depth rate. + Default: 0. + """ + + def __init__(self, + block, + num_blocks, + in_channels, + out_channels, + expansion=None, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + drop_path_rate=0.0, + **kwargs): + self.block = block + self.expansion = get_expansion(block, expansion) + + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[0], + **kwargs)) + in_channels = out_channels + for i in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[i], + **kwargs)) + super(ResLayer, self).__init__(*layers) + + +@MODELS.register_module() +class ResNet(BaseBackbone): + """ResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + expansion=None, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate=0.0): + super(ResNet, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.expansion = get_expansion(self.block, expansion) + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + _in_channels = stem_channels + _out_channels = base_channels * self.expansion + + # stochastic depth decay rule + total_depth = sum(stage_blocks) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + res_layer = self.make_res_layer( + block=self.block, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=_out_channels, + expansion=self.expansion, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=dpr[:num_blocks]) + _in_channels = _out_channels + _out_channels *= 2 + dpr = dpr[num_blocks:] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = res_layer[-1].out_channels + + def make_res_layer(self, **kwargs): + return ResLayer(**kwargs) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ResNet, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress zero_init_residual if use pretrained model. + return + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, x): + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer id to set the different learning rates for ResNet. + + ResNet stages: + 50 : [3, 4, 6, 3] + 101 : [3, 4, 23, 3] + 152 : [3, 8, 36, 3] + 200 : [3, 24, 36, 3] + eca269d: [3, 30, 48, 8] + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + depths = self.stage_blocks + if depths[1] == 4 and depths[2] == 6: + blk2, blk3 = 2, 3 + elif depths[1] == 4 and depths[2] == 23: + blk2, blk3 = 2, 3 + elif depths[1] == 8 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 24 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 30 and depths[2] == 48: + blk2, blk3 = 5, 6 + else: + raise NotImplementedError + + N2, N3 = math.ceil(depths[1] / blk2 - + 1e-5), math.ceil(depths[2] / blk3 - 1e-5) + N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6 + max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id, max_layer_id + 1 + + if param_name.startswith('backbone.layer'): + stage_id = int(param_name.split('.')[1][5:]) + block_id = int(param_name.split('.')[2]) + + if stage_id == 1: + layer_id = 1 + elif stage_id == 2: + layer_id = 2 + block_id // blk2 # r50: 2, 3 + elif stage_id == 3: + layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 + else: # stage_id == 4 + layer_id = N # r50: 6 + return layer_id, max_layer_id + 1 + + else: + return 0, max_layer_id + 1 + + +@MODELS.register_module() +class ResNetV1c(ResNet): + """ResNetV1c backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv + in the input stem with three 3x3 convs. + """ + + def __init__(self, **kwargs): + super(ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class ResNetV1d(ResNet): + """ResNetV1d backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/mmpretrain/models/backbones/resnet_cifar.py b/mmpretrain/models/backbones/resnet_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..9f17f92fd76a690ea90977b38ab2ea00345ba903 --- /dev/null +++ b/mmpretrain/models/backbones/resnet_cifar.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class ResNet_CIFAR(ResNet): + """ResNet backbone for CIFAR. + + Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in + conv1, and does not apply MaxPoolinng after stem. It has been proven to + be more efficient than standard ResNet in other public codebase, e.g., + `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): This network has specific designed stem, thus it is + asserted to be False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + def __init__(self, depth, deep_stem=False, **kwargs): + super(ResNet_CIFAR, self).__init__( + depth, deep_stem=deep_stem, **kwargs) + assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem' + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmpretrain/models/backbones/resnext.py b/mmpretrain/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..8858b7d3dffdcb20677e091fba4f5a1084d086a3 --- /dev/null +++ b/mmpretrain/models/backbones/resnext.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + **kwargs): + super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super(ResNeXt, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/revvit.py b/mmpretrain/models/backbones/revvit.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e6c28c943c83d0580634ac04450ee7ffc5f478 --- /dev/null +++ b/mmpretrain/models/backbones/revvit.py @@ -0,0 +1,671 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys + +import numpy as np +import torch +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from torch import nn +from torch.autograd import Function as Function + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) + + +class RevBackProp(Function): + """Custom Backpropagation function to allow (A) flushing memory in forward + and (B) activation recomputation reversibly in backward for gradient + calculation. + + Inspired by + https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py + """ + + @staticmethod + def forward( + ctx, + x, + layers, + buffer_layers, # List of layer ids for int activation to buffer + ): + """Reversible Forward pass. + + Any intermediate activations from `buffer_layers` are cached in ctx for + forward pass. This is not necessary for standard usecases. Each + reversible layer implements its own forward pass logic. + """ + buffer_layers.sort() + x1, x2 = torch.chunk(x, 2, dim=-1) + intermediate = [] + + for layer in layers: + x1, x2 = layer(x1, x2) + if layer.layer_id in buffer_layers: + intermediate.extend([x1.detach(), x2.detach()]) + + if len(buffer_layers) == 0: + all_tensors = [x1.detach(), x2.detach()] + else: + intermediate = [torch.LongTensor(buffer_layers), *intermediate] + all_tensors = [x1.detach(), x2.detach(), *intermediate] + + ctx.save_for_backward(*all_tensors) + ctx.layers = layers + + return torch.cat([x1, x2], dim=-1) + + @staticmethod + def backward(ctx, dx): + """Reversible Backward pass. + + Any intermediate activations from `buffer_layers` are recovered from + ctx. Each layer implements its own loic for backward pass (both + activation recomputation and grad calculation). + """ + d_x1, d_x2 = torch.chunk(dx, 2, dim=-1) + # retrieve params from ctx for backward + x1, x2, *int_tensors = ctx.saved_tensors + # no buffering + if len(int_tensors) != 0: + buffer_layers = int_tensors[0].tolist() + else: + buffer_layers = [] + + layers = ctx.layers + + for _, layer in enumerate(layers[::-1]): + if layer.layer_id in buffer_layers: + x1, x2, d_x1, d_x2 = layer.backward_pass( + y1=int_tensors[buffer_layers.index(layer.layer_id) * 2 + + 1], + y2=int_tensors[buffer_layers.index(layer.layer_id) * 2 + + 2], + d_y1=d_x1, + d_y2=d_x2, + ) + else: + x1, x2, d_x1, d_x2 = layer.backward_pass( + y1=x1, + y2=x2, + d_y1=d_x1, + d_y2=d_x2, + ) + + dx = torch.cat([d_x1, d_x2], dim=-1) + + del int_tensors + del d_x1, d_x2, x1, x2 + + return dx, None, None + + +class RevTransformerEncoderLayer(BaseModule): + """Reversible Transformer Encoder Layer. + + This module is a building block of Reversible Transformer Encoder, + which support backpropagation without storing activations. + The residual connection is not applied to the FFN layer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + Default: 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0 + drop_path_rate (float): stochastic depth rate. + Default 0.0 + num_fcs (int): The number of linear in FFN + Default: 2 + qkv_bias (bool): enable bias for qkv if True. + Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU') + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + layer_id (int): The layer id of current layer. Used in RevBackProp. + Default: 0 + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + layer_id: int = 0, + init_cfg=None): + super(RevTransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate) + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + act_cfg=act_cfg, + add_identity=False) + + self.layer_id = layer_id + self.seeds = {} + + def init_weights(self): + super(RevTransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def seed_cuda(self, key): + """Fix seeds to allow for stochastic elements such as dropout to be + reproduced exactly in activation recomputation in the backward pass.""" + # randomize seeds + # use cuda generator if available + if (hasattr(torch.cuda, 'default_generators') + and len(torch.cuda.default_generators) > 0): + # GPU + device_idx = torch.cuda.current_device() + seed = torch.cuda.default_generators[device_idx].seed() + else: + # CPU + seed = int(torch.seed() % sys.maxsize) + + self.seeds[key] = seed + torch.manual_seed(self.seeds[key]) + + def forward(self, x1, x2): + """ + Implementation of Reversible TransformerEncoderLayer + + ` + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + ` + """ + self.seed_cuda('attn') + # attention output + f_x2 = self.attn(self.ln1(x2)) + # apply droppath on attention output + self.seed_cuda('droppath') + f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2) + y1 = x1 + f_x2_dropped + + # free memory + if self.training: + del x1 + + # ffn output + self.seed_cuda('ffn') + g_y1 = self.ffn(self.ln2(y1)) + # apply droppath on ffn output + torch.manual_seed(self.seeds['droppath']) + g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1) + # final output + y2 = x2 + g_y1_dropped + + # free memory + if self.training: + del x2 + + return y1, y2 + + def backward_pass(self, y1, y2, d_y1, d_y2): + """Activation re-compute with the following equation. + + x2 = y2 - g(y1), g = FFN + x1 = y1 - f(x2), f = MSHA + """ + + # temporarily record intermediate activation for G + # and use them for gradient calculation of G + with torch.enable_grad(): + y1.requires_grad = True + + torch.manual_seed(self.seeds['ffn']) + g_y1 = self.ffn(self.ln2(y1)) + + torch.manual_seed(self.seeds['droppath']) + g_y1 = build_dropout(self.drop_path_cfg)(g_y1) + + g_y1.backward(d_y2, retain_graph=True) + + # activate recomputation is by design and not part of + # the computation graph in forward pass + with torch.no_grad(): + x2 = y2 - g_y1 + del g_y1 + + d_y1 = d_y1 + y1.grad + y1.grad = None + + # record F activation and calculate gradients on F + with torch.enable_grad(): + x2.requires_grad = True + + torch.manual_seed(self.seeds['attn']) + f_x2 = self.attn(self.ln1(x2)) + + torch.manual_seed(self.seeds['droppath']) + f_x2 = build_dropout(self.drop_path_cfg)(f_x2) + + f_x2.backward(d_y1, retain_graph=True) + + # propagate reverse computed activations at the + # start of the previous block + with torch.no_grad(): + x1 = y1 - f_x2 + del f_x2, y1 + + d_y2 = d_y2 + x2.grad + + x2.grad = None + x2 = x2.detach() + + return x1, x2, d_y1, d_y2 + + +class TwoStreamFusion(nn.Module): + """A general constructor for neural modules fusing two equal sized tensors + in forward. + + Args: + mode (str): The mode of fusion. Options are 'add', 'max', 'min', + 'avg', 'concat'. + """ + + def __init__(self, mode: str): + super().__init__() + self.mode = mode + + if mode == 'add': + self.fuse_fn = lambda x: torch.stack(x).sum(dim=0) + elif mode == 'max': + self.fuse_fn = lambda x: torch.stack(x).max(dim=0).values + elif mode == 'min': + self.fuse_fn = lambda x: torch.stack(x).min(dim=0).values + elif mode == 'avg': + self.fuse_fn = lambda x: torch.stack(x).mean(dim=0) + elif mode == 'concat': + self.fuse_fn = lambda x: torch.cat(x, dim=-1) + else: + raise NotImplementedError + + def forward(self, x): + # split the tensor into two halves in the channel dimension + x = torch.chunk(x, 2, dim=2) + return self.fuse_fn(x) + + +@MODELS.register_module() +class RevVisionTransformer(BaseBackbone): + """Reversible Vision Transformer. + + A PyTorch implementation of : `Reversible Vision Transformers + `_ # noqa: E501 + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + fusion_mode (str): The fusion mode of transformer layers. + Defaults to 'concat'. + no_custom_backward (bool): Whether to use custom backward. + Defaults to False. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], + { + # The same as the implementation in MAE + # + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + } + num_extra_tokens = 0 # The official RevViT doesn't have class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='avg_featmap', + with_cls_token=False, + frozen_stages=-1, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + fusion_mode='concat', + no_custom_backward=False, + init_cfg=None): + super(RevVisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + self.no_custom_backward = no_custom_backward + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + layer_id=i, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(RevTransformerEncoderLayer(**_layer_cfg)) + + # fusion operation for the final output + self.fusion_layer = TwoStreamFusion(mode=fusion_mode) + + self.frozen_stages = frozen_stages + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + def init_weights(self): + super(RevVisionTransformer, self).init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers) and self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = torch.cat([x, x], dim=-1) + + # forward with different conditions + if not self.training or self.no_custom_backward: + # in eval/inference model + executing_fn = RevVisionTransformer._forward_vanilla_bp + else: + # use custom backward when self.training=True. + executing_fn = RevBackProp.apply + + x = executing_fn(x, self.layers, []) + + if self.final_norm: + x = self.ln1(x) + x = self.fusion_layer(x) + + return (self._format_output(x, patch_resolution), ) + + @staticmethod + def _forward_vanilla_bp(hidden_state, layers, buffer=[]): + """Using reversible layers without reversible backpropagation. + + Debugging purpose only. Activated with self.no_custom_backward + """ + # split into ffn state(ffn_out) and attention output(attn_out) + ffn_out, attn_out = torch.chunk(hidden_state, 2, dim=-1) + del hidden_state + + for _, layer in enumerate(layers): + attn_out, ffn_out = layer(attn_out, ffn_out) + + return torch.cat([attn_out, ffn_out], dim=-1) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/riformer.py b/mmpretrain/models/backbones/riformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7cb4d37c2ac6f1479fd3c533c456f3b0a0c45e --- /dev/null +++ b/mmpretrain/models/backbones/riformer.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .poolformer import Mlp, PatchEmbed + + +class Affine(nn.Module): + """Affine Transformation module. + + Args: + in_features (int): Input dimension. + """ + + def __init__(self, in_features): + super().__init__() + self.affine = nn.Conv2d( + in_features, + in_features, + kernel_size=1, + stride=1, + padding=0, + groups=in_features, + bias=True) + + def forward(self, x): + return self.affine(x) - x + + +class RIFormerBlock(BaseModule): + """RIFormer Block. + + Args: + dim (int): Embedding dim. + mlp_ratio (float): Mlp expansion ratio. Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='GN', num_groups=1)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-5. + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + """ + + def __init__(self, + dim, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + layer_scale_init_value=1e-5, + deploy=False): + + super().__init__() + + if deploy: + self.norm_reparam = build_norm_layer(norm_cfg, dim)[1] + else: + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = Affine(in_features=dim) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + # The following two techniques are useful to train deep RIFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.norm_cfg = norm_cfg + self.dim = dim + self.deploy = deploy + + def forward(self, x): + if hasattr(self, 'norm_reparam'): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.norm_reparam(x)) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + def fuse_affine(self, norm, token_mixer): + gamma_affn = token_mixer.affine.weight.reshape(-1) + gamma_affn = gamma_affn - torch.ones_like(gamma_affn) + beta_affn = token_mixer.affine.bias + gamma_ln = norm.weight + beta_ln = norm.bias + return (gamma_ln * gamma_affn), (beta_ln * gamma_affn + beta_affn) + + def get_equivalent_scale_bias(self): + eq_s, eq_b = self.fuse_affine(self.norm1, self.token_mixer) + return eq_s, eq_b + + def switch_to_deploy(self): + if self.deploy: + return + eq_s, eq_b = self.get_equivalent_scale_bias() + self.norm_reparam = build_norm_layer(self.norm_cfg, self.dim)[1] + self.norm_reparam.weight.data = eq_s + self.norm_reparam.bias.data = eq_b + self.__delattr__('norm1') + if hasattr(self, 'token_mixer'): + self.__delattr__('token_mixer') + self.deploy = True + + +def basic_blocks(dim, + index, + layers, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + layer_scale_init_value=1e-5, + deploy=False): + """generate RIFormer blocks for a stage.""" + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + blocks.append( + RIFormerBlock( + dim, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + deploy=deploy, + )) + blocks = nn.Sequential(*blocks) + + return blocks + + +@MODELS.register_module() +class RIFormer(BaseBackbone): + """RIFormer. + + A PyTorch implementation of RIFormer introduced by: + `RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer `_ + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``RIFormer.arch_settings``. And if dict, it + should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - mlp_ratios (list[int]): Expansion ratio of MLPs. + - layer_scale_init_value (float): Init value for Layer Scale. + + Defaults to 'S12'. + + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + in_patch_size (int): The patch size of/? input image patch embedding. + Defaults to 7. + in_stride (int): The stride of input image patch embedding. + Defaults to 4. + in_pad (int): The padding of input image patch embedding. + Defaults to 2. + down_patch_size (int): The patch size of downsampling patch embedding. + Defaults to 3. + down_stride (int): The stride of downsampling patch embedding. + Defaults to 2. + down_pad (int): The padding of downsampling patch embedding. + Defaults to 1. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to -1, which means not freezing any parameters. + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims, --mlp_ratios: + # embedding dims and mlp ratios for the four stages + # --downsamples: flags to apply downsampling or not in four blocks + arch_settings = { + 's12': { + 'layers': [2, 2, 6, 2], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's24': { + 'layers': [4, 4, 12, 4], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm48': { + 'layers': [8, 8, 24, 8], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + } + + def __init__(self, + arch='s12', + in_channels=3, + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., + drop_path_rate=0., + out_indices=-1, + frozen_stages=-1, + init_cfg=None, + deploy=False): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + layers = arch['layers'] + embed_dims = arch['embed_dims'] + mlp_ratios = arch['mlp_ratios'] \ + if 'mlp_ratios' in arch else [4, 4, 4, 4] + layer_scale_init_value = arch['layer_scale_init_value'] \ + if 'layer_scale_init_value' in arch else 1e-5 + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chans=in_channels, + embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + stage = basic_blocks( + embed_dims[i], + i, + layers, + mlp_ratio=mlp_ratios[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + deploy=deploy) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], + embed_dim=embed_dims[i + 1])) + + self.network = nn.ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 7 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, + embed_dims[(i_layer + 1) // 2])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + self.deploy = deploy + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RIFormer, self).train(mode) + self._freeze_stages() + return self + + def switch_to_deploy(self): + for m in self.modules(): + if isinstance(m, RIFormerBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/seresnet.py b/mmpretrain/models/backbones/seresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4437c17fa06d62f57ac18a31967a35b4f44f190f --- /dev/null +++ b/mmpretrain/models/backbones/seresnet.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.utils.checkpoint as cp + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .resnet import Bottleneck, ResLayer, ResNet + + +class SEBottleneck(Bottleneck): + """SEBottleneck block for SEResNet. + + Args: + in_channels (int): The input channels of the SEBottleneck block. + out_channels (int): The output channel of the SEBottleneck block. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + + def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs): + super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs) + self.se_layer = SELayer(out_channels, ratio=se_ratio) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + out = self.se_layer(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class SEResNet(ResNet): + """SEResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import SEResNet + >>> import torch + >>> self = SEResNet(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 512, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, se_ratio=16, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SEResNet') + self.se_ratio = se_ratio + super(SEResNet, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer(se_ratio=self.se_ratio, **kwargs) diff --git a/mmpretrain/models/backbones/seresnext.py b/mmpretrain/models/backbones/seresnext.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2838074225930795d6d8ad70ba067b6ad4c2da --- /dev/null +++ b/mmpretrain/models/backbones/seresnext.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResLayer +from .seresnet import SEBottleneck as _SEBottleneck +from .seresnet import SEResNet + + +class SEBottleneck(_SEBottleneck): + """SEBottleneck block for SEResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + base_channels (int): Middle channels of the first stage. Default: 64. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + se_ratio=16, + **kwargs): + super(SEBottleneck, self).__init__(in_channels, out_channels, se_ratio, + **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # We follow the same rational of ResNext to compute mid_channels. + # For SEResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for SEResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class SEResNeXt(SEResNet): + """SEResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super(SEResNeXt, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/shufflenet_v1.py b/mmpretrain/models/backbones/shufflenet_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc3617f93b82fa5e37fa2bb5b47d93e6bd9a58f --- /dev/null +++ b/mmpretrain/models/backbones/shufflenet_v1.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import channel_shuffle, make_divisible +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class ShuffleUnit(BaseModule): + """ShuffleUnit block. + + ShuffleNet unit with pointwise group convolution (GConv) and channel + shuffle. + + Args: + in_channels (int): The input channels of the ShuffleUnit. + out_channels (int): The output channels of the ShuffleUnit. + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3 + first_block (bool): Whether it is the first ShuffleUnit of a + sequential ShuffleUnits. Default: True, which means not using the + grouped 1x1 convolution. + combine (str): The ways to combine the input and output + branches. Default: 'add'. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + groups=3, + first_block=True, + combine='add', + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(ShuffleUnit, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.first_block = first_block + self.combine = combine + self.groups = groups + self.bottleneck_channels = self.out_channels // 4 + self.with_cp = with_cp + + if self.combine == 'add': + self.depthwise_stride = 1 + self._combine_func = self._add + assert in_channels == out_channels, ( + 'in_channels must be equal to out_channels when combine ' + 'is add') + elif self.combine == 'concat': + self.depthwise_stride = 2 + self._combine_func = self._concat + self.out_channels -= self.in_channels + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + raise ValueError(f'Cannot combine tensors with {self.combine}. ' + 'Only "add" and "concat" are supported') + + self.first_1x1_groups = 1 if first_block else self.groups + self.g_conv_1x1_compress = ConvModule( + in_channels=self.in_channels, + out_channels=self.bottleneck_channels, + kernel_size=1, + groups=self.first_1x1_groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.depthwise_conv3x3_bn = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.bottleneck_channels, + kernel_size=3, + stride=self.depthwise_stride, + padding=1, + groups=self.bottleneck_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.g_conv_1x1_expand = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.out_channels, + kernel_size=1, + groups=self.groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.act = build_activation_layer(act_cfg) + + @staticmethod + def _add(x, out): + # residual connection + return x + out + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.g_conv_1x1_compress(x) + out = self.depthwise_conv3x3_bn(out) + + if self.groups > 1: + out = channel_shuffle(out, self.groups) + + out = self.g_conv_1x1_expand(out) + + if self.combine == 'concat': + residual = self.avgpool(residual) + out = self.act(out) + out = self._combine_func(residual, out) + else: + out = self._combine_func(residual, out) + out = self.act(out) + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV1(BaseBackbone): + """ShuffleNetV1 backbone. + + Args: + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3. + widen_factor (float): Width multiplier - adjusts the number + of channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, ) + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(2, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=None): + super(ShuffleNetV1, self).__init__(init_cfg) + self.init_cfg = init_cfg + self.stage_blocks = [4, 8, 4] + self.groups = groups + + for index in out_indices: + if index not in range(0, 3): + raise ValueError('the item in out_indices must in ' + f'range(0, 3). But received {index}') + + if frozen_stages not in range(-1, 3): + raise ValueError('frozen_stages must be in range(-1, 3). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if groups == 1: + channels = (144, 288, 576) + elif groups == 2: + channels = (200, 400, 800) + elif groups == 3: + channels = (240, 480, 960) + elif groups == 4: + channels = (272, 544, 1088) + elif groups == 8: + channels = (384, 768, 1536) + else: + raise ValueError(f'{groups} groups is not supported for 1x1 ' + 'Grouped Convolutions') + + channels = [make_divisible(ch * widen_factor, 8) for ch in channels] + + self.in_channels = int(24 * widen_factor) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + first_block = True if i == 0 else False + layer = self.make_layer(channels[i], num_blocks, first_block) + self.layers.append(layer) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV1, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def make_layer(self, out_channels, num_blocks, first_block=False): + """Stack ShuffleUnit blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): Number of blocks. + first_block (bool): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. Default: False, which means using + the grouped 1x1 convolution. + """ + layers = [] + for i in range(num_blocks): + first_block = first_block if i == 0 else False + combine_mode = 'concat' if i == 0 else 'add' + layers.append( + ShuffleUnit( + self.in_channels, + out_channels, + groups=self.groups, + first_block=first_block, + combine=combine_mode, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetV1, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/shufflenet_v2.py b/mmpretrain/models/backbones/shufflenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..02f9c749a814b0b4ee4e04dd6afacda078ae6f39 --- /dev/null +++ b/mmpretrain/models/backbones/shufflenet_v2.py @@ -0,0 +1,305 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import channel_shuffle +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class InvertedResidual(BaseModule): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + # Channel Split operation. using these lines of code to replace + # ``chunk(x, 2, dim=1)`` can make it easier to deploy a + # shufflenetv2 model by using mmdeploy. + channels = x.shape[1] + c = channels // 2 + channels % 2 + x1 = x[:, :c, :, :] + x2 = x[:, c:, :, :] + + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV2(BaseBackbone): + """ShuffleNetV2 backbone. + + Args: + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + widen_factor=1.0, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=None): + super(ShuffleNetV2, self).__init__(init_cfg) + self.stage_blocks = [4, 8, 4] + for index in out_indices: + if index not in range(0, 4): + raise ValueError('the item in out_indices must in ' + f'range(0, 4). But received {index}') + + if frozen_stages not in range(-1, 4): + raise ValueError('frozen_stages must be in range(-1, 4). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if widen_factor == 0.5: + channels = [48, 96, 192, 1024] + elif widen_factor == 1.0: + channels = [116, 232, 464, 1024] + elif widen_factor == 1.5: + channels = [176, 352, 704, 1024] + elif widen_factor == 2.0: + channels = [244, 488, 976, 2048] + else: + raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. ' + f'But received {widen_factor}') + + self.in_channels = 24 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + layer = self._make_layer(channels[i], num_blocks) + self.layers.append(layer) + + output_channels = channels[-1] + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=output_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def _make_layer(self, out_channels, num_blocks): + """Stack blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): number of blocks. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append( + InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m.weight, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/sparse_convnext.py b/mmpretrain/models/backbones/sparse_convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..8f361360af460746a0f70206becb519252135596 --- /dev/null +++ b/mmpretrain/models/backbones/sparse_convnext.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (SparseAvgPooling, SparseConv2d, SparseHelper, + SparseMaxPooling, build_norm_layer) +from .convnext import ConvNeXt, ConvNeXtBlock + + +class SparseConvNeXtBlock(ConvNeXtBlock): + """Sparse ConvNeXt Block. + + Note: + There are two equivalent implementations: + 1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear -> + GELU -> Linear; Permute back + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class SparseConvNeXt(ConvNeXt): + """ConvNeXt with sparse module conversion function. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/models/convnext.py + and + https://github.com/keyu-tian/SparK/blob/main/encoder.py + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='SparseLN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_output (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict. + """ # noqa: E501 + + def __init__(self, + arch: str = 'small', + in_channels: int = 3, + stem_patch_size: int = 4, + norm_cfg: dict = dict(type='SparseLN2d', eps=1e-6), + act_cfg: dict = dict(type='GELU'), + linear_pw_conv: bool = True, + use_grn: bool = False, + drop_path_rate: float = 0, + layer_scale_init_value: float = 1e-6, + out_indices: int = -1, + frozen_stages: int = 0, + gap_before_output: bool = True, + with_cp: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super(ConvNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_output = gap_before_output + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + SparseConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + self.dense_model_to_sparse(m=self) + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + if self.gap_before_output: + gap = x.mean([-2, -1], keepdim=True) + outs.append(gap.flatten(1)) + else: + outs.append(x) + + return tuple(outs) + + def dense_model_to_sparse(self, m: nn.Module) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + # elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + # m: nn.BatchNorm2d + # output = (SparseSyncBatchNorm2d + # if enable_sync_bn else SparseBatchNorm2d)( + # m.weight.shape[0], + # eps=m.eps, + # momentum=m.momentum, + # affine=m.affine, + # track_running_stats=m.track_running_stats) + # output.weight.data.copy_(m.weight.data) + # output.bias.data.copy_(m.bias.data) + # output.running_mean.data.copy_(m.running_mean.data) + # output.running_var.data.copy_(m.running_var.data) + # output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + for name, child in m.named_children(): + output.add_module(name, self.dense_model_to_sparse(child)) + del m + return output diff --git a/mmpretrain/models/backbones/sparse_resnet.py b/mmpretrain/models/backbones/sparse_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..67597f1f0327f466a6841333c8247f96238ce35f --- /dev/null +++ b/mmpretrain/models/backbones/sparse_resnet.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import Optional, Tuple + +import torch.nn as nn + +from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling, + SparseBatchNorm2d, + SparseConv2d, + SparseMaxPooling, + SparseSyncBatchNorm2d) +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class SparseResNet(ResNet): + """ResNet with sparse module conversion function. + + Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Output channels of the stem layer. Defaults to 64. + base_channels (int): Middle channels of the first stage. + Defaults to 64. + num_stages (int): Stages of the network. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + """ + + def __init__(self, + depth: int, + in_channels: int = 3, + stem_channels: int = 64, + base_channels: int = 64, + expansion: Optional[int] = None, + num_stages: int = 4, + strides: Tuple[int] = (1, 2, 2, 2), + dilations: Tuple[int] = (1, 1, 1, 1), + out_indices: Tuple[int] = (3, ), + style: str = 'pytorch', + deep_stem: bool = False, + avg_down: bool = False, + frozen_stages: int = -1, + conv_cfg: Optional[dict] = None, + norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'), + norm_eval: bool = False, + with_cp: bool = False, + zero_init_residual: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate: float = 0, + **kwargs): + super().__init__( + depth=depth, + in_channels=in_channels, + stem_channels=stem_channels, + base_channels=base_channels, + expansion=expansion, + num_stages=num_stages, + strides=strides, + dilations=dilations, + out_indices=out_indices, + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + norm_eval=norm_eval, + with_cp=with_cp, + zero_init_residual=zero_init_residual, + init_cfg=init_cfg, + drop_path_rate=drop_path_rate, + **kwargs) + norm_type = norm_cfg['type'] + enable_sync_bn = False + if re.search('Sync', norm_type) is not None: + enable_sync_bn = True + self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn) + + def dense_model_to_sparse(self, m: nn.Module, + enable_sync_bn: bool) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + m: nn.BatchNorm2d + output = (SparseSyncBatchNorm2d + if enable_sync_bn else SparseBatchNorm2d)( + m.weight.shape[0], + eps=m.eps, + momentum=m.momentum, + affine=m.affine, + track_running_stats=m.track_running_stats) + output.weight.data.copy_(m.weight.data) + output.bias.data.copy_(m.bias.data) + output.running_mean.data.copy_(m.running_mean.data) + output.running_var.data.copy_(m.running_var.data) + output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + elif isinstance(m, (nn.Conv1d, )): + raise NotImplementedError + + for name, child in m.named_children(): + output.add_module( + name, + self.dense_model_to_sparse( + child, enable_sync_bn=enable_sync_bn)) + del m + return output diff --git a/mmpretrain/models/backbones/swin_transformer.py b/mmpretrain/models/backbones/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..559fd5e9150f78a9801fcb9070e114b4e96113c5 --- /dev/null +++ b/mmpretrain/models/backbones/swin_transformer.py @@ -0,0 +1,585 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import (ShiftWindowMSA, resize_pos_embed, + resize_relative_position_bias_table, to_2tuple) +from .base_backbone import BaseBackbone + + +class SwinBlock(BaseModule): + """Swin Transformer block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + shift=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SwinBlock, self).__init__(init_cfg) + self.with_cp = with_cp + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA(**_attn_cfgs) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=7, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + self.embed_dims = embed_dims + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **block_cfgs[i] + } + block = SwinBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': 2 * embed_dims, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x, in_shape, do_downsample=True): + for block in self.blocks: + x = block(x, in_shape) + + if self.downsample is not None and do_downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + return x, out_shape + + @property + def out_channels(self): + if self.downsample: + return self.downsample.out_channels + else: + return self.embed_dims + + +@MODELS.register_module() +class SwinTransformer(BaseBackbone): + """Swin Transformer. + + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SwinTransformer + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'expansion_ratio': 3})) + >>> self = SwinTransformer(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48]}), + } # yapf: disable + + _version = 3 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=224, + patch_size=4, + in_channels=3, + window_size=7, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + out_after_downsample=False, + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + init_cfg=None): + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.out_after_downsample = out_after_downsample + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook( + self._prepare_relative_position_bias_table) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': window_size, + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **stage_cfg + } + + stage = SwinBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + if self.out_after_downsample: + self.num_features = embed_dims[1:] + else: + self.num_features = embed_dims[:-1] + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + super(SwinTransformer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) + + return tuple(outs) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, + **kwargs): + """load checkpoints.""" + # Names of some parameters in has been changed. + version = local_metadata.get('version', None) + if (version is None + or version < 2) and self.__class__ is SwinTransformer: + final_stage_num = len(self.stages) - 1 + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if k.startswith('norm.') or k.startswith('backbone.norm.'): + convert_key = k.replace('norm.', f'norm{final_stage_num}.') + state_dict[convert_key] = state_dict[k] + del state_dict[k] + if (version is None + or version < 3) and self.__class__ is SwinTransformer: + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if 'attn_mask' in k: + del state_dict[k] + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + *args, **kwargs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformer, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, + **kwargs): + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'relative_position_bias_table' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + relative_position_bias_table_pretrained = state_dict[ckpt_key] + relative_position_bias_table_current = state_dict_model[key] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if L1 != L2: + src_size = int(L1**0.5) + dst_size = int(L2**0.5) + new_rel_pos_bias = resize_relative_position_bias_table( + src_size, dst_size, + relative_position_bias_table_pretrained, nH1) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info('Resize the relative_position_bias_table from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos_bias.shape}') + state_dict[ckpt_key] = new_rel_pos_bias + + # The index buffer need to be re-generated. + index_buffer = ckpt_key.replace('bias_table', 'index') + del state_dict[index_buffer] + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = sum(self.depths) + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = param_name.split('.')[3] + if block_id in ('reduction', 'norm'): + layer_depth = sum(self.depths[:stage_id + 1]) + else: + layer_depth = sum(self.depths[:stage_id]) + int(block_id) + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/swin_transformer_v2.py b/mmpretrain/models/backbones/swin_transformer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..142505a808ae3fc631d54e1a56ae483db242da31 --- /dev/null +++ b/mmpretrain/models/backbones/swin_transformer_v2.py @@ -0,0 +1,567 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from ..builder import MODELS +from ..utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2, + resize_pos_embed, to_2tuple) +from .base_backbone import BaseBackbone + + +class SwinBlockV2(BaseModule): + """Swin Transformer V2 block. Use post normalization. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + extra_norm (bool): Whether add extra norm at the end of main branch. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=8, + shift=False, + extra_norm=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained_window_size=0, + init_cfg=None): + + super(SwinBlockV2, self).__init__(init_cfg) + self.with_cp = with_cp + self.extra_norm = extra_norm + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + # use V2 attention implementation + _attn_cfgs.update( + window_msa=WindowMSAV2, + pretrained_window_size=to_2tuple(pretrained_window_size)) + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + 'add_identity': False, + **ffn_cfgs + } + self.ffn = FFN(**_ffn_cfgs) + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + # add extra norm for every n blocks in huge and giant model + if self.extra_norm: + self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + + def _inner_forward(x): + # Use post normalization + identity = x + x = self.attn(x, hw_shape) + x = self.norm1(x) + x = x + identity + + identity = x + x = self.ffn(x) + x = self.norm2(x) + x = x + identity + + if self.extra_norm: + x = self.norm3(x) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockV2Sequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + extra_norm_every_n_blocks (int): Add extra norm at the end of main + branch every n blocks. Defaults to 0, which means no needs for + extra norm layer. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=8, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + extra_norm_every_n_blocks=0, + pretrained_window_size=0, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + if downsample: + self.out_channels = 2 * embed_dims + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': self.out_channels, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.out_channels = embed_dims + self.downsample = None + + self.blocks = ModuleList() + for i in range(depth): + extra_norm = True if extra_norm_every_n_blocks and \ + (i + 1) % extra_norm_every_n_blocks == 0 else False + _block_cfg = { + 'embed_dims': self.out_channels, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'extra_norm': extra_norm, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'pretrained_window_size': pretrained_window_size, + **block_cfgs[i] + } + block = SwinBlockV2(**_block_cfg) + self.blocks.append(block) + + def forward(self, x, in_shape): + if self.downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + + for block in self.blocks: + x = block(x, out_shape) + + return x, out_shape + + +@MODELS.register_module() +class SwinTransformerV2(BaseBackbone): + """Swin Transformer V2. + + A PyTorch implement of : `Swin Transformer V2: + Scaling Up Capacity and Resolution + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + - **extra_norm_every_n_blocks** (int): Add extra norm at the end + of main branch every n blocks. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int | Sequence): The height and width of the window. + Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + pretrained_window_sizes (tuple(int)): Pretrained window sizes of + each layer. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SwinTransformerV2 + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'padding': 'same'})) + >>> self = SwinTransformerV2(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48], + 'extra_norm_every_n_blocks': 0}), + # head count not certain for huge, and is employed for another + # parallel study about self-supervised learning. + **dict.fromkeys(['h', 'huge'], + {'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [8, 16, 32, 64], + 'extra_norm_every_n_blocks': 6}), + **dict.fromkeys(['g', 'giant'], + {'embed_dims': 512, + 'depths': [2, 2, 42, 4], + 'num_heads': [16, 32, 64, 128], + 'extra_norm_every_n_blocks': 6}), + } # yapf: disable + + _version = 1 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=256, + patch_size=4, + in_channels=3, + window_size=8, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + pretrained_window_sizes=[0, 0, 0, 0], + init_cfg=None): + super(SwinTransformerV2, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', + 'extra_norm_every_n_blocks' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.extra_norm_every_n_blocks = self.arch_settings[ + 'extra_norm_every_n_blocks'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + if isinstance(window_size, int): + self.window_sizes = [window_size for _ in range(self.num_layers)] + elif isinstance(window_size, Sequence): + assert len(window_size) == self.num_layers, \ + f'Length of window_sizes {len(window_size)} is not equal to '\ + f'length of stages {self.num_layers}.' + self.window_sizes = window_size + else: + raise TypeError('window_size should be a Sequence or int.') + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook(self._delete_reinit_params) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i > 0 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': self.window_sizes[i], + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks, + 'pretrained_window_size': pretrained_window_sizes[i], + 'downsample_cfg': dict(use_post_norm=True), + **stage_cfg + } + + stage = SwinBlockV2Sequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + super(SwinTransformerV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformerV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs): + # delete relative_position_index since we always re-init it + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Delete `relative_position_index` and `relative_coords_table` ' + 'since we always re-init these params according to the ' + '`window_size`, which might cause unwanted but unworried ' + 'warnings when loading checkpoint.') + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete relative_coords_table since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_coords_table' in k + ] + for k in relative_position_index_keys: + del state_dict[k] diff --git a/mmpretrain/models/backbones/t2t_vit.py b/mmpretrain/models/backbones/t2t_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..a57b95e1fb00b227c400e7b32fa612e3539503c6 --- /dev/null +++ b/mmpretrain/models/backbones/t2t_vit.py @@ -0,0 +1,447 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) +from .base_backbone import BaseBackbone + + +class T2TTransformerLayer(BaseModule): + """Transformer Layer for T2T_ViT. + + Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports + different ``input_dims`` and ``embed_dims``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs + input_dims (int, optional): The input token dimension. + Defaults to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``(input_dims // num_heads) ** -0.5`` if set. Defaults to None. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Notes: + In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e. + ``(embed_dims // num_heads) ** -0.5``. However, in the official + code, it uses ``(input_dims // num_heads) ** -0.5``, so here we + keep the same with the official implementation. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + input_dims=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg) + + self.v_shortcut = True if input_dims is not None else False + input_dims = input_dims or embed_dims + + self.ln1 = build_norm_layer(norm_cfg, input_dims) + + self.attn = MultiheadAttention( + input_dims=input_dims, + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + qk_scale=qk_scale or (input_dims // num_heads)**-0.5, + v_shortcut=self.v_shortcut) + + self.ln2 = build_norm_layer(norm_cfg, embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + def forward(self, x): + if self.v_shortcut: + x = self.attn(self.ln1(x)) + else: + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +class T2TModule(BaseModule): + """Tokens-to-Token module. + + "Tokens-to-Token module" (T2T Module) can model the local structure + information of images and reduce the length of tokens progressively. + + Args: + img_size (int): Input image size + in_channels (int): Number of input channels + embed_dims (int): Embedding dimension + token_dims (int): Tokens dimension in T2TModuleAttention. + use_performer (bool): If True, use Performer version self-attention to + adopt regular self-attention. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + + Notes: + Usually, ``token_dim`` is set as a small value (32 or 64) to reduce + MACs + """ + + def __init__( + self, + img_size=224, + in_channels=3, + embed_dims=384, + token_dims=64, + use_performer=False, + init_cfg=None, + ): + super(T2TModule, self).__init__(init_cfg) + + self.embed_dims = embed_dims + + self.soft_split0 = nn.Unfold( + kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + if not use_performer: + self.attention1 = T2TTransformerLayer( + input_dims=in_channels * 7 * 7, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.attention2 = T2TTransformerLayer( + input_dims=token_dims * 3 * 3, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.project = nn.Linear(token_dims * 3 * 3, embed_dims) + else: + raise NotImplementedError("Performer hasn't been implemented.") + + # there are 3 soft split, stride are 4,2,2 separately + out_side = img_size // (4 * 2 * 2) + self.init_out_size = [out_side, out_side] + self.num_patches = out_side**2 + + @staticmethod + def _get_unfold_size(unfold: nn.Unfold, input_size): + h, w = input_size + kernel_size = to_2tuple(unfold.kernel_size) + stride = to_2tuple(unfold.stride) + padding = to_2tuple(unfold.padding) + dilation = to_2tuple(unfold.dilation) + + h_out = (h + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (w + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + return (h_out, w_out) + + def forward(self, x): + # step0: soft split + hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:]) + x = self.soft_split0(x).transpose(1, 2) + + for step in [1, 2]: + # re-structurization/reconstruction + attn = getattr(self, f'attention{step}') + x = attn(x).transpose(1, 2) + B, C, _ = x.shape + x = x.reshape(B, C, hw_shape[0], hw_shape[1]) + + # soft split + soft_split = getattr(self, f'soft_split{step}') + hw_shape = self._get_unfold_size(soft_split, hw_shape) + x = soft_split(x).transpose(1, 2) + + # final tokens + x = self.project(x) + return x, hw_shape + + +def get_sinusoid_encoding(n_position, embed_dims): + """Generate sinusoid encoding table. + + Sinusoid encoding is a kind of relative position encoding method came from + `Attention Is All You Need`_. + + Args: + n_position (int): The length of the input token. + embed_dims (int): The position embedding dimension. + + Returns: + :obj:`torch.FloatTensor`: The sinusoid encoding table. + """ + + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (i // 2) / embed_dims) + for i in range(embed_dims) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos) for pos in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +@MODELS.register_module() +class T2T_ViT(BaseBackbone): + """Tokens-to-Token Vision Transformer (T2T-ViT) + + A PyTorch implementation of `Tokens-to-Token ViT: Training Vision + Transformers from Scratch on ImageNet `_ + + Args: + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + in_channels (int): Number of input channels. + embed_dims (int): Embedding dimension. + num_layers (int): Num of transformer layers in encoder. + Defaults to 14. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Dropout rate after position embedding. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. Defaults to + ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + t2t_cfg (dict): Extra config of Tokens-to-Token module. + Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=384, + num_layers=14, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + interpolate_mode='bicubic', + t2t_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super().__init__(init_cfg) + + # Token-to-Token Module + self.tokens_to_token = T2TModule( + img_size=img_size, + in_channels=in_channels, + embed_dims=embed_dims, + **t2t_cfg) + self.patch_resolution = self.tokens_to_token.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + sinusoid_table = get_sinusoid_encoding( + num_patches + self.num_extra_tokens, embed_dims) + self.register_buffer('pos_embed', sinusoid_table) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must be a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = num_layers + index + assert 0 <= out_indices[i] <= num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] + + self.encoder = ModuleList() + for i in range(num_layers): + if isinstance(layer_cfgs, Sequence): + layer_cfg = layer_cfgs[i] + else: + layer_cfg = deepcopy(layer_cfgs) + layer_cfg = { + 'embed_dims': embed_dims, + 'num_heads': 6, + 'feedforward_channels': 3 * embed_dims, + 'drop_path_rate': dpr[i], + 'qkv_bias': False, + 'norm_cfg': norm_cfg, + **layer_cfg + } + + layer = T2TTransformerLayer(**layer_cfg) + self.encoder.append(layer) + + self.final_norm = final_norm + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims) + else: + self.norm = nn.Identity() + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress custom init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.tokens_to_token.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.tokens_to_token(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.encoder): + x = layer(x) + + if i == len(self.encoder) - 1 and self.final_norm: + x = self.norm(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/timm_backbone.py b/mmpretrain/models/backbones/timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..51ecbdbb077be0643026de2ec91c0169263a41f7 --- /dev/null +++ b/mmpretrain/models/backbones/timm_backbone.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmengine.logging import MMLogger + +from mmpretrain.registry import MODELS +from mmpretrain.utils import require +from .base_backbone import BaseBackbone + + +def print_timm_feature_info(feature_info): + """Print feature_info of timm backbone to help development and debug. + + Args: + feature_info (list[dict] | timm.models.features.FeatureInfo | None): + feature_info of timm backbone. + """ + logger = MMLogger.get_current_instance() + if feature_info is None: + logger.warning('This backbone does not have feature_info') + elif isinstance(feature_info, list): + for feat_idx, each_info in enumerate(feature_info): + logger.info(f'backbone feature_info[{feat_idx}]: {each_info}') + else: + try: + logger.info(f'backbone out_indices: {feature_info.out_indices}') + logger.info(f'backbone out_channels: {feature_info.channels()}') + logger.info(f'backbone out_strides: {feature_info.reduction()}') + except AttributeError: + logger.warning('Unexpected format of backbone feature_info') + + +@MODELS.register_module() +class TIMMBackbone(BaseBackbone): + """Wrapper to use backbones from timm library. + + More details can be found in + `timm `_. + See especially the document for `feature extraction + `_. + + Args: + model_name (str): Name of timm model to instantiate. + features_only (bool): Whether to extract feature pyramid (multi-scale + feature maps from the deepest layer at each stride). For Vision + Transformer models that do not support this argument, + set this False. Defaults to False. + pretrained (bool): Whether to load pretrained weights. + Defaults to False. + checkpoint_path (str): Path of checkpoint to load at the last of + ``timm.create_model``. Defaults to empty string, which means + not loading. + in_channels (int): Number of input image channels. Defaults to 3. + init_cfg (dict or list[dict], optional): Initialization config dict of + OpenMMLab projects. Defaults to None. + **kwargs: Other timm & model specific arguments. + """ + + @require('timm') + def __init__(self, + model_name, + features_only=False, + pretrained=False, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs): + import timm + + if not isinstance(pretrained, bool): + raise TypeError('pretrained must be bool, not str for model path') + if features_only and checkpoint_path: + warnings.warn( + 'Using both features_only and checkpoint_path will cause error' + ' in timm. See ' + 'https://github.com/rwightman/pytorch-image-models/issues/488') + + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + norm_class = MODELS.get(kwargs['norm_layer']) + + def build_norm(*args, **kwargs): + return norm_class(*args, **kwargs) + + kwargs['norm_layer'] = build_norm + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs) + + # reset classifier + if hasattr(self.timm_model, 'reset_classifier'): + self.timm_model.reset_classifier(0, '') + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + feature_info = getattr(self.timm_model, 'feature_info', None) + print_timm_feature_info(feature_info) + + def forward(self, x): + features = self.timm_model(x) + if isinstance(features, (list, tuple)): + features = tuple(features) + else: + features = (features, ) + return features diff --git a/mmpretrain/models/backbones/tinyvit.py b/mmpretrain/models/backbones/tinyvit.py new file mode 100644 index 0000000000000000000000000000000000000000..5279832184343a6e8ff4b253891de1b990192775 --- /dev/null +++ b/mmpretrain/models/backbones/tinyvit.py @@ -0,0 +1,769 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn import functional as F + +from mmpretrain.registry import MODELS +from ..utils import LeAttention +from .base_backbone import BaseBackbone + + +class ConvBN2d(Sequential): + """An implementation of Conv2d + BatchNorm2d with support of fusion. + + Modified from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int): The size of the convolution kernel. + Default: 1. + stride (int): The stride of the convolution. + Default: 1. + padding (int): The padding of the convolution. + Default: 0. + dilation (int): The dilation of the convolution. + Default: 1. + groups (int): The number of groups in the convolution. + Default: 1. + bn_weight_init (float): The initial value of the weight of + the nn.BatchNorm2d layer. Default: 1.0. + init_cfg (dict): The initialization config of the module. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bn_weight_init=1.0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.add_module( + 'conv2d', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False)) + bn2d = nn.BatchNorm2d(num_features=out_channels) + # bn initialization + torch.nn.init.constant_(bn2d.weight, bn_weight_init) + torch.nn.init.constant_(bn2d.bias, 0) + self.add_module('bn2d', bn2d) + + @torch.no_grad() + def fuse(self): + conv2d, bn2d = self._modules.values() + w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5 + w = conv2d.weight * w[:, None, None, None] + b = bn2d.bias - bn2d.running_mean * bn2d.weight / \ + (bn2d.running_var + bn2d.eps)**0.5 + + m = nn.Conv2d( + in_channels=w.size(1) * self.c.groups, + out_channels=w.size(0), + kernel_size=w.shape[2:], + stride=self.conv2d.stride, + padding=self.conv2d.padding, + dilation=self.conv2d.dilation, + groups=self.conv2d.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchEmbed(BaseModule): + """Patch Embedding for Vision Transformer. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use + Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is + (N, C, H, W). + + Args: + in_channels (int): The number of input channels. + embed_dim (int): The embedding dimension. + resolution (Tuple[int, int]): The resolution of the input feature. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + embed_dim, + resolution, + act_cfg=dict(type='GELU')): + super().__init__() + img_size: Tuple[int, int] = resolution + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_channels = in_channels + self.embed_dim = embed_dim + self.seq = nn.Sequential( + ConvBN2d( + in_channels, + embed_dim // 2, + kernel_size=3, + stride=2, + padding=1), + build_activation_layer(act_cfg), + ConvBN2d( + embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + ) + + def forward(self, x): + return self.seq(x) + + +class PatchMerging(nn.Module): + """Patch Merging for TinyViT. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Different from `mmpretrain.models.utils.PatchMerging`, this module use + Conv2d and BatchNorm2d to implement PatchMerging. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + out_channels (int): The number of output channels. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + resolution, + in_channels, + out_channels, + act_cfg=dict(type='GELU')): + super().__init__() + + self.img_size = resolution + + self.act = build_activation_layer(act_cfg) + self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1) + self.conv2 = ConvBN2d( + out_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels) + self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1) + self.out_resolution = (resolution[0] // 2, resolution[1] // 2) + + def forward(self, x): + if len(x.shape) == 3: + H, W = self.img_size + B = x.shape[0] + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + x = self.conv1(x) + x = self.act(x) + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + + x = x.flatten(2).transpose(1, 2) + return x + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + expand_ratio (int): The expand ratio of the hidden channels. + drop_rate (float): The drop rate of the block. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + out_channels, + expand_ratio, + drop_path, + act_cfg=dict(type='GELU')): + super().__init__() + self.in_channels = in_channels + hidden_channels = int(in_channels * expand_ratio) + + # linear + self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1) + self.act = build_activation_layer(act_cfg) + # depthwise conv + self.conv2 = ConvBN2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_channels) + # linear + self.conv3 = ConvBN2d( + hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act(x) + + return x + + +class ConvStage(BaseModule): + """Convolution Stage for TinyViT. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + depth (int): The number of blocks in the stage. + act_cfg (dict): The activation config of the module. + drop_path (float): The drop path of the block. + downsample (None | nn.Module): The downsample operation. + Default: None. + use_checkpoint (bool): Whether to use checkpointing to save memory. + out_channels (int): The number of output channels. + conv_expand_ratio (int): The expand ratio of the hidden channels. + Default: 4. + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + resolution, + depth, + act_cfg, + drop_path=0., + downsample=None, + use_checkpoint=False, + out_channels=None, + conv_expand_ratio=4., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.use_checkpoint = use_checkpoint + # build blocks + self.blocks = ModuleList([ + MBConvBlock( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=conv_expand_ratio, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path) + for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + resolution=resolution, + in_channels=in_channels, + out_channels=out_channels, + act_cfg=act_cfg) + self.resolution = self.downsample.out_resolution + else: + self.downsample = None + self.resolution = resolution + + def forward(self, x): + for block in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +class MLP(BaseModule): + """MLP module for TinyViT. + + Args: + in_channels (int): The number of input channels. + hidden_channels (int, optional): The number of hidden channels. + Default: None. + out_channels (int, optional): The number of output channels. + Default: None. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + drop (float): Probability of an element to be zeroed. + Default: 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.norm = nn.LayerNorm(in_channels) + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.fc2 = nn.Linear(hidden_channels, out_channels) + self.act = build_activation_layer(act_cfg) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class TinyViTBlock(BaseModule): + """TinViT Block. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + num_heads (int): The number of heads in the multi-head attention. + window_size (int): The size of the window. + Default: 7. + mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float): Probability of an element to be zeroed. + Default: 0. + drop_path (float): The drop path of the block. + Default: 0. + local_conv_size (int): The size of the local convolution. + Default: 3. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + resolution, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + act_cfg=dict(type='GELU')): + super().__init__() + self.in_channels = in_channels + self.img_size = resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert in_channels % num_heads == 0, \ + 'dim must be divisible by num_heads' + head_dim = in_channels // num_heads + + window_resolution = (window_size, window_size) + self.attn = LeAttention( + in_channels, + head_dim, + num_heads, + attn_ratio=1, + resolution=window_resolution) + + mlp_hidden_dim = int(in_channels * mlp_ratio) + self.mlp = MLP( + in_channels=in_channels, + hidden_channels=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.local_conv = ConvBN2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=local_conv_size, + stride=1, + padding=local_conv_size // 2, + groups=in_channels) + + def forward(self, x): + H, W = self.img_size + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - + H % self.window_size) % self.window_size + pad_r = (self.window_size - + W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, + C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + +class BasicStage(BaseModule): + """Basic Stage for TinyViT. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + depth (int): The number of blocks in the stage. + num_heads (int): The number of heads in the multi-head attention. + window_size (int): The size of the window. + mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float): Probability of an element to be zeroed. + Default: 0. + drop_path (float): The drop path of the block. + Default: 0. + downsample (None | nn.Module): The downsample operation. + Default: None. + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + use_checkpoint=False, + local_conv_size=3, + out_channels=None, + act_cfg=dict(type='GELU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.use_checkpoint = use_checkpoint + # build blocks + self.blocks = ModuleList([ + TinyViTBlock( + in_channels=in_channels, + resolution=resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + local_conv_size=local_conv_size, + act_cfg=act_cfg, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path) + for i in range(depth) + ]) + + # build patch merging layer + if downsample is not None: + self.downsample = downsample( + resolution=resolution, + in_channels=in_channels, + out_channels=out_channels, + act_cfg=act_cfg) + self.resolution = self.downsample.out_resolution + else: + self.downsample = None + self.resolution = resolution + + def forward(self, x): + for block in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +@MODELS.register_module() +class TinyViT(BaseBackbone): + """TinyViT. + A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation + for Small Vision Transformers`_ + + Inspiration from + https://github.com/microsoft/Cream/blob/main/TinyViT + + Args: + arch (str | dict): The architecture of TinyViT. + Default: '5m'. + img_size (tuple | int): The resolution of the input image. + Default: (224, 224) + window_size (list): The size of the window. + Default: [7, 7, 14, 7] + in_channels (int): The number of input channels. + Default: 3. + depths (list[int]): The depth of each stage. + Default: [2, 2, 6, 2]. + mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_rate (float): Probability of an element to be zeroed. + Default: 0. + drop_path_rate (float): The drop path of the block. + Default: 0.1. + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. + mbconv_expand_ratio (int): The expand ratio of the mbconv. + Default: 4.0 + local_conv_size (int): The size of the local conv. + Default: 3. + layer_lr_decay (float): The layer lr decay. + Default: 1.0 + out_indices (int | list[int]): Output from which stages. + Default: -1 + frozen_stages (int | list[int]): Stages to be frozen (all param fixed). + Default: -0 + gap_before_final_nrom (bool): Whether to add a gap before the final + norm. Default: True. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + arch_settings = { + '5m': { + 'channels': [64, 128, 160, 320], + 'num_heads': [2, 4, 5, 10], + 'depths': [2, 2, 6, 2], + }, + '11m': { + 'channels': [64, 128, 256, 448], + 'num_heads': [2, 4, 8, 14], + 'depths': [2, 2, 6, 2], + }, + '21m': { + 'channels': [96, 192, 384, 576], + 'num_heads': [3, 6, 12, 18], + 'depths': [2, 2, 6, 2], + }, + } + + def __init__(self, + arch='5m', + img_size=(224, 224), + window_size=[7, 7, 14, 7], + in_channels=3, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavaiable arch, please choose from ' \ + f'({set(self.arch_settings)} or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'channels' in arch and 'num_heads' in arch and \ + 'depths' in arch, 'The arch dict must have' \ + f'"channels", "num_heads", "window_sizes" ' \ + f'keys, but got {arch.keys()}' + + self.channels = arch['channels'] + self.num_heads = arch['num_heads'] + self.widow_sizes = window_size + self.img_size = img_size + self.depths = arch['depths'] + + self.num_stages = len(self.channels) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + self.layer_lr_decay = layer_lr_decay + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dim=self.channels[0], + resolution=self.img_size, + act_cfg=dict(type='GELU')) + patches_resolution = self.patch_embed.patches_resolution + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + + # build stages + self.stages = ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channel = self.channels[i] + curr_resolution = (patches_resolution[0] // (2**i), + patches_resolution[1] // (2**i)) + drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])] + downsample = PatchMerging if (i < self.num_stages - 1) else None + out_channels = self.channels[min(i + 1, self.num_stages - 1)] + if i >= 1: + stage = BasicStage( + in_channels=channel, + resolution=curr_resolution, + depth=depth, + num_heads=self.num_heads[i], + window_size=self.widow_sizes[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=drop_path, + downsample=downsample, + use_checkpoint=use_checkpoint, + local_conv_size=local_conv_size, + out_channels=out_channels, + act_cfg=act_cfg) + else: + stage = ConvStage( + in_channels=channel, + resolution=curr_resolution, + depth=depth, + act_cfg=act_cfg, + drop_path=drop_path, + downsample=downsample, + use_checkpoint=use_checkpoint, + out_channels=out_channels, + conv_expand_ratio=mbconv_expand_ratio) + self.stages.append(stage) + + # add output norm + if i in self.out_indices: + norm_layer = build_norm_layer(norm_cfg, out_channels)[1] + self.add_module(f'norm{i}', norm_layer) + + def set_layer_lr_decay(self, layer_lr_decay): + # TODO: add layer_lr_decay + pass + + def forward(self, x): + outs = [] + x = self.patch_embed(x) + + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean(1) + outs.append(norm_layer(gap)) + else: + out = norm_layer(x) + # convert the (B,L,C) format into (B,C,H,W) format + # which would be better for the downstream tasks. + B, L, C = out.shape + out = out.view(B, *stage.resolution, C) + outs.append(out.permute(0, 3, 1, 2)) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(TinyViT, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/tnt.py b/mmpretrain/models/backbones/tnt.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b241c1f6bc398157793748b7a457f0836daedb --- /dev/null +++ b/mmpretrain/models/backbones/tnt.py @@ -0,0 +1,368 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import to_2tuple +from .base_backbone import BaseBackbone + + +class TransformerBlock(BaseModule): + """Implement a transformer block in TnTLayer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. + Default: 4 + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0. + drop_path_rate (float): stochastic depth rate. Default 0. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + qkv_bias (bool): Enable bias for qkv if True. Default False + act_cfg (dict): The activation config for FFNs. Defaults to GELU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) or (n, batch, embed_dim). + (batch, n, embed_dim) is common case in CV. Defaults to False + init_cfg (dict, optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + embed_dims, + num_heads, + ffn_ratio=4, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + init_cfg=None): + super(TransformerBlock, self).__init__(init_cfg=init_cfg) + + self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first) + + self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=embed_dims * ffn_ratio, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + if not qkv_bias: + self.attn.attn.in_proj_bias = None + + def forward(self, x): + x = self.attn(self.norm_attn(x), identity=x) + x = self.ffn(self.norm_ffn(x), identity=x) + return x + + +class TnTLayer(BaseModule): + """Implement one encoder layer in Transformer in Transformer. + + Args: + num_pixel (int): The pixel number in target patch transformed with + a linear projection in inner transformer + embed_dims_inner (int): Feature dimension in inner transformer block + embed_dims_outer (int): Feature dimension in outer transformer block + num_heads_inner (int): Parallel attention heads in inner transformer. + num_heads_outer (int): Parallel attention heads in outer transformer. + inner_block_cfg (dict): Extra config of inner transformer block. + Defaults to empty dict. + outer_block_cfg (dict): Extra config of outer transformer block. + Defaults to empty dict. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + init_cfg (dict, optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + num_pixel, + embed_dims_inner, + embed_dims_outer, + num_heads_inner, + num_heads_outer, + inner_block_cfg=dict(), + outer_block_cfg=dict(), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(TnTLayer, self).__init__(init_cfg=init_cfg) + + self.inner_block = TransformerBlock( + embed_dims=embed_dims_inner, + num_heads=num_heads_inner, + **inner_block_cfg) + + self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1] + self.projection = nn.Linear( + embed_dims_inner * num_pixel, embed_dims_outer, bias=True) + + self.outer_block = TransformerBlock( + embed_dims=embed_dims_outer, + num_heads=num_heads_outer, + **outer_block_cfg) + + def forward(self, pixel_embed, patch_embed): + pixel_embed = self.inner_block(pixel_embed) + + B, N, C = patch_embed.size() + patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection( + self.norm_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = self.outer_block(patch_embed) + + return pixel_embed, patch_embed + + +class PixelEmbed(BaseModule): + """Image to Pixel Embedding. + + Args: + img_size (int | tuple): The size of input image + patch_size (int): The size of one patch + in_channels (int): The num of input channels + embed_dims_inner (int): The num of channels of the target patch + transformed with a linear projection in inner transformer + stride (int): The stride of the conv2d layer. We use a conv2d layer + and a unfold layer to implement image to pixel embedding. + init_cfg (dict, optional): Initialization config dict + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims_inner=48, + stride=4, + init_cfg=None): + super(PixelEmbed, self).__init__(init_cfg=init_cfg) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # patches_resolution property necessary for resizing + # positional embedding + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + num_patches = patches_resolution[0] * patches_resolution[1] + + self.img_size = img_size + self.num_patches = num_patches + self.embed_dims_inner = embed_dims_inner + + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] + self.new_patch_size = new_patch_size + + self.proj = nn.Conv2d( + in_channels, + self.embed_dims_inner, + kernel_size=7, + padding=3, + stride=stride) + self.unfold = nn.Unfold( + kernel_size=new_patch_size, stride=new_patch_size) + + def forward(self, x, pixel_pos): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model " \ + f'({self.img_size[0]}*{self.img_size[1]}).' + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, + 2).reshape(B * self.num_patches, self.embed_dims_inner, + self.new_patch_size[0], + self.new_patch_size[1]) + x = x + pixel_pos + x = x.reshape(B * self.num_patches, self.embed_dims_inner, + -1).transpose(1, 2) + return x + + +@MODELS.register_module() +class TNT(BaseBackbone): + """Transformer in Transformer. + + A PyTorch implement of: `Transformer in Transformer + `_ + + Inspiration from + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size. Defaults to 224 + patch_size (int | tuple): The patch size. Deault to 16 + in_channels (int): Number of input channels. Defaults to 3 + ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. + Default: 4 + qkv_bias (bool): Enable bias for qkv if True. Default False + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0. + drop_path_rate (float): stochastic depth rate. Default 0. + act_cfg (dict): The activation config for FFNs. Defaults to GELU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + first_stride (int): The stride of the conv2d layer. We use a conv2d + layer and a unfold layer to implement image to pixel embedding. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + init_cfg (dict, optional): Initialization config dict + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims_outer': 384, + 'embed_dims_inner': 24, + 'num_layers': 12, + 'num_heads_outer': 6, + 'num_heads_inner': 4 + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims_outer': 640, + 'embed_dims_inner': 40, + 'num_layers': 12, + 'num_heads_outer': 10, + 'num_heads_inner': 4 + }) + } + + def __init__(self, + arch='b', + img_size=224, + patch_size=16, + in_channels=3, + ffn_ratio=4, + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + first_stride=4, + num_fcs=2, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ]): + super(TNT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims_outer', 'embed_dims_inner', 'num_layers', + 'num_heads_inner', 'num_heads_outer' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims_inner = self.arch_settings['embed_dims_inner'] + self.embed_dims_outer = self.arch_settings['embed_dims_outer'] + # embed_dims for consistency with other models + self.embed_dims = self.embed_dims_outer + self.num_layers = self.arch_settings['num_layers'] + self.num_heads_inner = self.arch_settings['num_heads_inner'] + self.num_heads_outer = self.arch_settings['num_heads_outer'] + + self.pixel_embed = PixelEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims_inner=self.embed_dims_inner, + stride=first_stride) + num_patches = self.pixel_embed.num_patches + self.num_patches = num_patches + new_patch_size = self.pixel_embed.new_patch_size + num_pixel = new_patch_size[0] * new_patch_size[1] + + self.norm1_proj = build_norm_layer(norm_cfg, num_pixel * + self.embed_dims_inner)[1] + self.projection = nn.Linear(num_pixel * self.embed_dims_inner, + self.embed_dims_outer) + self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer)) + self.patch_pos = nn.Parameter( + torch.zeros(1, num_patches + 1, self.embed_dims_outer)) + self.pixel_pos = nn.Parameter( + torch.zeros(1, self.embed_dims_inner, new_patch_size[0], + new_patch_size[1])) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, self.num_layers) + ] # stochastic depth decay rule + self.layers = ModuleList() + for i in range(self.num_layers): + block_cfg = dict( + ffn_ratio=ffn_ratio, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + batch_first=True) + self.layers.append( + TnTLayer( + num_pixel=num_pixel, + embed_dims_inner=self.embed_dims_inner, + embed_dims_outer=self.embed_dims_outer, + num_heads_inner=self.num_heads_inner, + num_heads_outer=self.num_heads_outer, + inner_block_cfg=block_cfg, + outer_block_cfg=block_cfg, + norm_cfg=norm_cfg)) + + self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.patch_pos, std=.02) + trunc_normal_(self.pixel_pos, std=.02) + + def forward(self, x): + B = x.shape[0] + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj( + self.projection( + self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat( + (self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.drop_after_pos(patch_embed) + + for layer in self.layers: + pixel_embed, patch_embed = layer(pixel_embed, patch_embed) + + patch_embed = self.norm(patch_embed) + return (patch_embed[:, 0], ) diff --git a/mmpretrain/models/backbones/twins.py b/mmpretrain/models/backbones/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..be55c02db1daa5cb37760f2066448b3fca2cb893 --- /dev/null +++ b/mmpretrain/models/backbones/twins.py @@ -0,0 +1,721 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import ConditionalPositionEncoding, MultiheadAttention + + +class GlobalSubsampledAttention(MultiheadAttention): + """Global Sub-sampled Attention (GSA) module. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + sr_ratio (float): The ratio of spatial reduction in attention modules. + Defaults to 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + norm_cfg=dict(type='LN'), + qkv_bias=True, + sr_ratio=1, + **kwargs): + super(GlobalSubsampledAttention, + self).__init__(embed_dims, num_heads, **kwargs) + + self.qkv_bias = qkv_bias + self.q = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias) + self.kv = nn.Linear(self.input_dims, embed_dims * 2, bias=qkv_bias) + + # remove self.qkv, here split into self.q, self.kv + delattr(self, 'qkv') + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + # use a conv as the spatial-reduction operation, the kernel_size + # and stride in conv are equal to the sr_ratio. + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + assert H * W == N, 'The product of h and w of hw_shape must be N, ' \ + 'which is the 2nd dim number of the input Tensor x.' + + q = self.q(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x = x.permute(0, 2, 1).reshape(B, C, *hw_shape) # BNC_2_BCHW + x = self.sr(x) + x = x.reshape(B, C, -1).permute(0, 2, 1) # BCHW_2_BNC + x = self.norm(x) + + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GlobalSubsampledAttention(GSA). + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): The ratio of spatial reduction in attention modules. + Defaults to 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, \ + f'dim {embed_dims} should be divided by num_heads {num_heads}' + + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + x = x.view(B, H, W, C) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(B, _h, self.window_size, _w, self.window_size, + C).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(B, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, C // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.window_size, + self.window_size, C) + x = attn.transpose(2, 3).reshape(B, _h * self.window_size, + _w * self.window_size, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer with LocallyGroupedSelfAttention(LSA). + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +@MODELS.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + arch (dict, str): PCPVT architecture, a str value in arch zoo or a + detailed configuration dict with 7 keys, and the length of all the + values in dict should be the same: + + - depths (List[int]): The number of encoder layers in each stage. + - embed_dims (List[int]): Embedding dimension in each stage. + - patch_sizes (List[int]): The patch sizes in each stage. + - num_heads (List[int]): Numbers of attention head in each stage. + - strides (List[int]): The strides in each stage. + - mlp_ratios (List[int]): The ratios of mlp in each stage. + - sr_ratios (List[int]): The ratios of GSA-encoder layers in each + stage. + + in_channels (int): Number of input channels. Defaults to 3. + out_indices (tuple[int]): Output from which stages. + Defaults to ``(3, )``. + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0 + drop_path_rate (float): Stochastic depth rate. Defaults to 0.0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + norm_after_stage(bool, List[bool]): Add extra norm after each stage. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import PCPVT + >>> import torch + >>> pcpvt_cfg = {'arch': "small", + >>> 'norm_after_stage': [False, False, False, True]} + >>> model = PCPVT(**pcpvt_cfg) + >>> x = torch.rand(1, 3, 224, 224) + >>> outputs = model(x) + >>> print(outputs[-1].shape) + torch.Size([1, 512, 7, 7]) + >>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True] + >>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3) + >>> model = PCPVT(**pcpvt_cfg) + >>> outputs = model(x) + >>> for feat in outputs: + >>> print(feat.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 320, 14, 14]) + torch.Size([1, 512, 7, 7]) + """ + arch_zoo = { + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 4, 6, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 4, 18, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 8, 27, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + } # yapf: disable + + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides', + 'mlp_ratios', 'sr_ratios' + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + norm_after_stage=False, + init_cfg=None): + super(PCPVT, self).__init__(init_cfg=init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + assert isinstance(arch, dict) and ( + set(arch) == self.essential_keys + ), f'Custom arch needs a dict with keys {self.essential_keys}.' + self.arch_settings = arch + + self.depths = self.arch_settings['depths'] + self.embed_dims = self.arch_settings['embed_dims'] + self.patch_sizes = self.arch_settings['patch_sizes'] + self.strides = self.arch_settings['strides'] + self.mlp_ratios = self.arch_settings['mlp_ratios'] + self.num_heads = self.arch_settings['num_heads'] + self.sr_ratios = self.arch_settings['sr_ratios'] + + self.num_extra_tokens = 0 # there is no cls-token in Twins + self.num_stage = len(self.depths) + for key, value in self.arch_settings.items(): + assert isinstance(value, list) and len(value) == self.num_stage, ( + 'Length of setting item in arch dict must be type of list and' + ' have the same length.') + + # patch_embeds + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.stages = ModuleList() + + for i in range(self.num_stage): + # use in_channels of the model in the first stage + if i == 0: + stage_in_channels = in_channels + else: + stage_in_channels = self.embed_dims[i - 1] + + self.patch_embeds.append( + PatchEmbed( + in_channels=stage_in_channels, + embed_dims=self.embed_dims[i], + conv_type='Conv2d', + kernel_size=self.patch_sizes[i], + stride=self.strides[i], + padding='corner', + norm_cfg=dict(type='LN'))) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + # PEGs + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in self.embed_dims + ]) + + # stochastic depth + total_depth = sum(self.depths) + self.dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(self.depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=self.embed_dims[k], + num_heads=self.num_heads[k], + feedforward_channels=self.mlp_ratios[k] * + self.embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=self.dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=norm_cfg, + sr_ratio=self.sr_ratios[k]) for i in range(self.depths[k]) + ]) + self.stages.append(_block) + cur += self.depths[k] + + self.out_indices = out_indices + + assert isinstance(norm_after_stage, (bool, list)) + if isinstance(norm_after_stage, bool): + self.norm_after_stage = [norm_after_stage] * self.num_stage + else: + self.norm_after_stage = norm_after_stage + assert len(self.norm_after_stage) == self.num_stage, \ + (f'Number of norm_after_stage({len(self.norm_after_stage)}) should' + f' be equal to the number of stages({self.num_stage}).') + + for i, has_norm in enumerate(self.norm_after_stage): + assert isinstance(has_norm, bool), 'norm_after_stage should be ' \ + 'bool or List[bool].' + if has_norm and norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm_after_stage{i}', norm_layer) + + def init_weights(self): + if self.init_cfg is not None: + super(PCPVT, self).init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(self.num_stage): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.stages[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + + norm_layer = getattr(self, f'norm_after_stage{i}') + x = norm_layer(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@MODELS.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + arch (dict, str): SVT architecture, a str value in arch zoo or a + detailed configuration dict with 8 keys, and the length of all the + values in dict should be the same: + + - depths (List[int]): The number of encoder layers in each stage. + - embed_dims (List[int]): Embedding dimension in each stage. + - patch_sizes (List[int]): The patch sizes in each stage. + - num_heads (List[int]): Numbers of attention head in each stage. + - strides (List[int]): The strides in each stage. + - mlp_ratios (List[int]): The ratios of mlp in each stage. + - sr_ratios (List[int]): The ratios of GSA-encoder layers in each + stage. + - windiow_sizes (List[int]): The window sizes in LSA-encoder layers + in each stage. + + in_channels (int): Number of input channels. Defaults to 3. + out_indices (tuple[int]): Output from which stages. + Defaults to (3, ). + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + drop_rate (float): Dropout rate. Defaults to 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Defaults to 0.0 + drop_path_rate (float): Stochastic depth rate. Defaults to 0.2. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + norm_after_stage(bool, List[bool]): Add extra norm after each stage. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SVT + >>> import torch + >>> svt_cfg = {'arch': "small", + >>> 'norm_after_stage': [False, False, False, True]} + >>> model = SVT(**svt_cfg) + >>> x = torch.rand(1, 3, 224, 224) + >>> outputs = model(x) + >>> print(outputs[-1].shape) + torch.Size([1, 512, 7, 7]) + >>> svt_cfg["out_indices"] = (0, 1, 2, 3) + >>> svt_cfg["norm_after_stage"] = [True, True, True, True] + >>> model = SVT(**svt_cfg) + >>> output = model(x) + >>> for feat in output: + >>> print(feat.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 320, 14, 14]) + torch.Size([1, 512, 7, 7]) + """ + arch_zoo = { + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 256, 512], + 'depths': [2, 2, 10, 4], + 'num_heads': [2, 4, 8, 16], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [96, 192, 384, 768], + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [128, 256, 512, 1024], + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + } # yapf: disable + + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides', + 'mlp_ratios', 'sr_ratios', 'window_sizes' + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.0, + norm_cfg=dict(type='LN'), + norm_after_stage=False, + init_cfg=None): + super(SVT, self).__init__(arch, in_channels, out_indices, qkv_bias, + drop_rate, attn_drop_rate, drop_path_rate, + norm_cfg, norm_after_stage, init_cfg) + + self.window_sizes = self.arch_settings['window_sizes'] + + for k in range(self.num_stage): + for i in range(self.depths[k]): + # in even-numbered layers of each stage, replace GSA with LSA + if i % 2 == 0: + ffn_channels = self.mlp_ratios[k] * self.embed_dims[k] + self.stages[k][i] = \ + LSAEncoderLayer( + embed_dims=self.embed_dims[k], + num_heads=self.num_heads[k], + feedforward_channels=ffn_channels, + drop_rate=drop_rate, + norm_cfg=norm_cfg, + attn_drop_rate=attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:k])+i], + qkv_bias=qkv_bias, + window_size=self.window_sizes[k]) diff --git a/mmpretrain/models/backbones/van.py b/mmpretrain/models/backbones/van.py new file mode 100644 index 0000000000000000000000000000000000000000..c34dc3362f84ffa39151219f038f0c74ee0242e8 --- /dev/null +++ b/mmpretrain/models/backbones/van.py @@ -0,0 +1,434 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class MixFFN(BaseModule): + """An implementation of MixFFN of VAN. Refer to + mmdetection/mmdet/models/backbones/pvt.py. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Depth-wise Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + init_cfg=None): + super(MixFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + + self.fc1 = Conv2d( + in_channels=embed_dims, + out_channels=feedforward_channels, + kernel_size=1) + self.dwconv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=feedforward_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=embed_dims, + kernel_size=1) + self.drop = nn.Dropout(ffn_drop) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LKA(BaseModule): + """Large Kernel Attention(LKA) of VAN. + + .. code:: text + DW_conv (depth-wise convolution) + | + | + DW_D_conv (depth-wise dilation convolution) + | + | + Transition Convolution (1×1 convolution) + + Args: + embed_dims (int): Number of input channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, init_cfg=None): + super(LKA, self).__init__(init_cfg=init_cfg) + + # a spatial local convolution (depth-wise convolution) + self.DW_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=5, + padding=2, + groups=embed_dims) + + # a spatial long-range convolution (depth-wise dilation convolution) + self.DW_D_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=7, + stride=1, + padding=9, + groups=embed_dims, + dilation=3) + + self.conv1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + u = x.clone() + attn = self.DW_conv(x) + attn = self.DW_D_conv(attn) + attn = self.conv1(attn) + + return u * attn + + +class SpatialAttention(BaseModule): + """Basic attention module in VANBloack. + + Args: + embed_dims (int): Number of input channels. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None): + super(SpatialAttention, self).__init__(init_cfg=init_cfg) + + self.proj_1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = LKA(embed_dims) + self.proj_2 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class VANBlock(BaseModule): + """A block of VAN. + + Args: + embed_dims (int): Number of input channels. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-2. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + ffn_ratio=4., + drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN', eps=1e-5), + layer_scale_init_value=1e-2, + init_cfg=None): + super(VANBlock, self).__init__(init_cfg=init_cfg) + self.out_channels = embed_dims + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + mlp_hidden_dim = int(embed_dims * ffn_ratio) + self.mlp = MixFFN( + embed_dims=embed_dims, + feedforward_channels=mlp_hidden_dim, + act_cfg=act_cfg, + ffn_drop=drop_rate) + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + + def forward(self, x): + identity = x + x = self.norm1(x) + x = self.attn(x) + if self.layer_scale_1 is not None: + x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + identity = x + x = self.norm2(x) + x = self.mlp(x) + if self.layer_scale_2 is not None: + x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + return x + + +class VANPatchEmbed(PatchEmbed): + """Image to Patch Embedding of VAN. + + The differences between VANPatchEmbed & PatchEmbed: + 1. Use BN. + 2. Do not use 'flatten' and 'transpose'. + """ + + def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs): + super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs) + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +@MODELS.register_module() +class VAN(BaseBackbone): + """Visual Attention Network. + + A PyTorch implement of : `Visual Attention Network + `_ + + Inspiration from + https://github.com/Visual-Attention-Network/VAN-Classification + + Args: + arch (str | dict): Visual Attention Network architecture. + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **ffn_ratios** (List[int]): The number of expansion ratio of + feedforward network hidden layer channels. + + Defaults to 'tiny'. + patch_sizes (List[int | tuple]): The patch size in patch embeddings. + Defaults to [7, 3, 3, 3]. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import VAN + >>> import torch + >>> cfg = dict(arch='tiny') + >>> model = VAN(**cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for out in outputs: + >>> print(out.size()) + (1, 256, 7, 7) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': [32, 64, 160, 256], + 'depths': [3, 3, 5, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [2, 2, 4, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 3, 12, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 5, 27, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + patch_sizes=[7, 3, 3, 3], + in_channels=3, + drop_rate=0., + drop_path_rate=0., + out_indices=(3, ), + frozen_stages=-1, + norm_eval=False, + norm_cfg=dict(type='LN'), + block_cfgs=dict(), + init_cfg=None): + super(VAN, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.ffn_ratios = self.arch_settings['ffn_ratios'] + self.num_stages = len(self.depths) + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + for i, depth in enumerate(self.depths): + patch_embed = VANPatchEmbed( + in_channels=in_channels if i == 0 else self.embed_dims[i - 1], + input_size=None, + embed_dims=self.embed_dims[i], + kernel_size=patch_sizes[i], + stride=patch_sizes[i] // 2 + 1, + padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), + norm_cfg=dict(type='BN')) + + blocks = ModuleList([ + VANBlock( + embed_dims=self.embed_dims[i], + ffn_ratio=self.ffn_ratios[i], + drop_rate=drop_rate, + drop_path_rate=dpr[cur_block_idx + j], + **block_cfgs) for j in range(depth) + ]) + cur_block_idx += depth + norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + + self.add_module(f'patch_embed{i + 1}', patch_embed) + self.add_module(f'blocks{i + 1}', blocks) + self.add_module(f'norm{i + 1}', norm) + + def train(self, mode=True): + super(VAN, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = getattr(self, f'patch_embed{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = getattr(self, f'blocks{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + blocks = getattr(self, f'blocks{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, hw_shape = patch_embed(x) + for block in blocks: + x = block(x) + x = x.flatten(2).transpose(1, 2) + x = norm(x) + x = x.reshape(-1, *hw_shape, + block.out_channels).permute(0, 3, 1, 2).contiguous() + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/vgg.py b/mmpretrain/models/backbones/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..026b916256cf56cdf75d348ee07b0ceceffd9751 --- /dev/null +++ b/mmpretrain/models/backbones/vgg.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +def make_vgg_layer(in_channels, + out_channels, + num_blocks, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dilation=1, + with_norm=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layer = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers.append(layer) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +@MODELS.register_module() +class VGG(BaseBackbone): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_norm (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int], optional): Output from which stages. + When it is None, the default behavior depends on whether + num_classes is specified. If num_classes <= 0, the default value is + (4, ), output the last feature map before classifier. If + num_classes > 0, the default value is (5, ), output the + classification score. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. + with_last_pool (bool): Whether to keep the last pooling before + classifier. Default: True. + """ + + # Parameters to build layers. Each element specifies the number of conv in + # each stage. For example, VGG11 contains 11 layers with learnable + # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, + # where 3 indicates the last three fully-connected layers. + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=False, + ceil_mode=False, + with_last_pool=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1., layer=['_BatchNorm']), + dict(type='Normal', std=0.01, layer=['Linear']) + ]): + super(VGG, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + + self.num_classes = num_classes + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + with_norm = norm_cfg is not None + + if out_indices is None: + out_indices = (5, ) if num_classes > 0 else (4, ) + assert max(out_indices) <= num_stages + self.out_indices = out_indices + + self.in_channels = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + out_channels = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.in_channels, + out_channels, + num_blocks, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dilation=dilation, + with_norm=with_norm, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.in_channels = out_channels + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + vgg_layers = getattr(self, self.module_name) + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + m = vgg_layers[j] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(VGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/vig.py b/mmpretrain/models/backbones/vig.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a7879bd99682c32cbd1e02079fe79e2c6a3d0a --- /dev/null +++ b/mmpretrain/models/backbones/vig.py @@ -0,0 +1,852 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modified from +# https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def get_2d_relative_pos_embed(embed_dim, grid_size): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, grid_size*grid_size] + """ + pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) + relative_pos = 2 * np.matmul(pos_embed, + pos_embed.transpose()) / pos_embed.shape[1] + return relative_pos + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def xy_pairwise_distance(x, y): + """Compute pairwise distance of a point cloud. + + Args: + x: tensor (batch_size, num_points, num_dims) + y: tensor (batch_size, num_points, num_dims) + Returns: + pairwise distance: (batch_size, num_points, num_points) + """ + with torch.no_grad(): + xy_inner = -2 * torch.matmul(x, y.transpose(2, 1)) + x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) + y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) + return x_square + xy_inner + y_square.transpose(2, 1) + + +def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): + """Get KNN based on the pairwise distance. + + Args: + x: (batch_size, num_dims, num_points, 1) + y: (batch_size, num_dims, num_points, 1) + k: int + relative_pos:Whether to use relative_pos + Returns: + nearest neighbors: + (batch_size, num_points, k) (batch_size, num_points, k) + """ + with torch.no_grad(): + x = x.transpose(2, 1).squeeze(-1) + y = y.transpose(2, 1).squeeze(-1) + batch_size, n_points, n_dims = x.shape + dist = xy_pairwise_distance(x.detach(), y.detach()) + if relative_pos is not None: + dist += relative_pos + _, nn_idx = torch.topk(-dist, k=k) + center_idx = torch.arange( + 0, n_points, device=x.device).repeat(batch_size, k, + 1).transpose(2, 1) + return torch.stack((nn_idx, center_idx), dim=0) + + +class DenseDilated(nn.Module): + """Find dilated neighbor from neighbor list. + + edge_index: (2, batch_size, num_points, k) + """ + + def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): + super(DenseDilated, self).__init__() + self.dilation = dilation + self.use_stochastic = use_stochastic + self.epsilon = epsilon + self.k = k + + def forward(self, edge_index): + if self.use_stochastic: + if torch.rand(1) < self.epsilon and self.training: + num = self.k * self.dilation + randnum = torch.randperm(num)[:self.k] + edge_index = edge_index[:, :, :, randnum] + else: + edge_index = edge_index[:, :, :, ::self.dilation] + else: + edge_index = edge_index[:, :, :, ::self.dilation] + return edge_index + + +class DenseDilatedKnnGraph(nn.Module): + """Find the neighbors' indices based on dilated knn.""" + + def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): + super(DenseDilatedKnnGraph, self).__init__() + self.dilation = dilation + self.use_stochastic = use_stochastic + self.epsilon = epsilon + self.k = k + self._dilated = DenseDilated(k, dilation, use_stochastic, epsilon) + + def forward(self, x, y=None, relative_pos=None): + if y is not None: + x = F.normalize(x, p=2.0, dim=1) + y = F.normalize(y, p=2.0, dim=1) + + edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, + relative_pos) + else: + x = F.normalize(x, p=2.0, dim=1) + y = x.clone() + + edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, + relative_pos) + return self._dilated(edge_index) + + +class BasicConv(Sequential): + + def __init__(self, + channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True, + drop=0.): + m = [] + for i in range(1, len(channels)): + m.append( + nn.Conv2d( + channels[i - 1], + channels[i], + 1, + bias=graph_conv_bias, + groups=4)) + if norm_cfg is not None: + m.append(build_norm_layer(norm_cfg, channels[-1])) + if act_cfg is not None: + m.append(build_activation_layer(act_cfg)) + if drop > 0: + m.append(nn.Dropout2d(drop)) + + super(BasicConv, self).__init__(*m) + + +def batched_index_select(x, idx): + r"""fetches neighbors features from a given neighbor idx + + Args: + x (Tensor): input feature Tensor + :math: + `\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. + idx (Tensor): edge_idx + :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. + Returns: + Tensor: output neighbors features + :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. + """ + batch_size, num_dims, num_vertices_reduced = x.shape[:3] + _, num_vertices, k = idx.shape + idx_base = torch.arange( + 0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced + idx = idx + idx_base + idx = idx.contiguous().view(-1) + + x = x.transpose(2, 1) + feature = x.contiguous().view(batch_size * num_vertices_reduced, + -1)[idx, :] + feature = feature.view(batch_size, num_vertices, k, + num_dims).permute(0, 3, 1, 2).contiguous() + return feature + + +class MRConv2d(nn.Module): + """Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) + for dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(MRConv2d, self).__init__() + self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + + def forward(self, x, edge_index, y=None): + x_i = batched_index_select(x, edge_index[1]) + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) + b, c, n, _ = x.shape + x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], + dim=2).reshape(b, 2 * c, n, _) + return self.nn(x) + + +class EdgeConv2d(nn.Module): + """Edge convolution layer (with activation, batch normalization) for dense + data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(EdgeConv2d, self).__init__() + self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + + def forward(self, x, edge_index, y=None): + x_i = batched_index_select(x, edge_index[1]) + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + max_value, _ = torch.max( + self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) + return max_value + + +class GraphSAGE(nn.Module): + """GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) + for dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GraphSAGE, self).__init__() + self.nn1 = BasicConv([in_channels, in_channels], act_cfg, norm_cfg, + graph_conv_bias) + self.nn2 = BasicConv([in_channels * 2, out_channels], act_cfg, + norm_cfg, graph_conv_bias) + + def forward(self, x, edge_index, y=None): + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True) + return self.nn2(torch.cat([x, x_j], dim=1)) + + +class GINConv2d(nn.Module): + """GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for + dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GINConv2d, self).__init__() + self.nn = BasicConv([in_channels, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + eps_init = 0.0 + self.eps = nn.Parameter(torch.Tensor([eps_init])) + + def forward(self, x, edge_index, y=None): + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j = torch.sum(x_j, -1, keepdim=True) + return self.nn((1 + self.eps) * x + x_j) + + +class GraphConv2d(nn.Module): + """Static graph convolution layer.""" + + def __init__(self, + in_channels, + out_channels, + graph_conv_type, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GraphConv2d, self).__init__() + if graph_conv_type == 'edge': + self.gconv = EdgeConv2d(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + elif graph_conv_type == 'mr': + self.gconv = MRConv2d(in_channels, out_channels, act_cfg, norm_cfg, + graph_conv_bias) + elif graph_conv_type == 'sage': + self.gconv = GraphSAGE(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + elif graph_conv_type == 'gin': + self.gconv = GINConv2d(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + else: + raise NotImplementedError( + 'graph_conv_type:{} is not supported'.format(graph_conv_type)) + + def forward(self, x, edge_index, y=None): + return self.gconv(x, edge_index, y) + + +class DyGraphConv2d(GraphConv2d): + """Dynamic graph convolution layer.""" + + def __init__(self, + in_channels, + out_channels, + k=9, + dilation=1, + graph_conv_type='mr', + act_cfg=dict(type='GELU'), + norm_cfg=None, + graph_conv_bias=True, + use_stochastic=False, + epsilon=0.2, + r=1): + super(DyGraphConv2d, + self).__init__(in_channels, out_channels, graph_conv_type, + act_cfg, norm_cfg, graph_conv_bias) + self.k = k + self.d = dilation + self.r = r + self.dilated_knn_graph = DenseDilatedKnnGraph(k, dilation, + use_stochastic, epsilon) + + def forward(self, x, relative_pos=None): + B, C, H, W = x.shape + y = None + if self.r > 1: + y = F.avg_pool2d(x, self.r, self.r) + y = y.reshape(B, C, -1, 1).contiguous() + x = x.reshape(B, C, -1, 1).contiguous() + edge_index = self.dilated_knn_graph(x, y, relative_pos) + x = super(DyGraphConv2d, self).forward(x, edge_index, y) + return x.reshape(B, -1, H, W).contiguous() + + +class Grapher(nn.Module): + """Grapher module with graph convolution and fc layers.""" + + def __init__(self, + in_channels, + k=9, + dilation=1, + graph_conv_type='mr', + act_cfg=dict(type='GELU'), + norm_cfg=None, + graph_conv_bias=True, + use_stochastic=False, + epsilon=0.2, + r=1, + n=196, + drop_path=0.0, + relative_pos=False): + super(Grapher, self).__init__() + self.channels = in_channels + self.n = n + self.r = r + self.fc1 = Sequential( + nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), in_channels), + ) + self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, k, + dilation, graph_conv_type, act_cfg, + norm_cfg, graph_conv_bias, + use_stochastic, epsilon, r) + self.fc2 = Sequential( + nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), in_channels), + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.relative_pos = None + if relative_pos: + relative_pos_tensor = torch.from_numpy( + np.float32( + get_2d_relative_pos_embed(in_channels, int( + n**0.5)))).unsqueeze(0).unsqueeze(1) + relative_pos_tensor = F.interpolate( + relative_pos_tensor, + size=(n, n // (r * r)), + mode='bicubic', + align_corners=False) + self.relative_pos = nn.Parameter( + -relative_pos_tensor.squeeze(1), requires_grad=False) + + def _get_relative_pos(self, relative_pos, H, W): + if relative_pos is None or H * W == self.n: + return relative_pos + else: + N = H * W + N_reduced = N // (self.r * self.r) + return F.interpolate( + relative_pos.unsqueeze(0), size=(N, N_reduced), + mode='bicubic').squeeze(0) + + def forward(self, x): + B, C, H, W = x.shape + relative_pos = self._get_relative_pos(self.relative_pos, H, W) + shortcut = x + x = self.fc1(x) + x = self.graph_conv(x, relative_pos) + x = self.fc2(x) + x = self.drop_path(x) + shortcut + return x + + +class FFN(nn.Module): + """"out_features = out_features or in_features\n + hidden_features = hidden_features or in_features""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop_path=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = Sequential( + nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), hidden_features), + ) + self.act = build_activation_layer(act_cfg) + self.fc2 = Sequential( + nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), out_features), + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop_path(x) + shortcut + return x + + +@MODELS.register_module() +class Vig(BaseBackbone): + """Vision GNN backbone. + + A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes + `_. + + Modified from the official implementation + https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch + + Args: + arch(str): Vision GNN architecture, + choose from 'tiny', 'small' and 'base'. + in_channels (int): The number of channels of input images. + Defaults to 3. + k (int): The number of KNN's k. Defaults to 9. + out_indices (Sequence | int): Output from which blocks. + Defaults to -1, means the last block. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='GELU'))``. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN', eps=1e-6)``. + graph_conv_bias (bool): Whether to use bias in the convolution + layers in Grapher. Defaults to True. + graph_conv_type (str): The type of graph convolution,choose + from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. + epsilon (float): Probability of random arrangement in KNN. It only + works when ``use_dilation=True`` and ``use_stochastic=True``. + Defaults to 0.2. + use_dilation(bool): Whether to use dilation in KNN. Defaults to True. + use_stochastic(bool): Whether to use stochastic in KNN. + Defaults to False. + drop_path (float): stochastic depth rate. Default 0.0 + relative_pos(bool): Whether to use relative position embedding. + Defaults to False. + norm_eval (bool): Whether to set the normalization layer to eval mode. + Defaults to False. + frozen_stages (int): Blocks to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): The initialization configs. + Defaults to None. + """ # noqa: E501 + + arch_settings = { + 'tiny': dict(num_blocks=12, channels=192), + 'small': dict(num_blocks=16, channels=320), + 'base': dict(num_blocks=16, channels=640), + } + + def __init__(self, + arch, + in_channels=3, + k=9, + out_indices=-1, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + graph_conv_bias=True, + graph_conv_type='mr', + epsilon=0.2, + use_dilation=True, + use_stochastic=False, + drop_path=0., + relative_pos=False, + norm_eval=False, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + arch = self.arch_settings[arch] + self.num_blocks = arch['num_blocks'] + channels = arch['channels'] + + if isinstance(out_indices, int): + out_indices = [out_indices] + elif isinstance(out_indices, tuple): + out_indices = list(out_indices) + elif not isinstance(out_indices, list): + raise TypeError('"out_indices" must by a tuple, list or int, ' + f'get {type(out_indices)} instead.') + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_blocks + index + assert 0 <= out_indices[i] <= self.num_blocks, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.stem = Sequential( + nn.Conv2d(in_channels, channels // 8, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 8), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 8, channels // 4, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 4), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 4, channels // 2, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 2), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 2, channels, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels), + build_activation_layer(act_cfg), + nn.Conv2d(channels, channels, 3, stride=1, padding=1), + build_norm_layer(norm_cfg, channels), + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] + # number of knn's k + num_knn = [ + int(x.item()) for x in torch.linspace(k, 2 * k, self.num_blocks) + ] + max_dilation = 196 // max(num_knn) + + self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14)) + + self.blocks = ModuleList([ + Sequential( + Grapher( + in_channels=channels, + k=num_knn[i], + dilation=min(i // 4 + + 1, max_dilation) if use_dilation else 1, + graph_conv_type=graph_conv_type, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + graph_conv_bias=graph_conv_bias, + use_stochastic=use_stochastic, + epsilon=epsilon, + drop_path=dpr[i], + relative_pos=relative_pos), + FFN(in_features=channels, + hidden_features=channels * 4, + act_cfg=act_cfg, + drop_path=dpr[i])) for i in range(self.num_blocks) + ]) + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + def forward(self, inputs): + outs = [] + x = self.stem(inputs) + self.pos_embed + + for i, block in enumerate(self.blocks): + x = block(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + self.stem.eval() + for i in range(self.frozen_stages): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(Vig, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class PyramidVig(BaseBackbone): + """Pyramid Vision GNN backbone. + + A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes + `_. + + Modified from the official implementation + https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch + + Args: + arch (str): Vision GNN architecture, choose from 'tiny', + 'small' and 'base'. + in_channels (int): The number of channels of input images. + Defaults to 3. + k (int): The number of KNN's k. Defaults to 9. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='GELU'))``. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN')``. + graph_conv_bias (bool): Whether to use bias in the convolution + layers in Grapher. Defaults to True. + graph_conv_type (str): The type of graph convolution,choose + from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. + epsilon (float): Probability of random arrangement in KNN. It only + works when ``use_stochastic=True``. Defaults to 0.2. + use_stochastic (bool): Whether to use stochastic in KNN. + Defaults to False. + drop_path (float): stochastic depth rate. Default 0.0 + norm_eval (bool): Whether to set the normalization layer to eval mode. + Defaults to False. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): The initialization configs. + Defaults to None. + """ # noqa: E501 + arch_settings = { + 'tiny': dict(blocks=[2, 2, 6, 2], channels=[48, 96, 240, 384]), + 'small': dict(blocks=[2, 2, 6, 2], channels=[80, 160, 400, 640]), + 'medium': dict(blocks=[2, 2, 16, 2], channels=[96, 192, 384, 768]), + 'base': dict(blocks=[2, 2, 18, 2], channels=[128, 256, 512, 1024]), + } + + def __init__(self, + arch, + in_channels=3, + k=9, + out_indices=-1, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + graph_conv_bias=True, + graph_conv_type='mr', + epsilon=0.2, + use_stochastic=False, + drop_path=0., + norm_eval=False, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + arch = self.arch_settings[arch] + self.blocks = arch['blocks'] + self.num_blocks = sum(self.blocks) + self.num_stages = len(self.blocks) + channels = arch['channels'] + self.channels = channels + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] <= self.num_stages, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.stem = Sequential( + nn.Conv2d(in_channels, channels[0] // 2, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels[0] // 2), + build_activation_layer(act_cfg), + nn.Conv2d(channels[0] // 2, channels[0], 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels[0]), + build_activation_layer(act_cfg), + nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1), + build_norm_layer(norm_cfg, channels[0]), + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] + # number of knn's k + num_knn = [ + int(x.item()) for x in torch.linspace(k, k, self.num_blocks) + ] + max_dilation = 49 // max(num_knn) + + self.pos_embed = nn.Parameter( + torch.zeros(1, channels[0], 224 // 4, 224 // 4)) + HW = 224 // 4 * 224 // 4 + reduce_ratios = [4, 2, 1, 1] + + self.stages = ModuleList() + block_idx = 0 + for stage_idx, num_blocks in enumerate(self.blocks): + mid_channels = channels[stage_idx] + reduce_ratio = reduce_ratios[stage_idx] + blocks = [] + if stage_idx > 0: + blocks.append( + Sequential( + nn.Conv2d( + self.channels[stage_idx - 1], + mid_channels, + kernel_size=3, + stride=2, + padding=1), + build_norm_layer(norm_cfg, mid_channels), + )) + HW = HW // 4 + for _ in range(num_blocks): + blocks.append( + Sequential( + Grapher( + in_channels=mid_channels, + k=num_knn[block_idx], + dilation=min(block_idx // 4 + 1, max_dilation), + graph_conv_type=graph_conv_type, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + graph_conv_bias=graph_conv_bias, + use_stochastic=use_stochastic, + epsilon=epsilon, + r=reduce_ratio, + n=HW, + drop_path=dpr[block_idx], + relative_pos=True), + FFN(in_features=mid_channels, + hidden_features=mid_channels * 4, + act_cfg=act_cfg, + drop_path=dpr[block_idx]))) + block_idx += 1 + self.stages.append(Sequential(*blocks)) + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + def forward(self, inputs): + outs = [] + x = self.stem(inputs) + self.pos_embed + + for i, blocks in enumerate(self.stages): + x = blocks(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + self.stem.eval() + for i in range(self.frozen_stages): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(PyramidVig, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a54053c217d18824357ad7250cb6c52be212f15d --- /dev/null +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -0,0 +1,537 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer, + resize_pos_embed, to_2tuple) +from .base_backbone import BaseBackbone + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + ffn_type (str): Select the type of ffn layers. Defaults to 'origin'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + layer_scale_init_value=0., + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + ffn_type='origin', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + layer_scale_init_value=layer_scale_init_value) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + if ffn_type == 'origin': + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + layer_scale_init_value=layer_scale_init_value) + elif ffn_type == 'swiglu_fused': + self.ffn = SwiGLUFFNFused( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + layer_scale_init_value=layer_scale_init_value) + else: + raise NotImplementedError + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +@MODELS.register_module() +class VisionTransformer(BaseBackbone): + """Vision Transformer. + + A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], + { + # The same as the implementation in MAE + # + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + **dict.fromkeys( + ['eva-g', 'eva-giant'], + { + # The implementation in EVA + # + 'embed_dims': 1408, + 'num_layers': 40, + 'num_heads': 16, + 'feedforward_channels': 6144 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small', 'dinov2-s', 'dinov2-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + **dict.fromkeys( + ['dinov2-g', 'dinov2-giant'], { + 'embed_dims': 1536, + 'num_layers': 40, + 'num_heads': 24, + 'feedforward_channels': 6144 + }), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + frozen_stages=-1, + interpolate_mode='bicubic', + layer_scale_init_value=0., + patch_cfg=dict(), + layer_cfgs=dict(), + pre_norm=False, + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm, # disable bias if pre_norm is used(e.g., CLIP) + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + if pre_norm: + self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims) + else: + self.pre_norm = nn.Identity() + + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + if self.out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(VisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze pre-norm + for param in self.pre_norm.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == 'avg_featmap': + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/vit_eva02.py b/mmpretrain/models/backbones/vit_eva02.py new file mode 100644 index 0000000000000000000000000000000000000000..20ec4b247bbdbfc209c353c8e001d34d71a3990c --- /dev/null +++ b/mmpretrain/models/backbones/vit_eva02.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer, + resize_pos_embed) +from .vision_transformer import VisionTransformer + + +class AttentionWithRoPE(BaseModule): + """Multi-head Attention Module with 2D sincos position embedding (RoPE). + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q and v. Note + that we follows the official implementation where ``k_bias`` + is 0. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + rope (:obj:`torch.nn.Module`, optional): If it is an object of the + ``RotaryEmbedding``, the rotation of the token position will be + performed before the softmax. Defaults to None. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + qkv_bias=True, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + init_cfg=None): + super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = qk_scale or self.head_dims**-0.5 + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.with_cls_token = with_cls_token + + self.rope = rope + + def forward(self, x, patch_resolution): + B, N, _ = x.shape + + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.rope: + if self.with_cls_token: + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t, patch_resolution) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] if self.with_cls_token else k + ro_k_t = self.rope(k_t, patch_resolution) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + else: + q = self.rope(q, patch_resolution) + k = self.rope(k, patch_resolution) + + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class EVA02EndcoderLayer(BaseModule): + """Implements one encoder EVA02EndcoderLayer in EVA02. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension of FFNs. + sub_ln (bool): Whether to add the sub layer normalization + in the attention module. Defaults to False. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool): enable bias for projection in the attention module + if True. Defaults to True. + rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object + in the attention module. Defaults to None. + drop_rate (float): Dropout rate in the mlp module. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + sub_ln=False, + attn_drop=0., + proj_drop=0., + qkv_bias=False, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + init_cfg=None): + super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims) + + self.attn = AttentionWithRoPE( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + rope=rope, + with_cls_token=with_cls_token) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate)) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims) + + if drop_rate > 0: + dropout_layer = dict(type='Dropout', drop_prob=drop_rate) + else: + dropout_layer = None + + if sub_ln: + ffn_norm = norm_cfg + else: + ffn_norm = None + + self.mlp = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + dropout_layer=dropout_layer, + norm_cfg=ffn_norm, + add_identity=False, + ) + + def forward(self, x, patch_resolution): + inputs = x + x = self.norm1(x) + x = self.attn(x, patch_resolution) + x = self.drop_path(x) + x = inputs + x + + inputs = x + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = inputs + x + + return x + + +@MODELS.register_module() +class ViTEVA02(VisionTransformer): + """EVA02 Vision Transformer. + + A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'tiny', 'small', 'base', 'large'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **mlp_ratio** (float): The ratio of the mlp module. + + Defaults to 'tiny'. + + sub_ln (bool): Whether to add the sub layer normalization in swiglu. + Defaults to False. + drop_rate (float): Probability of an element to be zeroed in the + mlp module. Defaults to 0. + attn_drop_rate (float): Probability of an element to be zeroed after + the softmax in the attention. Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed after + projection in the attention. Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + **kwargs(dict, optional): Other args for Vision Transformer. + """ + arch_zoo = { + **dict.fromkeys( + ['t', 'ti', 'tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': int(192 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': int(384 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': int(768 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': int(1024 * 4 * 2 / 3) + }) + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='tiny', + sub_ln=False, + drop_rate=0., + attn_drop_rate=0., + proj_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN'), + with_cls_token=True, + layer_cfgs=dict(), + **kwargs): + # set essential args for Vision Transformer + kwargs.update( + arch=arch, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + with_cls_token=with_cls_token) + super(ViTEVA02, self).__init__(**kwargs) + + self.num_heads = self.arch_settings['num_heads'] + + # Set RoPE + head_dim = self.embed_dims // self.num_heads + self.rope = RotaryEmbeddingFast( + embed_dims=head_dim, patch_resolution=self.patch_resolution) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self. + arch_settings['feedforward_channels'], + sub_ln=sub_ln, + norm_cfg=norm_cfg, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + rope=self.rope, + with_cls_token=with_cls_token, + drop_path_rate=dpr[i]) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(EVA02EndcoderLayer(**_layer_cfg)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, patch_resolution) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/vit_sam.py b/mmpretrain/models/backbones/vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0fc866531e2fe405948e1aa8738fd210bbfc34 --- /dev/null +++ b/mmpretrain/models/backbones/vit_sam.py @@ -0,0 +1,700 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import LayerNorm2d, build_norm_layer, resize_pos_embed, to_2tuple +from .base_backbone import BaseBackbone + + +def window_partition(x: torch.Tensor, + window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Partition into non-overlapping windows with padding if needed. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + x (torch.Tensor): Input tokens with [B, H, W, C]. + window_size (int): Window size. + + Returns: + Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + + - ``windows``: Windows after partition with + [B * num_windows, window_size, window_size, C]. + - ``(Hp, Wp)``: Padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int]) -> torch.Tensor: + """Window unpartition into original sequences and removing padding. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + x (torch.Tensor): Input tokens with + [B * num_windows, window_size, window_size, C]. + window_size (int): Window size. + pad_hw (tuple): Padded height and width (Hp, Wp). + hw (tuple): Original height and width (H, W) before padding. + + Returns: + torch.Tensor: Unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, + rel_pos: torch.Tensor) -> torch.Tensor: + """Get relative positional embeddings according to the relative positions + of query and key sizes. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + q_size (int): Size of query q. + k_size (int): Size of key k. + rel_pos (torch.Tensor): Relative position embeddings (L, C). + + Returns: + torch.Tensor: Extracted positional embeddings according to relative + positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode='linear', + ) + rel_pos_resized = rel_pos_resized.reshape(-1, + max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - + k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """Borrowed from https://github.com/facebookresearch/segment-anything/ + + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (torch.Tensor): Attention map. + q (torch.Tensor): Query q in the attention layer with shape + (B, q_h * q_w, C). + rel_pos_h (torch.Tensor): Relative position embeddings (Lh, C) for + height axis. + rel_pos_w (torch.Tensor): Relative position embeddings (Lw, C) for + width axis. + q_size (tuple): Spatial sequence size of query q with (q_h, q_w). + k_size (tuple): Spatial sequence size of key k with (k_h, k_w). + + Returns: + torch.Tensor: Attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) + rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to False. + input_size (int, optional): Input resolution for calculating the + relative positional parameter size. Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = head_embed_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dims, embed_dims) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert (input_size is not None), \ + 'Input size must be provided if using relative position embed.' + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_embed_dims)) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * input_size[1] - 1, head_embed_dims)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, + -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, + self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, + -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class TransformerEncoderLayer(BaseModule): + """Encoder layer with window attention in Vision Transformer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to False. + window_size (int): Window size for window attention. Defaults to 0. + input_size (int, optional): Input resolution for calculating the + relative positional parameter size. Defaults to None. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + use_rel_pos: bool = False, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.window_size = window_size + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = Attention( + embed_dims=embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + input_size=input_size if window_size == 0 else + (window_size, window_size), + ) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x): + shortcut = x + x = self.ln1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = shortcut + x + + x = self.ffn(self.ln2(x), identity=x) + return x + + +@MODELS.register_module() +class ViTSAM(BaseBackbone): + """Vision Transformer as image encoder used in SAM. + + A PyTorch implement of backbone: `Segment Anything + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'base', 'large', 'huge'. If use dict, it should have + below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + - **global_attn_indexes** (int): The index of layers with global + attention. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_channels (int): The num of output channels, if equal to 0, the + channel reduction layer is disabled. Defaults to 256. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + out_type (str): The type of output features. Please choose from + + - ``"raw"`` or ``"featmap"``: The feature map tensor from the + patch tokens with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + + Defaults to ``"raw"``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + use_abs_pos (bool): Whether to use absolute position embedding. + Defaults to True. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to True. + window_size (int): Window size for window attention. Defaults to 14. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + 'global_attn_indexes': [2, 5, 8, 11] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + 'global_attn_indexes': [5, 11, 17, 23] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120, + 'global_attn_indexes': [7, 15, 23, 31] + }), + } + OUT_TYPES = {'raw', 'featmap', 'avg_featmap'} + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_channels: int = 256, + out_indices: int = -1, + out_type: str = 'raw', + drop_rate: float = 0., + drop_path_rate: float = 0., + qkv_bias: bool = True, + use_abs_pos: bool = True, + use_rel_pos: bool = True, + window_size: int = 14, + norm_cfg: dict = dict(type='LN', eps=1e-6), + frozen_stages: int = -1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[dict] = None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.global_attn_indexes = self.arch_settings['global_attn_indexes'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + self.use_abs_pos = use_abs_pos + self.interpolate_mode = interpolate_mode + if use_abs_pos: + # Set position embedding + self.pos_embed = nn.Parameter( + torch.zeros(1, *self.patch_resolution, self.embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + if use_rel_pos: + self._register_load_state_dict_pre_hook( + self._prepare_relative_position) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + window_size=window_size + if i not in self.global_attn_indexes else 0, + input_size=self.patch_resolution, + use_rel_pos=use_rel_pos, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + if 'type' in _layer_cfg: + self.layers.append(MODELS.build(_layer_cfg)) + else: + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.out_channels = out_channels + if self.out_channels > 0: + self.channel_reduction = nn.Sequential( + nn.Conv2d( + self.embed_dims, + out_channels, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_channels, eps=1e-6), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_channels, eps=1e-6), + ) + + # freeze stages only when self.frozen_stages > 0 + self.frozen_stages = frozen_stages + if self.frozen_stages > 0: + self._freeze_stages() + + def init_weights(self): + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze channel_reduction module + if self.frozen_stages == self.num_layers and self.out_channels > 0: + m = self.channel_reduction + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + x = x.view(B, patch_resolution[0], patch_resolution[1], + self.embed_dims) + + if self.use_abs_pos: + # 'resize_pos_embed' only supports 'pos_embed' with ndim==3, but + # in ViTSAM, the 'pos_embed' has 4 dimensions (1, H, W, C), so it + # is flattened. Besides, ViTSAM doesn't have any extra token. + resized_pos_embed = resize_pos_embed( + self.pos_embed.flatten(1, 2), + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=0) + x = x + resized_pos_embed.view(1, *patch_resolution, + self.embed_dims) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i in self.out_indices: + # (B, H, W, C) -> (B, C, H, W) + x_reshape = x.permute(0, 3, 1, 2) + + if self.out_channels > 0: + x_reshape = self.channel_reduction(x_reshape) + outs.append(self._format_output(x_reshape)) + + return tuple(outs) + + def _format_output(self, x) -> torch.Tensor: + if self.out_type == 'raw' or self.out_type == 'featmap': + return x + elif self.out_type == 'avg_featmap': + # (B, C, H, W) -> (B, C, N) -> (B, N, C) + x = x.flatten(2).permute(0, 2, 1) + return x.mean(dim=1) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3] + pos_embed_shape = self.patch_embed.init_out_size + + flattened_pos_embed = state_dict[name].flatten(1, 2) + resized_pos_embed = resize_pos_embed(flattened_pos_embed, + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, 0) + state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape, + self.embed_dims) + + def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs): + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'rel_pos_' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + relative_position_pretrained = state_dict[ckpt_key] + relative_position_current = state_dict_model[key] + L1, _ = relative_position_pretrained.size() + L2, _ = relative_position_current.size() + if L1 != L2: + new_rel_pos = F.interpolate( + relative_position_pretrained.reshape(1, L1, + -1).permute( + 0, 2, 1), + size=L2, + mode='linear', + ) + new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info(f'Resize the {ckpt_key} from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos.shape}') + state_dict[ckpt_key] = new_rel_pos + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/xcit.py b/mmpretrain/models/backbones/xcit.py new file mode 100644 index 0000000000000000000000000000000000000000..392ebbedf457cc199b70afa1923ec0b698f7fd5b --- /dev/null +++ b/mmpretrain/models/backbones/xcit.py @@ -0,0 +1,770 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import ConvModule, DropPath +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, Sequential +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer, to_2tuple +from .base_backbone import BaseBackbone + +if digit_version(torch.__version__) < digit_version('1.8.0'): + floor_div = torch.floor_divide +else: + floor_div = partial(torch.div, rounding_mode='floor') + + +class ClassAttntion(BaseModule): + """Class Attention Module. + + A PyTorch implementation of Class Attention Module introduced by: + `Going deeper with Image Transformers `_ + + taken from + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + with slight modifications to do CA + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): The drop out rate for linear output weights. + Defaults to 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg=None): + + super(ClassAttntion, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # We only need to calculate query of cls token. + q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, + C // self.num_heads).permute( + 0, 2, 1, 3) + k = self.k(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + q = q * self.scale + v = self.v(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = self.proj(x_cls) + x_cls = self.proj_drop(x_cls) + + return x_cls + + +class PositionalEncodingFourier(BaseModule): + """Positional Encoding using a fourier kernel. + + A PyTorch implementation of Positional Encoding relying on + a fourier kernel introduced by: + `Attention is all you Need `_ + + Based on the `official XCiT code + `_ + + Args: + hidden_dim (int): The hidden feature dimension. Defaults to 32. + dim (int): The output feature dimension. Defaults to 768. + temperature (int): A control variable for position encoding. + Defaults to 10000. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + hidden_dim: int = 32, + dim: int = 768, + temperature: int = 10000, + init_cfg=None): + super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg) + + self.token_projection = ConvModule( + in_channels=hidden_dim * 2, + out_channels=dim, + kernel_size=1, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + self.eps = 1e-6 + + def forward(self, B: int, H: int, W: int): + device = self.token_projection.conv.weight.device + y_embed = torch.arange( + 1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float() + x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float() + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, device=device).float() + dim_t = floor_div(dim_t, 2) + dim_t = self.temperature**(2 * dim_t / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], + dim=4).flatten(3) + pos_y = torch.stack( + [pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos.repeat(B, 1, 1, 1) # (B, C, H, W) + + +class ConvPatchEmbed(BaseModule): + """Patch Embedding using multiple convolution layers. + + Args: + img_size (int, tuple): input image size. + Defaults to 224, means the size is 224*224. + patch_size (int): The patch size in conv patch embedding. + Defaults to 16. + in_channels (int): The input channels of this module. + Defaults to 3. + embed_dims (int): The feature dimension + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + img_size: Union[int, tuple] = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg) + img_size = to_2tuple(img_size) + num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + conv = partial( + ConvModule, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + layer = [] + if patch_size == 16: + layer.append( + conv(in_channels=in_channels, out_channels=embed_dims // 8)) + layer.append( + conv( + in_channels=embed_dims // 8, out_channels=embed_dims // 4)) + elif patch_size == 8: + layer.append( + conv(in_channels=in_channels, out_channels=embed_dims // 4)) + else: + raise ValueError('For patch embedding, the patch size must be 16 ' + f'or 8, but get patch size {self.patch_size}.') + + layer.append( + conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2)) + layer.append( + conv( + in_channels=embed_dims // 2, + out_channels=embed_dims, + act_cfg=None, + )) + + self.proj = Sequential(*layer) + + def forward(self, x: torch.Tensor): + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) # (B, N, C) + return x, (Hp, Wp) + + +class ClassAttentionBlock(BaseModule): + """Transformer block using Class Attention. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): The hidden dimension ratio for FFN. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + tokens_norm (bool): Whether to normalize all tokens or just the + cls_token in the CA. Defaults to False. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop=0., + attn_drop=0., + drop_path=0., + layer_scale_init_value=1., + tokens_norm=False, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + + super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, dim) + + self.attn = ClassAttntion( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, dim) + + self.ffn = FFN( + embed_dims=dim, + feedforward_channels=int(dim * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + ) + + if layer_scale_init_value > 0: + self.gamma1 = nn.Parameter(layer_scale_init_value * + torch.ones(dim)) + self.gamma2 = nn.Parameter(layer_scale_init_value * + torch.ones(dim)) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 + self.tokens_norm = tokens_norm + + def forward(self, x): + x_norm1 = self.norm1(x) + x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) + x = x + self.drop_path(self.gamma1 * x_attn) + if self.tokens_norm: + x = self.norm2(x) + else: + x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.ffn(cls_token, identity=0) + x = torch.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class LPI(BaseModule): + """Local Patch Interaction module. + + A PyTorch implementation of Local Patch Interaction module + as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers + `_ + + Local Patch Interaction module that allows explicit communication between + tokens in 3x3 windows to augment the implicit communication performed by + the block diagonal scatter attention. Implemented using 2 layers of + separable 3x3 convolutions with GeLU and BatchNorm2d + + Args: + in_features (int): The input channels. + out_features (int, optional): The output channels. Defaults to None. + kernel_size (int): The kernel_size in ConvModule. Defaults to 3. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_features: int, + out_features: Optional[int] = None, + kernel_size: int = 3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(LPI, self).__init__(init_cfg=init_cfg) + + out_features = out_features or in_features + padding = kernel_size // 2 + + self.conv1 = ConvModule( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + groups=in_features, + bias=True, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('conv', 'act', 'norm')) + + self.conv2 = ConvModule( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=out_features, + norm_cfg=None, + act_cfg=None) + + def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + return x + + +class XCA(BaseModule): + r"""Cross-Covariance Attention module. + + A PyTorch implementation of Cross-Covariance Attention module + as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers + `_ + + In Cross-Covariance Attention (XCA), the channels are updated using a + weighted sum. The weights are obtained from the (softmax normalized) + Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)` + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): The drop out rate for linear output weights. + Defaults to 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg=None): + super(XCA, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + # (qkv, B, num_heads, channels per head, N) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # Paper section 3.2 l2-Normalization and temperature scaling + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C) + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class XCABlock(BaseModule): + """Transformer block using XCA. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): The hidden dimension ratio for FFNs. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + bn_norm_cfg (dict): Config dict for batchnorm in LPI and + ConvPatchEmbed. Defaults to ``dict(type='BN')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + layer_scale_init_value: float = 1., + bn_norm_cfg=dict(type='BN'), + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(XCABlock, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, dim) + self.attn = XCA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = build_norm_layer(norm_cfg, dim) + self.local_mp = LPI( + in_features=dim, + norm_cfg=bn_norm_cfg, + act_cfg=act_cfg, + ) + + self.norm2 = build_norm_layer(norm_cfg, dim) + self.ffn = FFN( + embed_dims=dim, + feedforward_channels=int(dim * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + ) + + self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + + def forward(self, x, H: int, W: int): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + # NOTE official code has 3 then 2, so keeping it the same to be + # consistent with loaded weights See + # https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 + x = x + self.drop_path( + self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path( + self.gamma2 * self.ffn(self.norm2(x), identity=0)) + return x + + +@MODELS.register_module() +class XCiT(BaseBackbone): + """XCiT backbone. + + A PyTorch implementation of XCiT backbone introduced by: + `XCiT: Cross-Covariance Image Transformers + `_ + + Args: + img_size (int, tuple): Input image size. Defaults to 224. + patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. + embed_dims (int): Embedding dimension. Defaults to 768. + depth (int): depth of vision transformer. Defaults to 12. + cls_attn_layers (int): Depth of Class attention layers. + Defaults to 2. + num_heads (int): Number of attention heads. Defaults to 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_pos_embed (bool): Whether to use positional encoding. + Defaults to True. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + tokens_norm (bool): Whether to normalize all tokens or just the + cls_token in the CA. Defaults to False. + out_indices (Sequence[int]): Output from which layers. + Defaults to (-1, ). + frozen_stages (int): Layers to be frozen (all param fixed), and 0 + means to freeze the stem stage. Defaults to -1, which means + not freeze any parameters. + bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and + ConvPatchEmbed. Defaults to ``dict(type='BN')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + img_size: Union[int, tuple] = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + depth: int = 12, + cls_attn_layers: int = 2, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + use_pos_embed: bool = True, + layer_scale_init_value: float = 1., + tokens_norm: bool = False, + out_type: str = 'cls_token', + out_indices: Sequence[int] = (-1, ), + final_norm: bool = True, + frozen_stages: int = -1, + bn_norm_cfg=dict(type='BN'), + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=dict(type='TruncNormal', layer='Linear')): + super(XCiT, self).__init__(init_cfg=init_cfg) + + img_size = to_2tuple(img_size) + if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0): + raise ValueError(f'`patch_size` ({patch_size}) should divide ' + f'the image shape ({img_size}) evenly.') + + self.embed_dims = embed_dims + + assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token') + self.out_type = out_type + + self.patch_embed = ConvPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + norm_cfg=bn_norm_cfg, + act_cfg=act_cfg, + ) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.use_pos_embed = use_pos_embed + if use_pos_embed: + self.pos_embed = PositionalEncodingFourier(dim=embed_dims) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.xca_layers = nn.ModuleList() + self.ca_layers = nn.ModuleList() + self.num_layers = depth + cls_attn_layers + + for _ in range(depth): + self.xca_layers.append( + XCABlock( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + bn_norm_cfg=bn_norm_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + layer_scale_init_value=layer_scale_init_value, + )) + + for _ in range(cls_attn_layers): + self.ca_layers.append( + ClassAttentionBlock( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + tokens_norm=tokens_norm, + )) + + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims) + + # Transform out_indices + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + if frozen_stages > self.num_layers + 1: + raise ValueError('frozen_stages must be less than ' + f'{self.num_layers} but get {frozen_stages}') + self.frozen_stages = frozen_stages + + def init_weights(self): + super().init_weights() + + if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained': + return + + trunc_normal_(self.cls_token, std=.02) + + def _freeze_stages(self): + if self.frozen_stages < 0: + return + + # freeze position embedding + if self.use_pos_embed: + self.pos_embed.eval() + for param in self.pos_embed.parameters(): + param.requires_grad = False + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # set dropout to eval model + self.pos_drop.eval() + # freeze cls_token, only use in self.Clslayers + if self.frozen_stages > len(self.xca_layers): + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages): + if i <= len(self.xca_layers): + m = self.xca_layers[i - 1] + else: + m = self.ca_layers[i - len(self.xca_layers) - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze the last layer norm if all_stages are frozen + if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers): + self.norm.eval() + for param in self.norm.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + B = x.shape[0] + # x is (B, N, C). (Hp, Hw) is the patch resolution + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos_embed: + # (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp) + x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1) + x = self.pos_drop(x) + + for i, layer in enumerate(self.xca_layers): + x = layer(x, Hp, Wp) + if i in self.out_indices: + outs.append(self._format_output(x, (Hp, Wp), False)) + + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) + + for i, layer in enumerate(self.ca_layers): + x = layer(x) + if i == len(self.ca_layers) - 1: + x = self.norm(x) + if i + len(self.xca_layers) in self.out_indices: + outs.append(self._format_output(x, (Hp, Wp), True)) + + return tuple(outs) + + def _format_output(self, x, hw, with_cls_token: bool): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + if not with_cls_token: + raise ValueError( + 'Cannot output cls_token since there is no cls_token.') + return x[:, 0] + + patch_token = x[:, 1:] if with_cls_token else x + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/builder.py b/mmpretrain/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea4e25c8d6db3bbf07ab94ea08c08e474ec3595 --- /dev/null +++ b/mmpretrain/models/builder.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +CLASSIFIERS = MODELS +RETRIEVER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_classifier(cfg): + """Build classifier.""" + return CLASSIFIERS.build(cfg) + + +def build_retriever(cfg): + """Build retriever.""" + return RETRIEVER.build(cfg) diff --git a/mmpretrain/models/classifiers/__init__.py b/mmpretrain/models/classifiers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa276ff5a2152beb93c4d1b42e6bbf4e2cbf822 --- /dev/null +++ b/mmpretrain/models/classifiers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseClassifier +from .hugging_face import HuggingFaceClassifier +from .image import ImageClassifier +from .timm import TimmClassifier + +__all__ = [ + 'BaseClassifier', 'ImageClassifier', 'TimmClassifier', + 'HuggingFaceClassifier' +] diff --git a/mmpretrain/models/classifiers/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/classifiers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9744069339b4877870aa63b384e7dbf6ed3442e7 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/base.cpython-311.pyc b/mmpretrain/models/classifiers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..552a72e3196df5463e58750fcfbd9c50adff8542 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/base.cpython-311.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/hugging_face.cpython-311.pyc b/mmpretrain/models/classifiers/__pycache__/hugging_face.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbcd50bcd798953e2994577cf1af503c2236d45c Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/hugging_face.cpython-311.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/image.cpython-311.pyc b/mmpretrain/models/classifiers/__pycache__/image.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b71cc1257d874890acd1916929a55e04a9686079 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/image.cpython-311.pyc differ diff --git a/mmpretrain/models/classifiers/__pycache__/timm.cpython-311.pyc b/mmpretrain/models/classifiers/__pycache__/timm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f5d2bd10e610544ef37f45e9e4189a6ea1d5604 Binary files /dev/null and b/mmpretrain/models/classifiers/__pycache__/timm.cpython-311.pyc differ diff --git a/mmpretrain/models/classifiers/base.py b/mmpretrain/models/classifiers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a65fc213f4bfe271a9298b823ba38fc4ca9f57e1 --- /dev/null +++ b/mmpretrain/models/classifiers/base.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Sequence + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + + +class BaseClassifier(BaseModel, metaclass=ABCMeta): + """Base class for classifiers. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + + Attributes: + init_cfg (dict): Initialization config dict. + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__(self, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None): + super(BaseClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + @property + def with_neck(self) -> bool: + """Whether the classifier has a neck.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Whether the classifier has a head.""" + return hasattr(self, 'head') and self.head is not None + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`BaseDataElement`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) + in general. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmengine.BaseDataElement`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def extract_feats(self, multi_inputs: Sequence[torch.Tensor], + **kwargs) -> list: + """Extract features from a sequence of input tensor. + + Args: + multi_inputs (Sequence[torch.Tensor]): A sequence of input + tensor. It can be used in augmented inference. + **kwargs: Other keyword arguments accepted by :meth:`extract_feat`. + + Returns: + list: Features of every input tensor. + """ + assert isinstance(multi_inputs, Sequence), \ + '`extract_feats` is used for a sequence of inputs tensor. If you '\ + 'want to extract on single inputs tensor, use `extract_feat`.' + return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs] diff --git a/mmpretrain/models/classifiers/hugging_face.py b/mmpretrain/models/classifiers/hugging_face.py new file mode 100644 index 0000000000000000000000000000000000000000..26a8fda51b0d01ee54ba71665caedbb8a7bd842c --- /dev/null +++ b/mmpretrain/models/classifiers/hugging_face.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All right reserved. +import re +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import require +from .base import BaseClassifier + + +@MODELS.register_module() +class HuggingFaceClassifier(BaseClassifier): + """Image classifiers for HuggingFace model. + + This class accepts all positional and keyword arguments of the API + ``from_pretrained`` (when ``pretrained=True``) and ``from_config`` (when + ``pretrained=False``) of `transformers.AutoModelForImageClassification`_ + and use it to create a model from hugging-face. + + It can load checkpoints of hugging-face directly, and the saved checkpoints + also can be directly load by hugging-face. + + Please confirm that you have installed ``transfromers`` if you want to use it. + + .. _transformers.AutoModelForImageClassification: + https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification + + Args: + model_name (str): The name of the model to use in hugging-face. + pretrained (bool): Whether to load pretrained checkpoint from + hugging-face. Defaults to False. + *args: Other positional arguments of the method + `from_pretrained` or `from_config`. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + **kwargs: Other keyword arguments of the method + `from_pretrained` or `from_config`. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_classifier + >>> cfg = dict(type='HuggingFaceClassifier', model_name='microsoft/resnet-50', pretrained=True) + >>> model = build_classifier(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> out = model(inputs) + >>> print(out.shape) + torch.Size([1, 1000]) + """ # noqa: E501 + + @require('transformers') + def __init__(self, + model_name, + pretrained=False, + *model_args, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + train_cfg: Optional[dict] = None, + with_cp: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + from transformers import AutoConfig, AutoModelForImageClassification + if pretrained: + self.model = AutoModelForImageClassification.from_pretrained( + model_name, *model_args, **kwargs) + else: + config = AutoConfig.from_pretrained(model_name, *model_args, + **kwargs) + self.model = AutoModelForImageClassification.from_config(config) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + self.with_cp = with_cp + if self.with_cp: + self.model.gradient_checkpointing_enable() + + self._register_state_dict_hook(self._remove_state_dict_prefix) + self._register_load_state_dict_pre_hook(self._add_state_dict_prefix) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if mode == 'tensor': + return self.model(inputs).logits + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + raise NotImplementedError( + "The HuggingFaceClassifier doesn't support extract feature yet.") + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments of the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs).logits + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None): + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + + Returns: + List[DataSample]: The prediction results. + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs).logits + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + + return data_samples + + @staticmethod + def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = re.sub(f'^{prefix}model.', prefix, k) + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + new_prefix = prefix + 'model.' + for k in list(state_dict.keys()): + new_key = re.sub(f'^{prefix}', new_prefix, k) + state_dict[new_key] = state_dict[k] + del state_dict[k] diff --git a/mmpretrain/models/classifiers/image.py b/mmpretrain/models/classifiers/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0edd7aed8ce34a11b6cbbbdf2034bbcd1c652b --- /dev/null +++ b/mmpretrain/models/classifiers/image.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseClassifier + + +@MODELS.register_module() +class ImageClassifier(BaseClassifier): + """Image classifiers for supervised classification task. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in + :mod:`mmpretrain.model.utils.augment`. + - probs (List[float], optional): The probability of every batch + augmentation methods. If None, choose evenly. Defaults to None. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'ClsDataPreprocessor') + data_preprocessor.setdefault('batch_augments', train_cfg) + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super(ImageClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.backbone = backbone + self.neck = neck + self.head = head + + # If the model needs to load pretrain weights from a third party, + # the key can be modified with this hook + if hasattr(self.backbone, '_checkpoint_filter'): + self._register_load_state_dict_pre_hook( + self.backbone._checkpoint_filter) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor(s) without any + post-processing, same as a common PyTorch Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return self.head(feats) if self.with_head else feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs, stage='neck'): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + stage (str): Which stage to output the feature. Choose from: + + - "backbone": The output of backbone network. Returns a tuple + including multiple stages features. + - "neck": The output of neck module. Returns a tuple including + multiple stages features. + - "pre_logits": The feature before the final classification + linear layer. Usually returns a tensor. + + Defaults to "neck". + + Returns: + tuple | Tensor: The output of specified stage. + The output depends on detailed implementation. In general, the + output of backbone and neck is a tuple and the output of + pre_logits is a tensor. + + Examples: + 1. Backbone output + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 256, 14, 14]) + torch.Size([1, 512, 7, 7]) + + 2. Neck output + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64]) + torch.Size([1, 128]) + torch.Size([1, 256]) + torch.Size([1, 512]) + + 3. Pre-logits output (without the final linear classifier head) + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model + >>> model = build_classifier(cfg) + >>> + >>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + >>> print(out.shape) # The hidden dims in head is 3072 + torch.Size([1, 3072]) + """ # noqa: E501 + assert stage in ['backbone', 'neck', 'pre_logits'], \ + (f'Invalid output stage "{stage}", please choose from "backbone", ' + '"neck" and "pre_logits"') + + x = self.backbone(inputs) + + if stage == 'backbone': + return x + + if self.with_neck: + x = self.neck(x) + if stage == 'neck': + return x + + assert self.with_head and hasattr(self.head, 'pre_logits'), \ + "No head or the head doesn't implement `pre_logits` method." + return self.head.pre_logits(x) + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **kwargs) -> List[DataSample]: + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + feats = self.extract_feat(inputs) + return self.head.predict(feats, data_samples, **kwargs) + + def get_layer_depth(self, param_name: str): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + Tuple[int, int]: The layer-wise depth and the max depth. + """ + if hasattr(self.backbone, 'get_layer_depth'): + return self.backbone.get_layer_depth(param_name, 'backbone.') + else: + raise NotImplementedError( + f"The backbone {type(self.backbone)} doesn't " + 'support `get_layer_depth` by now.') diff --git a/mmpretrain/models/classifiers/timm.py b/mmpretrain/models/classifiers/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..d777b2e039d848b01fc9c6b6eaae6619bebb8938 --- /dev/null +++ b/mmpretrain/models/classifiers/timm.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All right reserved. +import re +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import require +from .base import BaseClassifier + + +@MODELS.register_module() +class TimmClassifier(BaseClassifier): + """Image classifiers for pytorch-image-models (timm) model. + + This class accepts all positional and keyword arguments of the function + `timm.models.create_model `_ and use + it to create a model from pytorch-image-models. + + It can load checkpoints of timm directly, and the saved checkpoints also + can be directly load by timm. + + Please confirm that you have installed ``timm`` if you want to use it. + + Args: + *args: All positional arguments of the function + `timm.models.create_model`. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + **kwargs: Other keyword arguments of the function + `timm.models.create_model`. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_classifier + >>> cfg = dict(type='TimmClassifier', model_name='resnet50', pretrained=True) + >>> model = build_classifier(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> out = model(inputs) + >>> print(out.shape) + torch.Size([1, 1000]) + """ # noqa: E501 + + @require('timm') + def __init__(self, + *args, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + train_cfg: Optional[dict] = None, + with_cp: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + from timm.models import create_model + self.model = create_model(*args, **kwargs) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + self.with_cp = with_cp + if self.with_cp: + self.model.set_grad_checkpointing() + + self._register_state_dict_hook(self._remove_state_dict_prefix) + self._register_load_state_dict_pre_hook(self._add_state_dict_prefix) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if mode == 'tensor': + return self.model(inputs) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + if hasattr(self.model, 'forward_features'): + return self.model.forward_features(inputs) + else: + raise NotImplementedError( + f"The model {type(self.model)} doesn't support extract " + "feature because it don't have `forward_features` method.") + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments of the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module(cls_score, target, **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None): + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + + Returns: + List[DataSample]: The prediction results. + """ + # The part can be traced by torch.fx + cls_score = self(inputs) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples=None): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + + return data_samples + + @staticmethod + def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = re.sub(f'^{prefix}model.', prefix, k) + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + new_prefix = prefix + 'model.' + for k in list(state_dict.keys()): + new_key = re.sub(f'^{prefix}', new_prefix, k) + state_dict[new_key] = state_dict[k] + del state_dict[k] diff --git a/mmpretrain/models/heads/__init__.py b/mmpretrain/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4364fb5626f321196952bc07bc2f54e3788a0ebe --- /dev/null +++ b/mmpretrain/models/heads/__init__.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beitv1_head import BEiTV1Head +from .beitv2_head import BEiTV2Head +from .cae_head import CAEHead +from .cls_head import ClsHead +from .conformer_head import ConformerHead +from .contrastive_head import ContrastiveHead +from .deit_head import DeiTClsHead +from .efficientformer_head import EfficientFormerClsHead +from .grounding_head import GroundingHead +from .itc_head import ITCHead +from .itm_head import ITMHead +from .itpn_clip_head import iTPNClipHead +from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead +from .levit_head import LeViTClsHead +from .linear_head import LinearClsHead +from .mae_head import MAEPretrainHead +from .margin_head import ArcFaceClsHead +from .mim_head import MIMHead +from .mixmim_head import MixMIMPretrainHead +from .mocov3_head import MoCoV3Head +from .multi_label_cls_head import MultiLabelClsHead +from .multi_label_csra_head import CSRAClsHead +from .multi_label_linear_head import MultiLabelLinearClsHead +from .multi_task_head import MultiTaskHead +from .seq_gen_head import SeqGenerationHead +from .simmim_head import SimMIMHead +from .spark_head import SparKPretrainHead +from .stacked_head import StackedLinearClsHead +from .swav_head import SwAVHead +from .vig_head import VigClsHead +from .vision_transformer_head import VisionTransformerClsHead +from .vqa_head import VQAGenerationHead + +__all__ = [ + 'ClsHead', + 'LinearClsHead', + 'StackedLinearClsHead', + 'MultiLabelClsHead', + 'MultiLabelLinearClsHead', + 'VisionTransformerClsHead', + 'DeiTClsHead', + 'ConformerHead', + 'EfficientFormerClsHead', + 'ArcFaceClsHead', + 'CSRAClsHead', + 'MultiTaskHead', + 'LeViTClsHead', + 'VigClsHead', + 'BEiTV1Head', + 'BEiTV2Head', + 'CAEHead', + 'ContrastiveHead', + 'LatentCrossCorrelationHead', + 'LatentPredictHead', + 'MAEPretrainHead', + 'MixMIMPretrainHead', + 'SwAVHead', + 'MoCoV3Head', + 'MIMHead', + 'SimMIMHead', + 'SeqGenerationHead', + 'VQAGenerationHead', + 'ITCHead', + 'ITMHead', + 'GroundingHead', + 'iTPNClipHead', + 'SparKPretrainHead', +] diff --git a/mmpretrain/models/heads/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21dc2871cf9ca1d28e1d2fe50bb2a0cc99e5d463 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/beitv1_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/beitv1_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ead5754eba9b8f1830afbd9c3bd80612ebbf9c8a Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/beitv1_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/beitv2_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/beitv2_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db05c0ea8f2ebfa47c2f095d34960f8fa14a269e Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/beitv2_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/cae_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/cae_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5a45043705da50290776dbd6e89e7475dd1dbdc Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/cae_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/cls_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/cls_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18868a3c23edbb9b05bb412161802ae8a5e810c5 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/cls_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/conformer_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/conformer_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a046debe430c0efc5df95f4065c9f90703c777a Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/conformer_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/contrastive_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/contrastive_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afd3eda70602a4c1e753e3f95b133c5106e2369a Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/contrastive_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/deit_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/deit_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fefcd5d73181a169b986fb9a566e8a8f358a4388 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/deit_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/efficientformer_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/efficientformer_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e273413453ce2c03dd2e17cf41144aeb9e4f92a Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/efficientformer_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/grounding_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/grounding_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1abd1be6fa34ceceb390c6bc5ac5d8153e72ac5f Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/grounding_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/itc_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/itc_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91344f220a80cd595f21dada58b4c4b79b7a8fac Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/itc_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/itm_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/itm_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70068a2d85d0e80bca206a605b6a6faf537df968 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/itm_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/itpn_clip_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/itpn_clip_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73cd0631fecf79674d8b43e10f79c6901a398704 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/itpn_clip_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/latent_heads.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/latent_heads.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..005b26ce5a8b69480d0c1a4702944b2c80511f4d Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/latent_heads.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/levit_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/levit_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6981c4a3d4db99415b6d998ff7750e1771f155c3 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/levit_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/linear_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/linear_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02675da543319b0d7ca0403a4cedce58267b583e Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/linear_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mae_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/mae_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13ee89fba4acefe38ae9333b3289880f06f49a4c Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mae_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/margin_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/margin_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebe79d3a9fb8ea099cb35536e571625ca263050f Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/margin_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mim_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/mim_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cca8c91c60ddd12e0d0fb1f3dafced1daa346217 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mim_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mixmim_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/mixmim_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..344c0ca95d71fbd726ed8278f44139ad13ec280f Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mixmim_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/mocov3_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/mocov3_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b89784f25c071d4d47060e0d12e78eafb21dc32 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/mocov3_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_label_cls_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/multi_label_cls_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89f2af8e1863c4205c18a9f0c1605bb1730766d4 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_label_cls_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_label_csra_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/multi_label_csra_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba78509a35454fdc239b6bb25c146a2a24ee6d6 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_label_csra_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_label_linear_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/multi_label_linear_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02deeaf34a780f77c51064d056eff125b96441ba Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_label_linear_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/multi_task_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/multi_task_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df8470460168b11905399af805fe967c6466345 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/multi_task_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/seq_gen_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/seq_gen_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6aafdfc738a59b3b2a829baff032c20f302d43b Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/seq_gen_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/simmim_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/simmim_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36762b2708ff9e9bd77d6db81a6f2b08d7030b42 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/simmim_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/spark_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/spark_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914fa99d9a019ba131509982842990251b3d0aaf Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/spark_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/stacked_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/stacked_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d277a88adcaa031c752a12a4fe64549a734c870a Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/stacked_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/swav_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/swav_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689ea8fef1365bf0fcdf93cda50b1ef4bcbecc93 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/swav_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/vig_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/vig_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01c58d62125dd6bee8ef80b2ee108d6c60bcd516 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/vig_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/vision_transformer_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/vision_transformer_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f08a2b9ae61ed4c031058a88bd0cb5ab472be33b Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/vision_transformer_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/__pycache__/vqa_head.cpython-311.pyc b/mmpretrain/models/heads/__pycache__/vqa_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4360241d059e39fd343c15ddc81de2f834a1f209 Binary files /dev/null and b/mmpretrain/models/heads/__pycache__/vqa_head.cpython-311.pyc differ diff --git a/mmpretrain/models/heads/beitv1_head.py b/mmpretrain/models/heads/beitv1_head.py new file mode 100644 index 0000000000000000000000000000000000000000..df422ea71c9090d1ab084bbc93c8889a4f2f402e --- /dev/null +++ b/mmpretrain/models/heads/beitv1_head.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV1Head(BaseModule): + """Head for BEiT v1 Pre-training. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = torch.argmax(target, dim=1).flatten(1) + target = target[mask] + + # remove cls_token + feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss_module(logits, target) + return loss diff --git a/mmpretrain/models/heads/beitv2_head.py b/mmpretrain/models/heads/beitv2_head.py new file mode 100644 index 0000000000000000000000000000000000000000..cf677a2cf7c1a3964f1ba884a0ccae83f8b70a40 --- /dev/null +++ b/mmpretrain/models/heads/beitv2_head.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Head(BaseModule): + """Head for BEiT v2 Pre-training. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor, + target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + feats_cls_pt (torch.Tensor) : Features from class late layers for + pretraining. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # shared cls head + logits = self.cls_head(feats[mask]) + logits_cls_pt = self.cls_head(feats_cls_pt[mask]) + + loss_1 = self.loss_module(logits, target) + loss_2 = self.loss_module(logits_cls_pt, target) + return loss_1, loss_2 diff --git a/mmpretrain/models/heads/cae_head.py b/mmpretrain/models/heads/cae_head.py new file mode 100644 index 0000000000000000000000000000000000000000..18a07f0a79297c35a39b9b2da0d25bf1eac6e70b --- /dev/null +++ b/mmpretrain/models/heads/cae_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CAEHead(BaseModule): + """Head for CAE Pre-training. + + Compute the align loss and the main loss. In addition, this head also + generates the prediction target generated by dalle. + + Args: + loss (dict): The config of loss. + tokenizer_path (str): The path of the tokenizer. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + + @torch.no_grad() + def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: + """Generate the reconstruction target. + + Args: + logits_target (torch.Tensor): The logits generated by DALL-E.s + + Returns: + torch.Tensor: The logits target. + """ + target = torch.argmax(logits_target, dim=1) + return target.flatten(1) + + def loss(self, logits: torch.Tensor, logits_target: torch.Tensor, + latent_pred: torch.Tensor, latent_target: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate loss. + + Args: + logits (torch.Tensor): Logits generated by decoder. + logits_target (img_target): Target generated by dalle for decoder + prediction. + latent_pred (torch.Tensor): Latent prediction by regressor. + latent_target (torch.Tensor): Target for latent prediction, + generated by teacher. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. + - ``loss_main`` (torch.Tensor): Cross entropy loss. + - ``loss_align`` (torch.Tensor): MSE loss. + """ + + target = self._generate_target(logits_target) # target features + target = target[mask].detach() + + # loss main for decoder, loss align for regressor + loss_main, loss_align = self.loss_module(logits, target, latent_pred, + latent_target) + + return (loss_main, loss_align) diff --git a/mmpretrain/models/heads/cls_head.py b/mmpretrain/models/heads/cls_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac4c51804122adbc92df8c8748e4109e205110f --- /dev/null +++ b/mmpretrain/models/heads/cls_head.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.evaluation.metrics import Accuracy +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class ClsHead(BaseModule): + """Classification head. + + Args: + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + topk: Union[int, Tuple[int]] = (1, ), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ClsHead, self).__init__(init_cfg=init_cfg) + + self.topk = topk + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + self.cal_acc = cal_acc + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ClsHead``, we just obtain the feature + of the last stage. + """ + # The ClsHead doesn't have other module, just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The ClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate(cls_score, target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: Optional[List[Optional[DataSample]]] = None + ) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample | None], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples diff --git a/mmpretrain/models/heads/conformer_head.py b/mmpretrain/models/heads/conformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..eade90d567b5cb9189f62919ad9a6a0e9c47ae23 --- /dev/null +++ b/mmpretrain/models/heads/conformer_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.evaluation.metrics import Accuracy +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +@MODELS.register_module() +class ConformerHead(ClsHead): + """Linear classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (Sequence[int]): Number of channels in the input + feature map. + init_cfg (dict | optional): The extra init config of layers. + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__( + self, + num_classes: int, + in_channels: Sequence[int], # [conv_dim, trans_dim] + init_cfg: dict = dict(type='TruncNormal', layer='Linear', std=.02), + **kwargs): + super(ConformerHead, self).__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + self.init_cfg = init_cfg + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes) + self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes) + + def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ConformerHead``, we just obtain the + feature of the last stage. + """ + # The ConformerHead doesn't have other module, + # just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]: + """The forward process.""" + x = self.pre_logits(feats) + # There are two outputs in the Conformer model + assert len(x) == 2 + + conv_cls_score = self.conv_cls_head(x[0]) + tran_cls_score = self.trans_cls_head(x[1]) + + return conv_cls_score, tran_cls_score + + def predict(self, + feats: Tuple[List[torch.Tensor]], + data_samples: List[DataSample] = None) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + conv_cls_score, tran_cls_score = self(feats) + cls_score = conv_cls_score + tran_cls_score + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_loss(self, cls_score: Tuple[torch.Tensor], + data_samples: List[DataSample], **kwargs) -> dict: + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = sum([ + self.loss_module( + score, target, avg_factor=score.size(0), **kwargs) + for score in cls_score + ]) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate( + cls_score[0] + cls_score[1], target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses diff --git a/mmpretrain/models/heads/contrastive_head.py b/mmpretrain/models/heads/contrastive_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6d1474aed59e2912ca4b5c24ce5a2430f50cb913 --- /dev/null +++ b/mmpretrain/models/heads/contrastive_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class ContrastiveHead(BaseModule): + """Head for contrastive learning. + + The contrastive loss is implemented in this head and is used in SimCLR, + MoCo, DenseCL, etc. + + Args: + loss (dict): Config dict for module of loss functions. + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Defaults to 0.1. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + temperature: float = 0.1, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + self.temperature = temperature + + def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor: + """Forward function to compute contrastive loss. + + Args: + pos (torch.Tensor): Nx1 positive similarity. + neg (torch.Tensor): Nxk negative similarity. + + Returns: + torch.Tensor: The contrastive loss. + """ + N = pos.size(0) + logits = torch.cat((pos, neg), dim=1) + logits /= self.temperature + labels = torch.zeros((N, ), dtype=torch.long).to(pos.device) + + loss = self.loss_module(logits, labels) + return loss diff --git a/mmpretrain/models/heads/deit_head.py b/mmpretrain/models/heads/deit_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a96f6e152711d23646e02312218c0c85e96300e8 --- /dev/null +++ b/mmpretrain/models/heads/deit_head.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .vision_transformer_head import VisionTransformerClsHead + + +@MODELS.register_module() +class DeiTClsHead(VisionTransformerClsHead): + """Distilled Vision Transformer classifier head. + + Comparing with the :class:`VisionTransformerClsHead`, this head adds an + extra linear layer to handle the dist token. The final classification score + is the average of both linear transformation results of ``cls_token`` and + ``dist_token``. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def _init_layers(self): + """"Init extra hidden linear layer to handle dist token if exists.""" + super(DeiTClsHead, self)._init_layers() + if self.hidden_dim is None: + head_dist = nn.Linear(self.in_channels, self.num_classes) + else: + head_dist = nn.Linear(self.hidden_dim, self.num_classes) + self.layers.add_module('head_dist', head_dist) + + def pre_logits(self, + feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``DeiTClsHead``, we obtain the + feature of the last stage and forward in hidden layer if exists. + """ + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + if len(feat) == 3: + _, cls_token, dist_token = feat + else: + cls_token, dist_token = feat + if self.hidden_dim is None: + return cls_token, dist_token + else: + cls_token = self.layers.act(self.layers.pre_logits(cls_token)) + dist_token = self.layers.act(self.layers.pre_logits(dist_token)) + return cls_token, dist_token + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + if self.training: + warnings.warn('MMPretrain cannot train the ' + 'distilled version DeiT.') + cls_token, dist_token = self.pre_logits(feats) + # The final classification head. + cls_score = (self.layers.head(cls_token) + + self.layers.head_dist(dist_token)) / 2 + return cls_score diff --git a/mmpretrain/models/heads/efficientformer_head.py b/mmpretrain/models/heads/efficientformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..09aa05b28533028723f599881777939a48982319 --- /dev/null +++ b/mmpretrain/models/heads/efficientformer_head.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +@MODELS.register_module() +class EfficientFormerClsHead(ClsHead): + """EfficientFormer classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + distillation (bool): Whether use a additional distilled head. + Defaults to True. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__(self, + num_classes, + in_channels, + distillation=True, + init_cfg=dict(type='Normal', layer='Linear', std=0.01), + *args, + **kwargs): + super(EfficientFormerClsHead, self).__init__( + init_cfg=init_cfg, *args, **kwargs) + self.in_channels = in_channels + self.num_classes = num_classes + self.dist = distillation + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.head = nn.Linear(self.in_channels, self.num_classes) + if self.dist: + self.dist_head = nn.Linear(self.in_channels, self.num_classes) + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.head(pre_logits) + + if self.dist: + cls_score = (cls_score + self.dist_head(pre_logits)) / 2 + return cls_score + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In :obj`EfficientFormerClsHead`, we just + obtain the feature of the last stage. + """ + # The EfficientFormerClsHead doesn't have other module, just return + # after unpacking. + return feats[-1] + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if self.dist: + raise NotImplementedError( + "MMPretrain doesn't support to train" + ' the distilled version EfficientFormer.') + else: + return super().loss(feats, data_samples, **kwargs) diff --git a/mmpretrain/models/heads/grounding_head.py b/mmpretrain/models/heads/grounding_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a47512ef5930dde51a7023a07c3412d759b6bd8c --- /dev/null +++ b/mmpretrain/models/heads/grounding_head.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models.utils.box_utils import (box_cxcywh_to_xyxy, + generalized_box_iou) +from mmpretrain.registry import MODELS, TOKENIZER + + +@MODELS.register_module() +class GroundingHead(BaseModule): + """bbox Coordination generation head for multi-modal pre-trained task, + adapted by BLIP. Normally used for visual grounding. + + Args: + loss: dict, + decoder: dict, + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + decoder: dict = None, + tokenizer: dict = None, + box_l1_loss_coeff=4.0, + box_giou_loss_coeff=2.0, + init_cfg: Optional[dict] = None, + ) -> None: + super(GroundingHead, self).__init__(init_cfg=init_cfg) + ''' init the decoder from med_config''' + self.decoder = None + if decoder: + self.decoder = MODELS.build(decoder) + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='none', ignore_index=-100) + + self.box_l1_loss_coeff = box_l1_loss_coeff + self.box_giou_loss_coeff = box_giou_loss_coeff + + if isinstance(tokenizer, dict): + self.tokenizer = TOKENIZER.build(tokenizer) + else: + self.tokenizer = tokenizer + + self.image_res = 640 + prefix_ids = torch.tensor( + self.tokenizer.convert_tokens_to_ids(['[unused339]'])) + target_ids = torch.tensor( + self.tokenizer.convert_tokens_to_ids( + [f'[unused{340+_}]' for _ in range(self.image_res + 1)])) + self.register_buffer('prefix_ids', prefix_ids) + self.register_buffer('target_ids', target_ids) + + bbox_prob_mask = torch.zeros(len(self.tokenizer)) + bbox_prob_mask[self.target_ids[0]:self.target_ids[-1] + 1] = 1 + bbox_prob_mask = (1.0 - bbox_prob_mask) * -10000.0 + self.register_buffer('bbox_prob_mask', bbox_prob_mask) + self.bin_start_idx = self.target_ids[0] + + def forward(self, text_embedding, text_embedding_mask, + encoder_hidden_states, encoder_attention_mask): + + # localize prompt token, text embedding + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + loc_prompt = self.prompt.weight.T + loc_prompt = torch.repeat_interleave(loc_prompt, + merge_att_mask.shape[0], + 0).unsqueeze(1) + + loc_prompt_mask = torch.ones(loc_prompt.shape[:-1]).long().to( + loc_prompt.device) + + decoder_out = self.decoder( + inputs_embeds=loc_prompt, + attention_mask=loc_prompt_mask, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + output_hidden_states=True, + labels=None, + ) + decoder_hs = decoder_out.hidden_states[-1][:, 0, :] + box_pred = self.box_head(decoder_hs) + return decoder_out, decoder_hs, box_pred + + def loss(self, + text_embedding, + text_embedding_mask, + encoder_hidden_states, + encoder_attention_mask, + decoder_targets, + return_scores=False): + """Calculate losses from the extracted features. + + Args: + feats (dict): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + answer_targets = (decoder_targets * + self.image_res).long() + self.bin_start_idx + prefix_ids = torch.repeat_interleave(self.prefix_ids, + merge_att_mask.shape[0], + 0).unsqueeze(-1) + prefix_ids = torch.cat([prefix_ids, answer_targets], dim=1) + + answer_output = self.decoder( + prefix_ids, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + labels=None, + return_dict=True, + ) + prob_mask = self.bbox_prob_mask.view(1, 1, + self.bbox_prob_mask.shape[-1]) + prediction_scores = answer_output.logits + prob_mask + + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = prefix_ids[:, 1:].contiguous() + vocab_size = len(self.tokenizer) + loss_seq_init = self.loss_fn( + shifted_prediction_scores.view(-1, vocab_size), labels.view(-1)) + + with torch.no_grad(): + pred_box = (torch.argmax( + prediction_scores[:, :-1, :].contiguous(), dim=-1) - + self.bin_start_idx) / self.image_res + weight_bbox = F.l1_loss( + pred_box, decoder_targets, reduction='none').clamp( + 0, 5) * self.box_l1_loss_coeff + weight_giou = (1 - torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(pred_box), + box_cxcywh_to_xyxy(decoder_targets))) + ) * self.box_giou_loss_coeff + bs = text_embedding.shape[0] + loss_seq = loss_seq_init[:].view(bs, -1, 4) + loss_seq = loss_seq * weight_bbox + loss_seq = loss_seq * weight_giou.unsqueeze(1) + + loss_seq = loss_seq.mean() + + losses = { + 'loss_seq': loss_seq, + 'loss_seq_init': loss_seq_init.mean(), + 'loss': loss_seq, + 'box_l1': weight_bbox.mean(-1).mean().detach(), + 'box_giou': weight_giou.mean().detach() + } + + return losses + + def predict( + self, + text_embedding, + text_embedding_mask, + encoder_hidden_states, + encoder_attention_mask, + ): + """Generates the bbox coordinates at inference time.""" + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + prefix_ids = torch.repeat_interleave(self.prefix_ids, + merge_att_mask.shape[0], + 0).unsqueeze(-1) + + for _ in range(4): + decoder_output = self.decoder( + prefix_ids, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + labels=None, + return_dict=True, + ) + prob_mask = self.bbox_prob_mask.view(1, 1, + self.bbox_prob_mask.shape[-1]) + prediction_scores = decoder_output.logits + prob_mask + + prefix_ids = torch.cat([ + prefix_ids, + torch.argmax(prediction_scores[:, -1, :], dim=-1).unsqueeze(1) + ], + dim=1) + + pred_box = self.process_bbox(prefix_ids[:, 1:]) # xywh 0-1 to xyxy 0-1 + + return pred_box + + @torch.no_grad() + def process_bbox(self, bbox): + bbox = bbox - self.bin_start_idx + bbox = torch.true_divide(bbox, self.image_res) + bbox = box_cxcywh_to_xyxy(bbox) + bbox = torch.clip(bbox, 0, 1) + assert torch.all(bbox <= 1) + return bbox diff --git a/mmpretrain/models/heads/itc_head.py b/mmpretrain/models/heads/itc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..006d52c76d9317809c7bb07519f4efb18716d8bd --- /dev/null +++ b/mmpretrain/models/heads/itc_head.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.dist import all_gather +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class ITCHead(BaseModule): + """Image-text matching head for multi-modal pre-trained task. Adapted by + BLIP, ALBEF. Normally used for retrieval task. + + Args: + embed_dim (int): Embed channel size for queue. + queue_size (int): Queue size for image and text. Defaults to 57600. + temperature (float): Temperature to calculate the similarity. + Defaults to 0.07. + use_distill (bool): Whether to use distill to calculate loss. + Defaults to True. + alpha (float): Weight for momentum similarity. Defaults to 0.4. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + embed_dim: int, + queue_size: int = 57600, + temperature: float = 0.07, + use_distill: bool = True, + alpha: float = 0.4, + init_cfg: Optional[dict] = None): + super(ITCHead, self).__init__(init_cfg=init_cfg) + self.temp = nn.Parameter(temperature * torch.ones([])) + self.use_distill = use_distill + if self.use_distill: + # create the queue + self.register_buffer('image_queue', + torch.randn(embed_dim, queue_size)) + self.register_buffer('text_queue', + torch.randn(embed_dim, queue_size)) + self.register_buffer('idx_queue', torch.full((1, queue_size), + -100)) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + self.image_queue = F.normalize(self.image_queue, dim=0) + self.text_queue = F.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + # This value will be warmup by `WarmupParamHook` + self.alpha = alpha + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + return feats[-1] + + def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + # The part can be traced by torch.fx + img_feats, text_feats, img_feats_m, text_feats_m = self(feats) + + img_feats_all = torch.cat( + [img_feats_m.t(), + self.image_queue.clone().detach()], dim=1) + text_feats_all = torch.cat( + [text_feats_m.t(), + self.text_queue.clone().detach()], dim=1) + + # The part can not be traced by torch.fx + losses = self._get_loss(img_feats, text_feats, img_feats_m, + text_feats_m, img_feats_all, text_feats_all, + data_samples, **kwargs) + return losses + + def _get_loss(self, img_feats, text_feats, img_feats_m, text_feats_m, + img_feats_all, text_feats_all, data_samples, **kwargs): + """Unpack data samples and compute loss.""" + + idx = torch.tensor([ds.image_id + for ds in data_samples]).to(img_feats.device) + idx = idx.view(-1, 1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + with torch.no_grad(): + if self.use_distill: + sim_i2t_m = img_feats_m @ text_feats_all / self.temp + sim_t2i_m = text_feats_m @ img_feats_all / self.temp + + sim_i2t_targets = ( + self.alpha * F.softmax(sim_i2t_m, dim=1) + + (1 - self.alpha) * sim_targets) + sim_t2i_targets = ( + self.alpha * F.softmax(sim_t2i_m, dim=1) + + (1 - self.alpha) * sim_targets) + + sim_i2t = img_feats @ text_feats_all / self.temp + sim_t2i = text_feats @ img_feats_all / self.temp + + if self.use_distill: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean() + else: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() + + # compute loss + losses = dict() + + losses['itc_loss'] = (loss_i2t + loss_t2i) / 2 + self._dequeue_and_enqueue(img_feats_m, text_feats_m, idx) + return losses + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = torch.cat(all_gather(image_feat)) + text_feats = torch.cat(all_gather(text_feat)) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = torch.cat(all_gather(idxs)) + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr diff --git a/mmpretrain/models/heads/itm_head.py b/mmpretrain/models/heads/itm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b42f3f684e2ffefd085b39360706a339017f4c --- /dev/null +++ b/mmpretrain/models/heads/itm_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.evaluation import Accuracy +from mmpretrain.registry import MODELS + + +class Pooler(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@MODELS.register_module() +class ITMHead(BaseModule): + """Image-text matching head for multi-modal pre-trained task. Adapted by + BLIP, FLAVA. + + Args: + hidden_size (int): Hidden channel size out input features. + with_pooler (bool): Whether a pooler is added. Defaults to True. + loss (dict): Config of global contrasive loss. Defaults to + ``dict(type='GlobalContrasiveLoss')``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + hidden_size: int, + with_pooler: bool = True, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ITMHead, self).__init__(init_cfg=init_cfg) + self.hidden_size = hidden_size + + if with_pooler: + self.pooler = Pooler(hidden_size=self.hidden_size) + else: + self.pooler = nn.Identity() + self.fc = nn.Linear(self.hidden_size, 2) + + self.loss_module = MODELS.build(loss) + self.cal_acc = cal_acc + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pooler(feats[-1]) + itm_logits = self.fc(pre_logits) + return itm_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + # The part can be traced by torch.fx + itm_logits = self(feats) + + # deal with query + if itm_logits.ndim == 3: + itm_logits = itm_logits.mean(dim=1) + + # The part can not be traced by torch.fx + losses = self._get_loss(itm_logits, data_samples, **kwargs) + return losses + + def _get_loss(self, itm_logits: torch.Tensor, data_samples, **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + # use `itm_label` in here temporarily + target = torch.tensor([i.is_matched + for i in data_samples]).to(itm_logits.device) + + # compute loss + losses = dict() + + loss = self.loss_module( + itm_logits, target.long(), avg_factor=itm_logits.size(0), **kwargs) + losses['itm_loss'] = loss + + # compute accuracy + if self.cal_acc: + # topk is meaningless for matching task + acc = Accuracy.calculate(itm_logits, target) + # acc is warpped with two lists of topk and thrs + # which are unnecessary here + losses.update({'itm_accuracy': acc[0][0]}) + + return losses diff --git a/mmpretrain/models/heads/itpn_clip_head.py b/mmpretrain/models/heads/itpn_clip_head.py new file mode 100644 index 0000000000000000000000000000000000000000..52c49b8c013c5169d1d997b4db5030dd0bc6a540 --- /dev/null +++ b/mmpretrain/models/heads/itpn_clip_head.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.device import get_device +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class iTPNClipHead(BaseModule): + """Head for iTPN Pre-training using Clip. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.to(get_device(), non_blocking=True) + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # remove cls_token + # feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss_module(logits, target) + return loss diff --git a/mmpretrain/models/heads/latent_heads.py b/mmpretrain/models/heads/latent_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..a9662b5d91c8534d1a2a7834e4b9e3ec37f552c1 --- /dev/null +++ b/mmpretrain/models/heads/latent_heads.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_reduce, get_world_size +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class LatentPredictHead(BaseModule): + """Head for latent feature prediction. + + This head builds a predictor, which can be any registered neck component. + For example, BYOL and SimSiam call this head and build NonLinearNeck. + It also implements similarity loss between two forward features. + + Args: + loss (dict): Config dict for the loss. + predictor (dict): Config dict for the predictor. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + predictor: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + self.predictor = MODELS.build(predictor) + + def loss(self, input: torch.Tensor, + target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward head. + + Args: + input (torch.Tensor): NxC input features. + target (torch.Tensor): NxC target features. + + Returns: + torch.Tensor: The latent predict loss. + """ + pred = self.predictor([input])[0] + target = target.detach() + + loss = self.loss_module(pred, target) + + return loss + + +@MODELS.register_module() +class LatentCrossCorrelationHead(BaseModule): + """Head for latent feature cross correlation. + + Part of the code is borrowed from `script + `_. + + Args: + in_channels (int): Number of input channels. + loss (dict): Config dict for module of loss functions. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.world_size = get_world_size() + self.bn = nn.BatchNorm1d(in_channels, affine=False) + self.loss_module = MODELS.build(loss) + + def loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Forward head. + + Args: + input (torch.Tensor): NxC input features. + target (torch.Tensor): NxC target features. + + Returns: + torch.Tensor: The cross correlation loss. + """ + # cross-correlation matrix + cross_correlation_matrix = self.bn(input).T @ self.bn(target) + cross_correlation_matrix.div_(input.size(0) * self.world_size) + + all_reduce(cross_correlation_matrix) + + loss = self.loss_module(cross_correlation_matrix) + return loss diff --git a/mmpretrain/models/heads/levit_head.py b/mmpretrain/models/heads/levit_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a74d7ecc52caca0adca642e528f2861f9a0e5833 --- /dev/null +++ b/mmpretrain/models/heads/levit_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.models.heads import ClsHead +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +class BatchNormLinear(BaseModule): + + def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')): + super(BatchNormLinear, self).__init__() + self.bn = build_norm_layer(norm_cfg, in_channels) + self.linear = nn.Linear(in_channels, out_channels) + + @torch.no_grad() + def fuse(self): + w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 + b = self.bn.bias - self.bn.running_mean * \ + self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5 + w = self.linear.weight * w[None, :] + b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias + + self.linear.weight.data.copy_(w) + self.linear.bias.data.copy_(b) + return self.linear + + def forward(self, x): + x = self.bn(x) + x = self.linear(x) + return x + + +def fuse_parameters(module): + for child_name, child in module.named_children(): + if hasattr(child, 'fuse'): + setattr(module, child_name, child.fuse()) + else: + fuse_parameters(child) + + +@MODELS.register_module() +class LeViTClsHead(ClsHead): + + def __init__(self, + num_classes=1000, + distillation=True, + in_channels=None, + deploy=False, + **kwargs): + super(LeViTClsHead, self).__init__(**kwargs) + self.num_classes = num_classes + self.distillation = distillation + self.deploy = deploy + self.head = BatchNormLinear(in_channels, num_classes) + if distillation: + self.head_dist = BatchNormLinear(in_channels, num_classes) + + if self.deploy: + self.switch_to_deploy(self) + + def switch_to_deploy(self): + if self.deploy: + return + fuse_parameters(self) + self.deploy = True + + def forward(self, x): + x = self.pre_logits(x) + if self.distillation: + x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000 + if not self.training: + x = (x[0] + x[1]) / 2 + else: + raise NotImplementedError("MMPretrain doesn't support " + 'training in distillation mode.') + else: + x = self.head(x) + return x diff --git a/mmpretrain/models/heads/linear_head.py b/mmpretrain/models/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..90b4c2b11eb0b2ba087fd438a32596cedb13cebb --- /dev/null +++ b/mmpretrain/models/heads/linear_head.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class LinearClsHead(ClsHead): + """Linear classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01), + **kwargs): + super(LinearClsHead, self).__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``LinearClsHead``, we just obtain the + feature of the last stage. + """ + # The LinearClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b76ecedd96dd84d34ce2d9cb6dfa4fe725ea870b --- /dev/null +++ b/mmpretrain/models/heads/mae_head.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MAEPretrainHead(BaseModule): + """Head for MAE Pre-training. + + Args: + loss (dict): Config of loss. + norm_pix_loss (bool): Whether or not normalize target. + Defaults to False. + patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16, + in_channels: int = 3) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.in_channels = in_channels + self.loss_module = MODELS.build(loss) + + def patchify(self, imgs: torch.Tensor) -> torch.Tensor: + r"""Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images. The shape should + be :math:`(B, C, H, W)`. + + Returns: + torch.Tensor: Patchified images. The shape is + :math:`(B, L, \text{patch_size}^2 \times C)`. + """ + p = self.patch_size + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) + return x + + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + r"""Combine non-overlapped patches into images. + + Args: + x (torch.Tensor): The shape is + :math:`(B, L, \text{patch_size}^2 \times C)`. + + Returns: + torch.Tensor: The shape is :math:`(B, C, H, W)`. + """ + p = self.patch_size + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) + return imgs + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + + Args: + target (torch.Tensor): Image with the shape of B x C x H x W + + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + loss = self.loss_module(pred, target, mask) + + return loss diff --git a/mmpretrain/models/heads/margin_head.py b/mmpretrain/models/heads/margin_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3a88bf8b3f4d19b233192a7578f49b750ff53ed5 --- /dev/null +++ b/mmpretrain/models/heads/margin_head.py @@ -0,0 +1,300 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.fileio import list_from_file +from mmengine.runner import autocast +from mmengine.utils import is_seq_of + +from mmpretrain.models.losses import convert_to_one_hot +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +class NormProduct(nn.Linear): + """An enhanced linear layer with k clustering centers to calculate product + between normalized input and linear weight. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample + k (int): The number of clustering centers. Defaults to 1. + bias (bool): Whether there is bias. If set to ``False``, the + layer will not learn an additive bias. Defaults to ``True``. + feature_norm (bool): Whether to normalize the input feature. + Defaults to ``True``. + weight_norm (bool):Whether to normalize the weight. + Defaults to ``True``. + """ + + def __init__(self, + in_features: int, + out_features: int, + k=1, + bias: bool = False, + feature_norm: bool = True, + weight_norm: bool = True): + + super().__init__(in_features, out_features * k, bias=bias) + self.weight_norm = weight_norm + self.feature_norm = feature_norm + self.out_features = out_features + self.k = k + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.feature_norm: + input = F.normalize(input) + if self.weight_norm: + weight = F.normalize(self.weight) + else: + weight = self.weight + cosine_all = F.linear(input, weight, self.bias) + + if self.k == 1: + return cosine_all + else: + cosine_all = cosine_all.view(-1, self.out_features, self.k) + cosine, _ = torch.max(cosine_all, dim=2) + return cosine + + +@MODELS.register_module() +class ArcFaceClsHead(ClsHead): + """ArcFace classifier head. + + A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss + for Deep Face Recognition `_ and + `Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web + Faces `_ + + Example: + To use ArcFace in config files. + + 1. use vanilla ArcFace + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 2. use SubCenterArcFace with 3 sub-centers + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 3. use SubCenterArcFace With CountPowerAdaptiveMargins + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + custom_hooks = [dict(type='SetAdaptiveMarginsHook')] + + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_subcenters (int): Number of subcenters. Defaults to 1. + scale (float): Scale factor of output logit. Defaults to 64.0. + margins (float): The penalty margin. Could be the fllowing formats: + + - float: The margin, would be same for all the categories. + - Sequence[float]: The category-based margins list. + - str: A '.txt' file path which contains a list. Each line + represents the margin of a category, and the number in the + i-th row indicates the margin of the i-th class. + + Defaults to 0.5. + easy_margin (bool): Avoid theta + m >= PI. Defaults to False. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + num_subcenters: int = 1, + scale: float = 64., + margins: Optional[Union[float, Sequence[float], str]] = 0.50, + easy_margin: bool = False, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg: Optional[dict] = None): + + super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + assert num_subcenters >= 1 and num_classes >= 0 + self.in_channels = in_channels + self.num_classes = num_classes + self.num_subcenters = num_subcenters + self.scale = scale + self.easy_margin = easy_margin + + self.norm_product = NormProduct(in_channels, num_classes, + num_subcenters) + + if isinstance(margins, float): + margins = [margins] * num_classes + elif isinstance(margins, str) and margins.endswith('.txt'): + margins = [float(item) for item in list_from_file(margins)] + else: + assert is_seq_of(list(margins), (float, int)), ( + 'the attribute `margins` in ``ArcFaceClsHead`` should be a ' + ' float, a Sequence of float, or a ".txt" file path.') + + assert len(margins) == num_classes, \ + 'The length of margins must be equal with num_classes.' + + self.register_buffer( + 'margins', torch.tensor(margins).float(), persistent=False) + # To make `phi` monotonic decreasing, refers to + # https://github.com/deepinsight/insightface/issues/108 + sinm_m = torch.sin(math.pi - self.margins) * self.margins + threshold = torch.cos(math.pi - self.margins) + self.register_buffer('sinm_m', sinm_m, persistent=False) + self.register_buffer('threshold', threshold, persistent=False) + + def set_margins(self, margins: Union[Sequence[float], float]) -> None: + """set margins of arcface head. + + Args: + margins (Union[Sequence[float], float]): The marigins. + """ + if isinstance(margins, float): + margins = [margins] * self.num_classes + assert is_seq_of( + list(margins), float) and (len(margins) == self.num_classes), ( + f'margins must be Sequence[Union(float, int)], get {margins}') + + self.margins = torch.tensor( + margins, device=self.margins.device, dtype=torch.float32) + self.sinm_m = torch.sin(self.margins) * self.margins + self.threshold = -torch.cos(self.margins) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ArcFaceHead``, we just obtain the + feature of the last stage. + """ + # The ArcFaceHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def _get_logit_with_margin(self, pre_logits, target): + """add arc margin to the cosine in target index. + + The target must be in index format. + """ + assert target.dim() == 1 or ( + target.dim() == 2 and target.shape[1] == 1), \ + 'The target must be in index format.' + cosine = self.norm_product(pre_logits) + phi = torch.cos(torch.acos(cosine) + self.margins) + + if self.easy_margin: + # when cosine>0, choose phi + # when cosine<=0, choose cosine + phi = torch.where(cosine > 0, phi, cosine) + else: + # when cos>th, choose phi + # when cos<=th, choose cosine-mm + phi = torch.where(cosine > self.threshold, phi, + cosine - self.sinm_m) + + target = convert_to_one_hot(target, self.num_classes) + output = target * phi + (1 - target) * cosine + return output + + def forward(self, + feats: Tuple[torch.Tensor], + target: Optional[torch.Tensor] = None) -> torch.Tensor: + """The forward process.""" + # Disable AMP + with autocast(enabled=False): + pre_logits = self.pre_logits(feats) + + if target is None: + # when eval, logit is the cosine between W and pre_logits; + # cos(theta_yj) = (x/||x||) * (W/||W||) + logit = self.norm_product(pre_logits) + else: + # when training, add a margin to the pre_logits where target is + # True, then logit is the cosine between W and new pre_logits + logit = self._get_logit_with_margin(pre_logits, target) + + return self.scale * logit + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # Unpack data samples and pack targets + label_target = torch.cat([i.gt_label for i in data_samples]) + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = label_target + + # the index format target would be used + cls_score = self(feats, label_target) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses diff --git a/mmpretrain/models/heads/mim_head.py b/mmpretrain/models/heads/mim_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bda90c8198986ec9b2ff2d03db3350e1f1a25823 --- /dev/null +++ b/mmpretrain/models/heads/mim_head.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MIMHead(BaseModule): + """Pre-training head for Masked Image Modeling. + + Args: + loss (dict): Config dict for module of loss functions. + """ + + def __init__(self, loss: dict) -> None: + super().__init__() + self.loss_module = MODELS.build(loss) + + def loss(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward head. + + Args: + pred (torch.Tensor): Predictions with shape B x L x C. + target (torch.Tensor): Targets with shape B x L x C. + mask (torch.Tensor): Mask with shape B x L. + + Returns: + torch.Tensor: The loss tensor. + """ + loss = self.loss_module(pred, target, mask) + return loss diff --git a/mmpretrain/models/heads/mixmim_head.py b/mmpretrain/models/heads/mixmim_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a709630abb26bce1153596cec842da0912bab912 --- /dev/null +++ b/mmpretrain/models/heads/mixmim_head.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS +from .mae_head import MAEPretrainHead + + +@MODELS.register_module() +class MixMIMPretrainHead(MAEPretrainHead): + """Head for MixMIM Pre-training. + + Args: + loss (dict): Config of loss. + norm_pix_loss (bool): Whether or not normalize target. + Defaults to False. + patch_size (int): Patch size. Defaults to 16. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16) -> None: + super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size) + + def loss(self, x_rec: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + + B, L, C = x_rec.shape + + # unmix tokens + x1_rec = x_rec[:B // 2] + x2_rec = x_rec[B // 2:] + + unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask) + + loss_rec = self.loss_module(unmix_x_rec, target) + + return loss_rec diff --git a/mmpretrain/models/heads/mocov3_head.py b/mmpretrain/models/heads/mocov3_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bec2a6cc90247fab44d6d954a8a0c6ede0a812 --- /dev/null +++ b/mmpretrain/models/heads/mocov3_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.dist import all_gather, get_rank +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV3Head(BaseModule): + """Head for MoCo v3 Pre-training. + + This head builds a predictor, which can be any registered neck component. + It also implements latent contrastive loss between two forward features. + Part of the code is modified from: + ``_. + + Args: + predictor (dict): Config dict for module of predictor. + loss (dict): Config dict for module of loss functions. + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Defaults to 1.0. + """ + + def __init__(self, + predictor: dict, + loss: dict, + temperature: float = 1.0) -> None: + super().__init__() + self.predictor = MODELS.build(predictor) + self.loss_module = MODELS.build(loss) + self.temperature = temperature + + def loss(self, base_out: torch.Tensor, + momentum_out: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + base_out (torch.Tensor): NxC features from base_encoder. + momentum_out (torch.Tensor): NxC features from momentum_encoder. + + Returns: + torch.Tensor: The loss tensor. + """ + # predictor computation + pred = self.predictor([base_out])[0] + + # normalize + pred = nn.functional.normalize(pred, dim=1) + target = nn.functional.normalize(momentum_out, dim=1) + + # get negative samples + target = torch.cat(all_gather(target), dim=0) + + # Einstein sum is more intuitive + logits = torch.einsum('nc,mc->nm', [pred, target]) / self.temperature + + # generate labels + batch_size = logits.shape[0] + labels = (torch.arange(batch_size, dtype=torch.long) + + batch_size * get_rank()).to(logits.device) + + loss = self.loss_module(logits, labels) + return loss diff --git a/mmpretrain/models/heads/multi_label_cls_head.py b/mmpretrain/models/heads/multi_label_cls_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ca36bfe06e70e1e0f16a5dc4c161b186234f57ac --- /dev/null +++ b/mmpretrain/models/heads/multi_label_cls_head.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample, label_to_onehot + + +@MODELS.register_module() +class MultiLabelClsHead(BaseModule): + """Classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = None): + super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + if thr is None and topk is None: + thr = 0.5 + + self.thr = thr + self.topk = topk + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain + the feature of the last stage. + """ + # The MultiLabelClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The MultiLabelClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + num_classes = cls_score.size()[-1] + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + target = torch.stack([i.gt_score.float() for i in data_samples]) + else: + target = torch.stack([ + label_to_onehot(i.gt_label, num_classes) for i in data_samples + ]).float() + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + feats: Tuple[torch.Tensor], + data_samples: List[DataSample] = None) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score: torch.Tensor, + data_samples: List[DataSample]): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = torch.sigmoid(cls_score) + + if data_samples is None: + data_samples = [DataSample() for _ in range(cls_score.size(0))] + + for data_sample, score in zip(data_samples, pred_scores): + if self.thr is not None: + # a label is predicted positive if larger than thr + label = torch.where(score >= self.thr)[0] + else: + # top-k labels will be predicted positive for any example + _, label = score.topk(self.topk) + data_sample.set_pred_score(score).set_pred_label(label) + + return data_samples diff --git a/mmpretrain/models/heads/multi_label_csra_head.py b/mmpretrain/models/heads/multi_label_csra_head.py new file mode 100644 index 0000000000000000000000000000000000000000..95a3a0e8b9d6c68c2f2c1da3c0c160c4c695cc7c --- /dev/null +++ b/mmpretrain/models/heads/multi_label_csra_head.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/Kevinz-code/CSRA +from typing import Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from .multi_label_cls_head import MultiLabelClsHead + + +@MODELS.register_module() +class CSRAClsHead(MultiLabelClsHead): + """Class-specific residual attention classifier head. + + Please refer to the `Residual Attention: A Simple but Effective Method for + Multi-Label Recognition (ICCV 2021) `_ + for details. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + num_heads (int): Number of residual at tensor heads. + loss (dict): Config of classification loss. + lam (float): Lambda that combines global average and max pooling + scores. + init_cfg (dict, optional): The extra init config of layers. + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + temperature_settings = { # softmax temperature settings + 1: [1], + 2: [1, 99], + 4: [1, 2, 4, 99], + 6: [1, 2, 3, 4, 5, 99], + 8: [1, 2, 3, 4, 5, 6, 7, 99] + } + + def __init__(self, + num_classes: int, + in_channels: int, + num_heads: int, + lam: float, + init_cfg=dict(type='Normal', layer='Linear', std=0.01), + **kwargs): + assert num_heads in self.temperature_settings.keys( + ), 'The num of heads is not in temperature setting.' + assert lam > 0, 'Lambda should be between 0 and 1.' + super(CSRAClsHead, self).__init__(init_cfg=init_cfg, **kwargs) + self.temp_list = self.temperature_settings[num_heads] + self.csra_heads = ModuleList([ + CSRAModule(num_classes, in_channels, self.temp_list[i], lam) + for i in range(num_heads) + ]) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``CSRAClsHead``, we just obtain the + feature of the last stage. + """ + # The CSRAClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + logit = sum([head(pre_logits) for head in self.csra_heads]) + return logit + + +class CSRAModule(BaseModule): + """Basic module of CSRA with different temperature. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + T (int): Temperature setting. + lam (float): Lambda that combines global average and max pooling + scores. + init_cfg (dict | optional): The extra init config of layers. + Defaults to use dict(type='Normal', layer='Linear', std=0.01). + """ + + def __init__(self, + num_classes: int, + in_channels: int, + T: int, + lam: float, + init_cfg=None): + + super(CSRAModule, self).__init__(init_cfg=init_cfg) + self.T = T # temperature + self.lam = lam # Lambda + self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False) + self.softmax = nn.Softmax(dim=2) + + def forward(self, x): + score = self.head(x) / torch.norm( + self.head.weight, dim=1, keepdim=True).transpose(0, 1) + score = score.flatten(2) + base_logit = torch.mean(score, dim=2) + + if self.T == 99: # max-pooling + att_logit = torch.max(score, dim=2)[0] + else: + score_soft = self.softmax(score * self.T) + att_logit = torch.sum(score * score_soft, dim=2) + + return base_logit + self.lam * att_logit diff --git a/mmpretrain/models/heads/multi_label_linear_head.py b/mmpretrain/models/heads/multi_label_linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..81217ec55c54f23748b7e4ce8797509abfbb2ed3 --- /dev/null +++ b/mmpretrain/models/heads/multi_label_linear_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .multi_label_cls_head import MultiLabelClsHead + + +@MODELS.register_module() +class MultiLabelLinearClsHead(MultiLabelClsHead): + """Linear classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to use dict(type='Normal', layer='Linear', std=0.01). + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01)): + super(MultiLabelLinearClsHead, self).__init__( + loss=loss, thr=thr, topk=topk, init_cfg=init_cfg) + + assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \ + 'positive integer.' + + self.in_channels = in_channels + self.num_classes = num_classes + + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just + obtain the feature of the last stage. + """ + # The obtain the MultiLabelLinearClsHead doesn't have other module, + # just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/multi_task_head.py b/mmpretrain/models/heads/multi_task_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3515a2b1e0b2140a57f57a69416b2c462ecec871 --- /dev/null +++ b/mmpretrain/models/heads/multi_task_head.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule, ModuleDict + +from mmpretrain.registry import MODELS +from mmpretrain.structures import MultiTaskDataSample + + +def loss_convertor(loss_func, task_name): + + def wrapped(inputs, data_samples, **kwargs): + mask = torch.empty(len(data_samples), dtype=torch.bool) + task_data_samples = [] + for i, data_sample in enumerate(data_samples): + assert isinstance(data_sample, MultiTaskDataSample) + sample_mask = task_name in data_sample + mask[i] = sample_mask + if sample_mask: + task_data_samples.append(data_sample.get(task_name)) + + if len(task_data_samples) == 0: + # This makes it possible to perform loss.backward when a + # task does not have gt_labels within a batch. + loss = (inputs[0] * 0).sum() + return {'loss': loss, 'mask_size': torch.tensor(0.)} + + # Mask the inputs of the task + def mask_inputs(inputs, mask): + if isinstance(inputs, Sequence): + return type(inputs)( + [mask_inputs(input, mask) for input in inputs]) + elif isinstance(inputs, torch.Tensor): + return inputs[mask] + + masked_inputs = mask_inputs(inputs, mask) + loss_output = loss_func(masked_inputs, task_data_samples, **kwargs) + loss_output['mask_size'] = mask.sum().to(torch.float) + return loss_output + + return wrapped + + +@MODELS.register_module() +class MultiTaskHead(BaseModule): + """Multi task head. + + Args: + task_heads (dict): Sub heads to use, the key will be use to rename the + loss components. + common_cfg (dict): The common settings for all heads. Defaults to an + empty dict. + init_cfg (dict, optional): The extra initialization settings. + Defaults to None. + """ + + def __init__(self, task_heads, init_cfg=None, **kwargs): + super(MultiTaskHead, self).__init__(init_cfg=init_cfg) + + assert isinstance(task_heads, dict), 'The `task_heads` argument' \ + "should be a dict, which's keys are task names and values are" \ + 'configs of head for the task.' + + self.task_heads = ModuleDict() + + for task_name, sub_head in task_heads.items(): + if not isinstance(sub_head, nn.Module): + sub_head = MODELS.build(sub_head, default_args=kwargs) + sub_head.loss = loss_convertor(sub_head.loss, task_name) + self.task_heads[task_name] = sub_head + + def forward(self, feats): + """The forward process.""" + return { + task_name: head(feats) + for task_name, head in self.task_heads.items() + } + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components, each task loss + key will be prefixed by the task_name like "task1_loss" + """ + losses = dict() + for task_name, head in self.task_heads.items(): + head_loss = head.loss(feats, data_samples, **kwargs) + for k, v in head_loss.items(): + losses[f'{task_name}_{k}'] = v + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample] = None + ) -> List[MultiTaskDataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[MultiTaskDataSample]: A list of data samples which contains + the predicted results. + """ + predictions_dict = dict() + + for task_name, head in self.task_heads.items(): + task_samples = None + if data_samples is not None: + task_samples = [ + data_sample.get(task_name, None) if data_sample else None + for data_sample in data_samples + ] + + task_samples = head.predict(feats, task_samples) + batch_size = len(task_samples) + predictions_dict[task_name] = task_samples + + if data_samples is None: + data_samples = [MultiTaskDataSample() for _ in range(batch_size)] + else: + data_samples = [ + MultiTaskDataSample() if data_sample is None else data_sample + for data_sample in data_samples + ] + + for task_name, task_samples in predictions_dict.items(): + for data_sample, task_sample in zip(data_samples, task_samples): + task_sample.set_field( + task_name in data_sample.tasks, + 'eval_mask', + field_type='metainfo') + + if task_name in data_sample.tasks: + data_sample.get(task_name).update(task_sample) + else: + data_sample.set_field(task_sample, task_name) + + return data_samples diff --git a/mmpretrain/models/heads/seq_gen_head.py b/mmpretrain/models/heads/seq_gen_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e9b10efe6e1e6a709cd870f0572f14bbd176ee --- /dev/null +++ b/mmpretrain/models/heads/seq_gen_head.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SeqGenerationHead(BaseModule): + """Generation head for multi-modal pre-trained task, adopted by BLIP. + Normally used for generation task. + + Args: + decoder (dict): Decoder for blip generation head. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + decoder: dict, + ignore_index=-100, + loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1), + init_cfg: Optional[dict] = None, + ) -> None: + super(SeqGenerationHead, self).__init__(init_cfg=init_cfg) + self.decoder = MODELS.build(decoder) + self.loss_fn = MODELS.build(loss) + self.ignore_index = ignore_index + + def forward(self, input_ids: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, labels: torch.Tensor): + """Forward to get decoder output. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + encoder_attention_mask (torch.Tensor): Image embeddings hidden + states attention mask. + labels (torch.Tensor): Decoder target for calculate loss. + + Returns: + dict[str, Tensor]: a dictionary of decoder outputs. + """ + + decoder_out = self.decoder( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + return_dict=True, + ) + return decoder_out + + def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask, + labels): + """Calculate losses from the extracted features. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + encoder_attention_mask (torch.Tensor): Image embeddings hidden + states attention mask. + labels (torch.Tensor): Decoder target for calculate loss. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + + decoder_out = self( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + ) + prediction_scores = decoder_out['logits'] + # we are doing next-token prediction; + # shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + vocab_size = prediction_scores.shape[-1] + + # mask ignored index + if (labels == self.ignore_index).any(): + labels = labels.view(-1).clone() + ignore_mask = (labels == self.ignore_index) + labels.masked_fill_(ignore_mask, 0) + weight = torch.logical_not(ignore_mask) + avg_factor = max(weight.sum(), 1) + else: + weight = None + avg_factor = labels.size(0) + + lm_loss = self.loss_fn( + shifted_prediction_scores.view(-1, vocab_size), + labels, + weight=weight, + avg_factor=avg_factor, + ) + losses = { + 'seq_gen_lm_loss': lm_loss, + } + + return losses + + def predict(self, + input_ids, + encoder_hidden_states, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=20, + min_length=2, + top_p=0.9, + repetition_penalty=1.0, + **kwargs): + """Decoder prediction method. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + sep_token_id (int): Tokenid of separation token. + pad_token_id (int): Tokenid of pad token. + use_nucleus_sampling (bool): Whether to use nucleus sampling in + prediction. Defaults to False. + num_beams (int): Number of beams used in predition. + Defaults to 3. + max_length (int): Max length of generated text in predition. + Defaults to 20. + min_length (int): Min length of generated text in predition. + Defaults to 20. + top_p (float): + If < 1.0, only keep the top tokens with cumulative probability + >= top_p (nucleus filtering). Defaults to 0.9. + repetition_penalty (float): The parameter for repetition penalty. + Defaults to 1.0. + **kwarg: Other arguments that might used in generation. + + Returns: + dict[str, Tensor]: a dictionary of generation outputs. + """ + device = encoder_hidden_states.device + + # TODO: In old version of transformers + # Additional repeat interleave of hidden states should be add here. + image_atts = torch.ones( + encoder_hidden_states.size()[:-1], dtype=torch.long).to(device) + + model_kwargs = { + 'encoder_hidden_states': encoder_hidden_states, + 'encoder_attention_mask': image_atts, + } + model_kwargs.update(kwargs) + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + outputs = self.decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + return outputs diff --git a/mmpretrain/models/heads/simmim_head.py b/mmpretrain/models/heads/simmim_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b7af984c9eb4891e9f4281daf630355cafbb6cc7 --- /dev/null +++ b/mmpretrain/models/heads/simmim_head.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMHead(BaseModule): + """Head for SimMIM Pre-training. + + Args: + patch_size (int): Patch size of each token. + loss (dict): The config for loss. + """ + + def __init__(self, patch_size: int, loss: dict) -> None: + super().__init__() + self.patch_size = patch_size + self.loss_module = MODELS.build(loss) + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + This method will expand mask to the size of the original image. + + Args: + pred (torch.Tensor): The reconstructed image (B, C, H, W). + target (torch.Tensor): The target image (B, C, H, W). + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave( + self.patch_size, 2).unsqueeze(1).contiguous() + loss = self.loss_module(pred, target, mask) + + return loss diff --git a/mmpretrain/models/heads/spark_head.py b/mmpretrain/models/heads/spark_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a2748762ae50e1e085bd2ce536e95c6d52e51d9c --- /dev/null +++ b/mmpretrain/models/heads/spark_head.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SparKPretrainHead(BaseModule): + """Pre-training head for SparK. + + Args: + loss (dict): Config of loss. + norm_pix (bool): Whether or not normalize target. Defaults to True. + patch_size (int): Patch size, equal to downsample ratio of backbone. + Defaults to 32. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = True, + patch_size: int = 32) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.loss = MODELS.build(loss) + + def patchify(self, imgs): + """Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images, of shape B x C x H x W. + Returns: + torch.Tensor: Patchified images. The shape is B x L x D. + """ + p = self.patch_size + assert len(imgs.shape + ) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0 + + B, C, ori_h, ori_w = imgs.shape + h = ori_h // p + w = ori_w // p + x = imgs.reshape(shape=(B, C, h, p, w, p)) + x = torch.einsum('bchpwq->bhwpqc', x) + + # (B, f*f, downsample_raito*downsample_raito*3) + x = x.reshape(shape=(B, h * w, p**2 * C)) + return x + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + Args: + target (torch.Tensor): Image with the shape of B x 3 x H x W + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + active_mask: torch.Tensor) -> torch.Tensor: + """Forward function of MAE head. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + active_mask (torch.Tensor): The mask of the target image. + Returns: + torch.Tensor: The reconstruction loss. + """ + # (B, C, H, W) -> (B, L, C) and perform normalization + target = self.construct_target(target) + + # (B, C, H, W) -> (B, L, C) + pred = self.patchify(pred) + + # (B, 1, f, f) -> (B, L) + non_active_mask = active_mask.logical_not().int().view( + active_mask.shape[0], -1) + + # MSE loss on masked patches + loss = self.loss(pred, target, non_active_mask) + return loss diff --git a/mmpretrain/models/heads/stacked_head.py b/mmpretrain/models/heads/stacked_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd819de8e8daf162bb906d5524871577754fa1f --- /dev/null +++ b/mmpretrain/models/heads/stacked_head.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +class LinearBlock(BaseModule): + """Linear block for StackedLinearClsHead.""" + + def __init__(self, + in_channels, + out_channels, + dropout_rate=0., + norm_cfg=None, + act_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.fc = nn.Linear(in_channels, out_channels) + + self.norm = None + self.act = None + self.dropout = None + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + if act_cfg is not None: + self.act = build_activation_layer(act_cfg) + if dropout_rate > 0: + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x): + """The forward process.""" + x = self.fc(x) + if self.norm is not None: + x = self.norm(x) + if self.act is not None: + x = self.act(x) + if self.dropout is not None: + x = self.dropout(x) + return x + + +@MODELS.register_module() +class StackedLinearClsHead(ClsHead): + """Classifier head with several hidden fc layer and a output fc layer. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + mid_channels (Sequence[int]): Number of channels in the hidden fc + layers. + dropout_rate (float): Dropout rate after each hidden fc layer, + except the last layer. Defaults to 0. + norm_cfg (dict, optional): Config dict of normalization layer after + each hidden fc layer, except the last layer. Defaults to None. + act_cfg (dict, optional): Config dict of activation function after each + hidden layer, except the last layer. Defaults to use "ReLU". + """ + + def __init__(self, + num_classes: int, + in_channels: int, + mid_channels: Sequence[int], + dropout_rate: float = 0., + norm_cfg: Optional[Dict] = None, + act_cfg: Optional[Dict] = dict(type='ReLU'), + **kwargs): + super(StackedLinearClsHead, self).__init__(**kwargs) + self.num_classes = num_classes + self.in_channels = in_channels + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + assert isinstance(mid_channels, Sequence), \ + f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \ + f'instead of {type(mid_channels)}' + self.mid_channels = mid_channels + + self.dropout_rate = dropout_rate + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self._init_layers() + + def _init_layers(self): + """"Init layers.""" + self.layers = ModuleList() + in_channels = self.in_channels + for hidden_channels in self.mid_channels: + self.layers.append( + LinearBlock( + in_channels, + hidden_channels, + dropout_rate=self.dropout_rate, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channels = hidden_channels + + self.layers.append( + LinearBlock( + self.mid_channels[-1], + self.num_classes, + dropout_rate=0., + norm_cfg=None, + act_cfg=None)) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. + """ + x = feats[-1] + for layer in self.layers[:-1]: + x = layer(x) + return x + + @property + def fc(self): + """Full connected layer.""" + return self.layers[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/swav_head.py b/mmpretrain/models/heads/swav_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3a30236e019822a166e25551f77feec8228d84 --- /dev/null +++ b/mmpretrain/models/heads/swav_head.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SwAVHead(BaseModule): + """Head for SwAV Pre-training. + + Args: + loss (dict): Config dict for module of loss functions. + """ + + def __init__(self, loss: dict) -> None: + super().__init__() + self.loss_module = MODELS.build(loss) + + def loss(self, pred: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): NxC input features. + + Returns: + torch.Tensor: The SwAV loss. + """ + loss = self.loss_module(pred) + + return loss diff --git a/mmpretrain/models/heads/vig_head.py b/mmpretrain/models/heads/vig_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb984deb4b0b6bf162263a86771f2d3eba71cbd --- /dev/null +++ b/mmpretrain/models/heads/vig_head.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class VigClsHead(ClsHead): + """The classification head for Vision GNN. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int): The number of middle channels. Defaults to 1024. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + dropout (float): The dropout rate. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + hidden_dim: int = 1024, + act_cfg: dict = dict(type='GELU'), + dropout: float = 0., + **kwargs): + super().__init__(**kwargs) + + self.fc1 = nn.Linear(in_channels, hidden_dim) + self.bn = nn.BatchNorm1d(hidden_dim) + self.act = build_activation_layer(act_cfg) + self.drop = nn.Dropout(dropout) + self.fc2 = nn.Linear(hidden_dim, num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a stage_blocks stage. In ``VigClsHead``, we just obtain the + feature of the last stage. + """ + feats = feats[-1] + feats = self.fc1(feats) + feats = self.bn(feats) + feats = self.act(feats) + feats = self.drop(feats) + + return feats + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc2(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/vision_transformer_head.py b/mmpretrain/models/heads/vision_transformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..83e8fca125cd626c51abfcc87b28387f654618f9 --- /dev/null +++ b/mmpretrain/models/heads/vision_transformer_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from collections import OrderedDict +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer +from mmengine.model import Sequential +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class VisionTransformerClsHead(ClsHead): + """Vision Transformer classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + hidden_dim: Optional[int] = None, + act_cfg: dict = dict(type='Tanh'), + init_cfg: dict = dict(type='Constant', layer='Linear', val=0), + **kwargs): + super(VisionTransformerClsHead, self).__init__( + init_cfg=init_cfg, **kwargs) + self.in_channels = in_channels + self.num_classes = num_classes + self.hidden_dim = hidden_dim + self.act_cfg = act_cfg + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self._init_layers() + + def _init_layers(self): + """"Init hidden layer if exists.""" + if self.hidden_dim is None: + layers = [('head', nn.Linear(self.in_channels, self.num_classes))] + else: + layers = [ + ('pre_logits', nn.Linear(self.in_channels, self.hidden_dim)), + ('act', build_activation_layer(self.act_cfg)), + ('head', nn.Linear(self.hidden_dim, self.num_classes)), + ] + self.layers = Sequential(OrderedDict(layers)) + + def init_weights(self): + """"Init weights of hidden layer if exists.""" + super(VisionTransformerClsHead, self).init_weights() + # Modified from ClassyVision + if hasattr(self.layers, 'pre_logits'): + # Lecun norm + trunc_normal_( + self.layers.pre_logits.weight, + std=math.sqrt(1 / self.layers.pre_logits.in_features)) + nn.init.zeros_(self.layers.pre_logits.bias) + + def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``VisionTransformerClsHead``, we + obtain the feature of the last stage and forward in hidden layer if + exists. + """ + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + cls_token = feat[-1] if isinstance(feat, list) else feat + if self.hidden_dim is None: + return cls_token + else: + x = self.layers.pre_logits(cls_token) + return self.layers.act(x) + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.layers.head(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/vqa_head.py b/mmpretrain/models/heads/vqa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b5fe532874e2e8325caa3090d3be66b098ad46 --- /dev/null +++ b/mmpretrain/models/heads/vqa_head.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import mmengine +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class VQAGenerationHead(BaseModule): + """Generation head for multi-modal pre-trained task, adapted by BLIP. + Normally used for qa generation task (open-set) + + Args: + decoder (dict): Decoder for decoding answers. + inference_method (str): Inference method. One of 'rank', 'generate'. + - If 'rank', the model will return answers with the highest + probability from the answer list. + - If 'generate', the model will generate answers. + - Only for test, not for train / val. + num_beams (int): Number of beams for beam search. 1 means no beam + search. Only support when inference_method=='generate'. + Defaults to 3. + num_ans_candidates (int): Number of answer candidates, used to filter + out answers with low probability. Only support when + inference_method=='rank'. Defaults to 128. + loss (dict or nn.Module): Config of loss or module of loss. Defaults to + ``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + answer_list_path (str, optional): Path to `answer_list.json` + (json file of a answer list). Required when + inference_method=='rank'. + + + TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param. + Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to + maintain compatibility with torch < 1.10.0 + """ + + def __init__( + self, + decoder: dict, + inference_method: str = 'generate', + num_beams: int = 3, + num_ans_candidates: int = 128, + loss: Union[dict, nn.Module] = nn.CrossEntropyLoss( + reduction='none', ignore_index=-100), + init_cfg: Optional[dict] = None, + answer_list_path: Optional[str] = None, + ) -> None: + + super(VQAGenerationHead, self).__init__(init_cfg=init_cfg) + self.decoder = MODELS.build(decoder) + + if inference_method == 'generate': + assert isinstance(num_beams, int), \ + 'for VQA `generate` mode, `num_beams` must be a int.' + self.num_beams = num_beams + self.num_ans_candidates = None + self.answer_list = None + + elif inference_method == 'rank': + assert isinstance(num_ans_candidates, int), \ + 'for VQA `rank` mode, `num_ans_candidates` must be a int.' + assert isinstance(answer_list_path, str), \ + 'for VQA `rank` mode, `answer_list_path` must be set as ' \ + 'the path to `answer_list.json`.' + self.num_beams = None + self.answer_list = mmengine.load(answer_list_path) + if isinstance(self.answer_list, dict): + self.answer_list = list(self.answer_list.keys()) + assert isinstance(self.answer_list, list) and all( + isinstance(item, str) for item in self.answer_list), \ + 'for VQA `rank` mode, `answer_list.json` must be a list of str' + self.num_ans_candidates = min(num_ans_candidates, + len(self.answer_list)) + + else: + raise AssertionError( + 'for VQA, `inference_method` must be "generate" or "rank", ' + 'got {}.'.format(inference_method)) + + self.inference_method = inference_method + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + def forward(self, feats: dict): + prediction_logits = self.decoder( + feats['answer_input_ids'], + attention_mask=feats['answer_attention_mask'], + encoder_hidden_states=feats['question_states'], + encoder_attention_mask=feats['question_atts'], + labels=feats['answer_targets'], + return_dict=True, + return_logits=True, # directly return logits, not computing loss + reduction='none', + ) + return prediction_logits + + def loss(self, feats: dict, data_samples=None): + """Calculate losses from the extracted features. + + Args: + feats (dict): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + shifted_prediction_scores = self(feats) + labels = feats['answer_targets'] + lm_loss = None + + # we are doing next-token prediction; + # shift prediction scores and input ids by one + labels = labels[:, 1:].contiguous() + lm_loss = self.loss_module( + shifted_prediction_scores.view(-1, + self.decoder.med_config.vocab_size), + labels.view(-1)) + lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1) + # compute weighted loss + losses = dict() + loss = feats['answer_weight'] * lm_loss + loss = loss.sum() / feats['batch_size'] + losses['vqa_loss'] = loss + + return losses + + def predict_rank(self, feats: dict, data_samples=None): + """Predict rank in a close-set answer list.""" + question_states = feats['multimodal_embeds'] + question_atts = feats['question_atts'] + answer_candidates = feats['answer_candidates'] + assert answer_candidates is not None + + answer_ids = answer_candidates.input_ids + answer_atts = answer_candidates.attention_mask + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction='none', + ) + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax( + logits, dim=1).index_select( + dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk( + self.num_ans_candidates, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'], + -100) + + def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([ + init_dim * np.arange(n_tile) + i for i in range(init_dim) + ])) + return torch.index_select(x, dim, order_index.to(x.device)) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, self.num_ans_candidates) + question_atts = tile(question_atts, 0, self.num_ans_candidates) + + output = self.decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction='none', + ) + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] + + answers = [self.answer_list[max_id] for max_id in max_ids] + + return answers + + def predict_generate(self, feats: dict, data_samples=None): + """Predict answers in a generation manner.""" + device = feats['multimodal_embeds'].device + question_states = feats['multimodal_embeds'] + question_atts = torch.ones( + question_states.size()[:-1], dtype=torch.long).to(device) + model_kwargs = { + 'encoder_hidden_states': question_states, + 'encoder_attention_mask': question_atts + } + + bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1), + fill_value=feats['bos_token_id'], + device=device) + + outputs = self.decoder.generate( + input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=self.num_beams, + eos_token_id=feats['sep_token_id'], + pad_token_id=feats['pad_token_id'], + **model_kwargs) + + return outputs + + def predict(self, feats: dict, data_samples=None): + """Predict results from the extracted features.""" + if self.inference_method == 'generate': + return self.predict_generate(feats, data_samples) + elif self.inference_method == 'rank': + return self.predict_rank(feats, data_samples) diff --git a/mmpretrain/models/losses/__init__.py b/mmpretrain/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b2ed725ef76df7e18bf9283ec84b3b12e3d2cf --- /dev/null +++ b/mmpretrain/models/losses/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .asymmetric_loss import AsymmetricLoss, asymmetric_loss +from .cae_loss import CAELoss +from .cosine_similarity_loss import CosineSimilarityLoss +from .cross_correlation_loss import CrossCorrelationLoss +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy) +from .focal_loss import FocalLoss, sigmoid_focal_loss +from .label_smooth_loss import LabelSmoothLoss +from .reconstruction_loss import PixelReconstructionLoss +from .seesaw_loss import SeesawLoss +from .swav_loss import SwAVLoss +from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, + weighted_loss) + +__all__ = [ + 'asymmetric_loss', + 'AsymmetricLoss', + 'cross_entropy', + 'binary_cross_entropy', + 'CrossEntropyLoss', + 'reduce_loss', + 'weight_reduce_loss', + 'LabelSmoothLoss', + 'weighted_loss', + 'FocalLoss', + 'sigmoid_focal_loss', + 'convert_to_one_hot', + 'SeesawLoss', + 'CAELoss', + 'CosineSimilarityLoss', + 'CrossCorrelationLoss', + 'PixelReconstructionLoss', + 'SwAVLoss', +] diff --git a/mmpretrain/models/losses/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a17a7c20f617909fdf49f064ee3397871a86f5f Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/asymmetric_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/asymmetric_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d751561dc86e5d8c60b752242c07beddc746e81 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/asymmetric_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cae_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/cae_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48a4fba2fabfe6716348dd9c582763ffbe147d9d Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cae_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cosine_similarity_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/cosine_similarity_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ececc3cda39722621d8f2a620d6dc991302de94b Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cosine_similarity_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cross_correlation_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/cross_correlation_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d855f4ccc0338edf30144ed522e5dff4525220b5 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cross_correlation_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/cross_entropy_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/cross_entropy_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..169754282cffcfdc8e6f7b1cdff3aa26fe44f23c Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/cross_entropy_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/focal_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/focal_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7e5d54db1885aa4b6e66d6b70d1cd3d7613416c Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/focal_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/label_smooth_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/label_smooth_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c338cce6153e4434d13e5e74073f70092a3706c0 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/label_smooth_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/reconstruction_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/reconstruction_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e0a45a54366a59b77364c67c11997a33b5df77c Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/reconstruction_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/seesaw_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/seesaw_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..add962866f4a89a77005696c49b60839fc16c390 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/seesaw_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/swav_loss.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/swav_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e30f7dc4804c5d345e6f86c3137f6869804c0b0 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/swav_loss.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/__pycache__/utils.cpython-311.pyc b/mmpretrain/models/losses/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d560424f473766e451cc4f0a60322892ca514223 Binary files /dev/null and b/mmpretrain/models/losses/__pycache__/utils.cpython-311.pyc differ diff --git a/mmpretrain/models/losses/asymmetric_loss.py b/mmpretrain/models/losses/asymmetric_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc9707da8475b5e87d2b4f8a5a2cf669d7ffe2f --- /dev/null +++ b/mmpretrain/models/losses/asymmetric_loss.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .utils import convert_to_one_hot, weight_reduce_loss + + +def asymmetric_loss(pred, + target, + weight=None, + gamma_pos=1.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + avg_factor=None, + use_sigmoid=True, + eps=1e-8): + r"""asymmetric loss. + + Please refer to the `paper `__ for + details. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Defaults to None. + gamma_pos (float): positive focusing parameter. Defaults to 0.0. + gamma_neg (float): Negative focusing parameter. We usually set + gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + use_sigmoid (bool): Whether the prediction uses sigmoid instead + of softmax. Defaults to True. + eps (float): The minimum value of the argument of logarithm. Defaults + to 1e-8. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + + if use_sigmoid: + pred_sigmoid = pred.sigmoid() + else: + pred_sigmoid = nn.functional.softmax(pred, dim=-1) + + target = target.type_as(pred) + + if clip and clip > 0: + pt = (1 - pred_sigmoid + + clip).clamp(max=1) * (1 - target) + pred_sigmoid * target + else: + pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target + asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg * + (1 - target)) + loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class AsymmetricLoss(nn.Module): + """asymmetric loss. + + Args: + gamma_pos (float): positive focusing parameter. + Defaults to 0.0. + gamma_neg (float): Negative focusing parameter. We + usually set gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str): The method used to reduce the loss into + a scalar. + loss_weight (float): Weight of loss. Defaults to 1.0. + use_sigmoid (bool): Whether the prediction uses sigmoid instead + of softmax. Defaults to True. + eps (float): The minimum value of the argument of logarithm. Defaults + to 1e-8. + """ + + def __init__(self, + gamma_pos=0.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + loss_weight=1.0, + use_sigmoid=True, + eps=1e-8): + super(AsymmetricLoss, self).__init__() + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.clip = clip + self.reduction = reduction + self.loss_weight = loss_weight + self.use_sigmoid = use_sigmoid + self.eps = eps + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + r"""asymmetric loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction + with shape (N, \*), N or (N,1). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): + target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) + loss_cls = self.loss_weight * asymmetric_loss( + pred, + target, + weight, + gamma_pos=self.gamma_pos, + gamma_neg=self.gamma_neg, + clip=self.clip, + reduction=reduction, + avg_factor=avg_factor, + use_sigmoid=self.use_sigmoid, + eps=self.eps) + return loss_cls diff --git a/mmpretrain/models/losses/cae_loss.py b/mmpretrain/models/losses/cae_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc081b603361e9b06c96cf836941fa971a4b4c4 --- /dev/null +++ b/mmpretrain/models/losses/cae_loss.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CAELoss(BaseModule): + """Loss function for CAE. + + Compute the align loss and the main loss. + + Args: + lambd (float): The weight for the align loss. + """ + + def __init__(self, lambd: float) -> None: + super().__init__() + self.lambd = lambd + self.loss_cross_entropy = nn.CrossEntropyLoss() + self.loss_mse = nn.MSELoss() + + def forward( + self, logits: torch.Tensor, target: torch.Tensor, + latent_pred: torch.Tensor, + latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function of CAE Loss. + + Args: + logits (torch.Tensor): The outputs from the decoder. + target (torch.Tensor): The targets generated by dalle. + latent_pred (torch.Tensor): The latent prediction from the + regressor. + latent_target (torch.Tensor): The latent target from the teacher + network. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss. + """ + loss_main = self.loss_cross_entropy(logits, target) + loss_align = self.loss_mse(latent_pred, + latent_target.detach()) * self.lambd + + return loss_main, loss_align diff --git a/mmpretrain/models/losses/cosine_similarity_loss.py b/mmpretrain/models/losses/cosine_similarity_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a5931e24686bd560196e1e310fc283fc4c9d4d --- /dev/null +++ b/mmpretrain/models/losses/cosine_similarity_loss.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineSimilarityLoss(BaseModule): + """Cosine similarity loss function. + + Compute the similarity between two features and optimize that similarity as + loss. + + Args: + shift_factor (float): The shift factor of cosine similarity. + Default: 0.0. + scale_factor (float): The scale factor of cosine similarity. + Default: 1.0. + """ + + def __init__(self, + shift_factor: float = 0.0, + scale_factor: float = 1.0) -> None: + super().__init__() + self.shift_factor = shift_factor + self.scale_factor = scale_factor + + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function of cosine similarity loss. + + Args: + pred (torch.Tensor): The predicted features. + target (torch.Tensor): The target features. + + Returns: + torch.Tensor: The cosine similarity loss. + """ + pred_norm = nn.functional.normalize(pred, dim=-1) + target_norm = nn.functional.normalize(target, dim=-1) + loss = self.shift_factor - self.scale_factor * ( + pred_norm * target_norm).sum(dim=-1) + + if mask is None: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() + return loss diff --git a/mmpretrain/models/losses/cross_correlation_loss.py b/mmpretrain/models/losses/cross_correlation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d26ce3ddbd7b41778cbf25147df39da256788dd1 --- /dev/null +++ b/mmpretrain/models/losses/cross_correlation_loss.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CrossCorrelationLoss(BaseModule): + """Cross correlation loss function. + + Compute the on-diagnal and off-diagnal loss. + + Args: + lambd (float): The weight for the off-diag loss. + """ + + def __init__(self, lambd: float = 0.0051) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor: + """Forward function of cross correlation loss. + + Args: + cross_correlation_matrix (torch.Tensor): The cross correlation + matrix. + + Returns: + torch.Tensor: cross correlation loss. + """ + # loss + on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_( + 2).sum() + off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum() + loss = on_diag + self.lambd * off_diag + return loss + + def off_diagonal(self, x: torch.Tensor) -> torch.Tensor: + """Rreturn a flattened view of the off-diagonal elements of a square + matrix.""" + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/mmpretrain/models/losses/cross_entropy_loss.py b/mmpretrain/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5d418beb812f8493668aeff99198555068a55435 --- /dev/null +++ b/mmpretrain/models/losses/cross_entropy_loss.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None): + """Calculate the CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The gt label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def soft_cross_entropy(pred, + label, + weight=None, + reduction='mean', + class_weight=None, + avg_factor=None): + """Calculate the Soft CrossEntropy loss. The label can be float. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The gt label of the prediction with shape (N, C). + When using "mixup", the label can be float. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # element-wise losses + loss = -label * F.log_softmax(pred, dim=-1) + if class_weight is not None: + loss *= class_weight + loss = loss.sum(dim=-1) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + pos_weight=None): + r"""Calculate the binary CrossEntropy loss with logits. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + label (torch.Tensor): The gt label with shape (N, \*). + weight (torch.Tensor, optional): Element-wise weight of loss with shape + (N, ). Defaults to None. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + pos_weight (torch.Tensor, optional): The positive weight for each + class with shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # Ensure that the size of class_weight is consistent with pred and label to + # avoid automatic boracast, + assert pred.dim() == label.dim() + + if class_weight is not None: + N = pred.size()[0] + class_weight = class_weight.repeat(N, 1) + loss = F.binary_cross_entropy_with_logits( + pred, + label.float(), # only accepts float type tensor + weight=class_weight, + pos_weight=pos_weight, + reduction='none') + + # apply weights and do the reduction + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """Cross entropy loss. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_soft (bool): Whether to use the soft version of CrossEntropyLoss. + Defaults to False. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". Defaults to 'mean'. + loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (List[float], optional): The weight for each class with + shape (C), C is the number of classes. Default None. + pos_weight (List[float], optional): The positive weight for each + class with shape (C), C is the number of classes. Only enabled in + BCE loss when ``use_sigmoid`` is True. Default None. + """ + + def __init__(self, + use_sigmoid=False, + use_soft=False, + reduction='mean', + loss_weight=1.0, + class_weight=None, + pos_weight=None): + super(CrossEntropyLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.use_soft = use_soft + assert not ( + self.use_soft and self.use_sigmoid + ), 'use_sigmoid and use_soft could not be set simultaneously' + + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.pos_weight = pos_weight + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_soft: + self.cls_criterion = soft_cross_entropy + else: + self.cls_criterion = cross_entropy + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # only BCE loss has pos_weight + if self.pos_weight is not None and self.use_sigmoid: + pos_weight = cls_score.new_tensor(self.pos_weight) + kwargs.update({'pos_weight': pos_weight}) + else: + pos_weight = None + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/mmpretrain/models/losses/focal_loss.py b/mmpretrain/models/losses/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2cf5035aedfd923ae388b264a7457312b274fd --- /dev/null +++ b/mmpretrain/models/losses/focal_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import convert_to_one_hot, weight_reduce_loss + + +def sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + r"""Sigmoid focal loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Defaults to None. + gamma (float): The gamma for calculating the modulating factor. + Defaults to 2.0. + alpha (float): A balanced form for Focal Loss. Defaults to 0.25. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , + loss is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + """Focal loss. + + Args: + gamma (float): Focusing parameter in focal loss. + Defaults to 2.0. + alpha (float): The parameter in balanced form of focal + loss. Defaults to 0.25. + reduction (str): The method used to reduce the loss into + a scalar. Options are "none" and "mean". Defaults to 'mean'. + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0): + + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + r"""Sigmoid focal loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction + with shape (N, \*), N or (N,1). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): + target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) + loss_cls = self.loss_weight * sigmoid_focal_loss( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + return loss_cls diff --git a/mmpretrain/models/losses/label_smooth_loss.py b/mmpretrain/models/losses/label_smooth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f117df33b07c05ee7516f0b99d985f0d001b2d31 --- /dev/null +++ b/mmpretrain/models/losses/label_smooth_loss.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .cross_entropy_loss import CrossEntropyLoss +from .utils import convert_to_one_hot + + +@MODELS.register_module() +class LabelSmoothLoss(nn.Module): + r"""Initializer for the label smoothed cross entropy loss. + + Refers to `Rethinking the Inception Architecture for Computer Vision + `_ + + This decreases gap between output scores and encourages generalization. + Labels provided to forward can be one-hot like vectors (NxC) or class + indices (Nx1). + And this accepts linear combination of one-hot like labels from mixup or + cutmix except multi-label task. + + Args: + label_smooth_val (float): The degree of label smoothing. + num_classes (int, optional): Number of classes. Defaults to None. + mode (str): Refers to notes, Options are 'original', 'classy_vision', + 'multi_label'. Defaults to 'original'. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid of + softmax. Defaults to None, which means to use sigmoid in + "multi_label" mode and not use in other modes. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". Defaults to 'mean'. + loss_weight (float): Weight of the loss. Defaults to 1.0. + + Notes: + - if the mode is **"original"**, this will use the same label smooth + method as the original paper as: + + .. math:: + (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K} + + where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the + ``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which + equals 1 for :math:`k=y` and 0 otherwise. + + - if the mode is **"classy_vision"**, this will use the same label + smooth method as the facebookresearch/ClassyVision repo as: + + .. math:: + \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon} + + - if the mode is **"multi_label"**, this will accept labels from + multi-label task and smoothing them as: + + .. math:: + (1-2\epsilon)\delta_{k, y} + \epsilon + """ + + def __init__(self, + label_smooth_val, + num_classes=None, + use_sigmoid=None, + mode='original', + reduction='mean', + loss_weight=1.0, + class_weight=None, + pos_weight=None): + super().__init__() + self.num_classes = num_classes + self.loss_weight = loss_weight + + assert (isinstance(label_smooth_val, float) + and 0 <= label_smooth_val < 1), \ + f'LabelSmoothLoss accepts a float label_smooth_val ' \ + f'over [0, 1), but gets {label_smooth_val}' + self.label_smooth_val = label_smooth_val + + accept_reduction = {'none', 'mean', 'sum'} + assert reduction in accept_reduction, \ + f'LabelSmoothLoss supports reduction {accept_reduction}, ' \ + f'but gets {mode}.' + self.reduction = reduction + + accept_mode = {'original', 'classy_vision', 'multi_label'} + assert mode in accept_mode, \ + f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.' + self.mode = mode + + self._eps = label_smooth_val + if mode == 'classy_vision': + self._eps = label_smooth_val / (1 + label_smooth_val) + + if mode == 'multi_label': + if not use_sigmoid: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().warning( + 'For multi-label tasks, please set `use_sigmoid=True` ' + 'to use binary cross entropy.') + self.smooth_label = self.multilabel_smooth_label + use_sigmoid = True if use_sigmoid is None else use_sigmoid + else: + self.smooth_label = self.original_smooth_label + use_sigmoid = False if use_sigmoid is None else use_sigmoid + + self.ce = CrossEntropyLoss( + use_sigmoid=use_sigmoid, + use_soft=not use_sigmoid, + reduction=reduction, + class_weight=class_weight, + pos_weight=pos_weight) + + def generate_one_hot_like_label(self, label): + """This function takes one-hot or index label vectors and computes one- + hot like label vectors (float)""" + # check if targets are inputted as class integers + if label.dim() == 1 or (label.dim() == 2 and label.shape[1] == 1): + label = convert_to_one_hot(label.view(-1, 1), self.num_classes) + return label.float() + + def original_smooth_label(self, one_hot_like_label): + assert self.num_classes > 0 + smooth_label = one_hot_like_label * (1 - self._eps) + smooth_label += self._eps / self.num_classes + return smooth_label + + def multilabel_smooth_label(self, one_hot_like_label): + assert self.num_classes > 0 + smooth_label = torch.full_like(one_hot_like_label, self._eps) + smooth_label.masked_fill_(one_hot_like_label > 0, 1 - self._eps) + return smooth_label + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + r"""Label smooth loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + label (torch.Tensor): The ground truth label of the prediction + with shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + if self.num_classes is not None: + assert self.num_classes == cls_score.shape[1], \ + f'num_classes should equal to cls_score.shape[1], ' \ + f'but got num_classes: {self.num_classes} and ' \ + f'cls_score.shape[1]: {cls_score.shape[1]}' + else: + self.num_classes = cls_score.shape[1] + + one_hot_like_label = self.generate_one_hot_like_label(label=label) + assert one_hot_like_label.shape == cls_score.shape, \ + f'LabelSmoothLoss requires output and target ' \ + f'to be same shape, but got output.shape: {cls_score.shape} ' \ + f'and target.shape: {one_hot_like_label.shape}' + + smoothed_label = self.smooth_label(one_hot_like_label) + return self.loss_weight * self.ce.forward( + cls_score, + smoothed_label, + weight=weight, + avg_factor=avg_factor, + reduction_override=reduction_override, + **kwargs) diff --git a/mmpretrain/models/losses/reconstruction_loss.py b/mmpretrain/models/losses/reconstruction_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..40e6bfd707b8e378f1ec656cfb443c27e8bbdbb3 --- /dev/null +++ b/mmpretrain/models/losses/reconstruction_loss.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class PixelReconstructionLoss(BaseModule): + """Loss for the reconstruction of pixel in Masked Image Modeling. + + This module measures the distance between the target image and the + reconstructed image and compute the loss to optimize the model. Currently, + This module only provides L1 and L2 loss to penalize the reconstructed + error. In addition, a mask can be passed in the ``forward`` function to + only apply loss on visible region, like that in MAE. + + Args: + criterion (str): The loss the penalize the reconstructed error. + Currently, only supports L1 and L2 loss + channel (int, optional): The number of channels to average the + reconstruction loss. If not None, the reconstruction loss + will be divided by the channel. Defaults to None. + """ + + def __init__(self, criterion: str, channel: Optional[int] = None) -> None: + super().__init__() + + if criterion == 'L1': + self.penalty = torch.nn.L1Loss(reduction='none') + elif criterion == 'L2': + self.penalty = torch.nn.MSELoss(reduction='none') + else: + raise NotImplementedError(f'Currently, PixelReconstructionLoss \ + only supports L1 and L2 loss, but get {criterion}') + + self.channel = channel if channel is not None else 1 + + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function to compute the reconstrction loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + loss = self.penalty(pred, target) + + # if the dim of the loss is 3, take the average of the loss + # along the last dim + if len(loss.shape) == 3: + loss = loss.mean(dim=-1) + + if mask is None: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() / self.channel + + return loss diff --git a/mmpretrain/models/losses/seesaw_loss.py b/mmpretrain/models/losses/seesaw_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4aaaa451b41ea7e86b7efbfe1c0b6ce8b3756d80 --- /dev/null +++ b/mmpretrain/models/losses/seesaw_loss.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# migrate from mmdetection with modifications +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import weight_reduce_loss + + +def seesaw_ce_loss(cls_score, + labels, + weight, + cum_samples, + num_classes, + p, + q, + eps, + reduction='mean', + avg_factor=None): + """Calculate the Seesaw CrossEntropy loss. + + Args: + cls_score (torch.Tensor): The prediction with shape (N, C), + C is the number of classes. + labels (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor): Sample-wise loss weight. + cum_samples (torch.Tensor): Cumulative samples for each category. + num_classes (int): The number of classes. + p (float): The ``p`` in the mitigation factor. + q (float): The ``q`` in the compenstation factor. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor + reduction (str, optional): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: The calculated loss + """ + assert cls_score.size(-1) == num_classes + assert len(cum_samples) == num_classes + + onehot_labels = F.one_hot(labels, num_classes) + seesaw_weights = cls_score.new_ones(onehot_labels.size()) + + # mitigation factor + if p > 0: + sample_ratio_matrix = cum_samples[None, :].clamp( + min=1) / cum_samples[:, None].clamp(min=1) + index = (sample_ratio_matrix < 1.0).float() + sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index + ) # M_{ij} + mitigation_factor = sample_weights[labels.long(), :] + seesaw_weights = seesaw_weights * mitigation_factor + + # compensation factor + if q > 0: + scores = F.softmax(cls_score.detach(), dim=1) + self_scores = scores[ + torch.arange(0, len(scores)).to(scores.device).long(), + labels.long()] + score_matrix = scores / self_scores[:, None].clamp(min=eps) + index = (score_matrix > 1.0).float() + compensation_factor = score_matrix.pow(q) * index + (1 - index) + seesaw_weights = seesaw_weights * compensation_factor + + cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) + + loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') + + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class SeesawLoss(nn.Module): + """Implementation of seesaw loss. + + Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) + `_ + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. + Only False is supported. Defaults to False. + p (float): The ``p`` in the mitigation factor. + Defaults to 0.8. + q (float): The ``q`` in the compenstation factor. + Defaults to 2.0. + num_classes (int): The number of classes. + Defaults to 1000 for the ImageNet dataset. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor, default to 1e-2. + reduction (str): The method that reduces the loss to a scalar. + Options are "none", "mean" and "sum". Defaults to "mean". + loss_weight (float): The weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + use_sigmoid=False, + p=0.8, + q=2.0, + num_classes=1000, + eps=1e-2, + reduction='mean', + loss_weight=1.0): + super(SeesawLoss, self).__init__() + assert not use_sigmoid, '`use_sigmoid` is not supported' + self.use_sigmoid = False + self.p = p + self.q = q + self.num_classes = num_classes + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + self.cls_criterion = seesaw_ce_loss + + # cumulative samples for each category + self.register_buffer('cum_samples', + torch.zeros(self.num_classes, dtype=torch.float)) + + def forward(self, + cls_score, + labels, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + cls_score (torch.Tensor): The prediction with shape (N, C). + labels (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + f'The `reduction_override` should be one of (None, "none", ' \ + f'"mean", "sum"), but get "{reduction_override}".' + assert cls_score.size(0) == labels.view(-1).size(0), \ + f'Expected `labels` shape [{cls_score.size(0)}], ' \ + f'but got {list(labels.size())}' + reduction = ( + reduction_override if reduction_override else self.reduction) + assert cls_score.size(-1) == self.num_classes, \ + f'The channel number of output ({cls_score.size(-1)}) does ' \ + f'not match the `num_classes` of seesaw loss ({self.num_classes}).' + + # accumulate the samples for each category + unique_labels = labels.unique() + for u_l in unique_labels: + inds_ = labels == u_l.item() + self.cum_samples[u_l] += inds_.sum() + + if weight is not None: + weight = weight.float() + else: + weight = labels.new_ones(labels.size(), dtype=torch.float) + + # calculate loss_cls_classes + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, labels, weight, self.cum_samples, self.num_classes, + self.p, self.q, self.eps, reduction, avg_factor) + + return loss_cls diff --git a/mmpretrain/models/losses/swav_loss.py b/mmpretrain/models/losses/swav_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c7dbb78e9bf6619cede65a874569072b863bdfa0 --- /dev/null +++ b/mmpretrain/models/losses/swav_loss.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from mmengine.dist import all_reduce +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@torch.no_grad() +def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int, + world_size: int, epsilon: float) -> torch.Tensor: + """Apply the distributed sinknorn optimization on the scores matrix to find + the assignments. + + This function is modified from + https://github.com/facebookresearch/swav/blob/main/main_swav.py + + Args: + out (torch.Tensor): The scores matrix + sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp + algorithm. + world_size (int): The world size of the process group. + epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. + + Returns: + torch.Tensor: Output of sinkhorn algorithm. + """ + eps_num_stab = 1e-12 + Q = torch.exp(out / epsilon).t( + ) # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + all_reduce(sum_Q) + Q /= sum_Q + + for it in range(sinkhorn_iterations): + # normalize each row: total weight per prototype must be 1/K + u = torch.sum(Q, dim=1, keepdim=True) + if len(torch.nonzero(u == 0)) > 0: + Q += eps_num_stab + u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype) + all_reduce(u) + Q /= u + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + +class MultiPrototypes(BaseModule): + """Multi-prototypes for SwAV head. + + Args: + output_dim (int): The output dim from SwAV neck. + num_prototypes (List[int]): The number of prototypes needed. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + output_dim: int, + num_prototypes: List[int], + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(num_prototypes, list) + self.num_heads = len(num_prototypes) + for i, k in enumerate(num_prototypes): + self.add_module('prototypes' + str(i), + nn.Linear(output_dim, k, bias=False)) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Run forward for every prototype.""" + out = [] + for i in range(self.num_heads): + out.append(getattr(self, 'prototypes' + str(i))(x)) + return out + + +@MODELS.register_module() +class SwAVLoss(BaseModule): + """The Loss for SwAV. + + This Loss contains clustering and sinkhorn algorithms to compute Q codes. + Part of the code is borrowed from `script + `_. + The queue is built in `engine/hooks/swav_hook.py`. + + Args: + feat_dim (int): feature dimension of the prototypes. + sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp + algorithm. Defaults to 3. + epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. + Defaults to 0.05. + temperature (float): temperature parameter in training loss. + Defaults to 0.1. + crops_for_assign (List[int]): list of crops id used for computing + assignments. Defaults to [0, 1]. + num_crops (List[int]): list of number of crops. Defaults to [2]. + num_prototypes (int): number of prototypes. Defaults to 3000. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + feat_dim: int, + sinkhorn_iterations: int = 3, + epsilon: float = 0.05, + temperature: float = 0.1, + crops_for_assign: List[int] = [0, 1], + num_crops: List[int] = [2], + num_prototypes: int = 3000, + init_cfg: Optional[Union[List[dict], dict]] = None): + super().__init__(init_cfg=init_cfg) + self.sinkhorn_iterations = sinkhorn_iterations + self.epsilon = epsilon + self.temperature = temperature + self.crops_for_assign = crops_for_assign + self.num_crops = num_crops + self.use_queue = False + self.queue = None + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # prototype layer + self.prototypes = None + if isinstance(num_prototypes, list): + self.prototypes = MultiPrototypes(feat_dim, num_prototypes) + elif num_prototypes > 0: + self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False) + assert self.prototypes is not None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function of SwAV loss. + + Args: + x (torch.Tensor): NxC input features. + Returns: + torch.Tensor: The returned loss. + """ + # normalize the prototypes + with torch.no_grad(): + w = self.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + self.prototypes.weight.copy_(w) + + embedding, output = x, self.prototypes(x) + embedding = embedding.detach() + + bs = int(embedding.size(0) / sum(self.num_crops)) + loss = 0 + for i, crop_id in enumerate(self.crops_for_assign): + with torch.no_grad(): + out = output[bs * crop_id:bs * (crop_id + 1)].detach() + # time to use the queue + if self.queue is not None: + if self.use_queue or not torch.all(self.queue[i, + -1, :] == 0): + self.use_queue = True + out = torch.cat( + (torch.mm(self.queue[i], + self.prototypes.weight.t()), out)) + # fill the queue + self.queue[i, bs:] = self.queue[i, :-bs].clone() + self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) * + bs] + + # get assignments (batch_size * num_prototypes) + q = distributed_sinkhorn(out, self.sinkhorn_iterations, + self.world_size, self.epsilon)[-bs:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id): + x = output[bs * v:bs * (v + 1)] / self.temperature + subloss -= torch.mean( + torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1)) + loss += subloss / (np.sum(self.num_crops) - 1) + loss /= len(self.crops_for_assign) + return loss diff --git a/mmpretrain/models/losses/utils.py b/mmpretrain/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a65b68a6590aa3fe10a023022c9c9c9bce51f935 --- /dev/null +++ b/mmpretrain/models/losses/utils.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + ``loss_func(pred, target, **kwargs)``. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like ``loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)``. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + """This function converts target class indices to one-hot vectors, given + the number of classes. + + Args: + targets (Tensor): The ground truth label of the prediction + with shape (N, 1) + classes (int): the number of classes. + + Returns: + Tensor: Processed loss values. + """ + assert (torch.max(targets).item() < + classes), 'Class Index must be less than number of classes' + one_hot_targets = F.one_hot( + targets.long().squeeze(-1), num_classes=classes) + return one_hot_targets diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73645f0f5e6898151380e87f3e40cb97b624b418 --- /dev/null +++ b/mmpretrain/models/multimodal/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL + +if WITH_MULTIMODAL: + from .blip import * # noqa: F401,F403 + from .blip2 import * # noqa: F401,F403 + from .chinese_clip import * # noqa: F401, F403 + from .clip import * # noqa: F401, F403 + from .flamingo import * # noqa: F401, F403 + from .llava import * # noqa: F401, F403 + from .minigpt4 import * # noqa: F401, F403 + from .ofa import * # noqa: F401, F403 + from .otter import * # noqa: F401, F403 +else: + from mmpretrain.registry import MODELS + from mmpretrain.utils.dependency import register_multimodal_placeholder + + register_multimodal_placeholder([ + 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', + 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', + 'CLIPZeroShot' + ], MODELS) diff --git a/mmpretrain/models/multimodal/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/multimodal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ba193c596a797e29802cc3a5c19d007db896361 Binary files /dev/null and b/mmpretrain/models/multimodal/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/multimodal/blip/__init__.py b/mmpretrain/models/multimodal/blip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbc0da6e0d11c116d4575b6c981724e387e415a --- /dev/null +++ b/mmpretrain/models/multimodal/blip/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blip_caption import BlipCaption +from .blip_grounding import BlipGrounding +from .blip_nlvr import BlipNLVR +from .blip_retrieval import BlipRetrieval +from .blip_vqa import BlipVQA +from .language_model import BertLMHeadModel, XBertEncoder, XBertLMHeadDecoder + +__all__ = [ + 'BertLMHeadModel', 'BlipCaption', 'BlipGrounding', 'BlipNLVR', + 'BlipRetrieval', 'BlipVQA', 'XBertEncoder', 'XBertLMHeadDecoder' +] diff --git a/mmpretrain/models/multimodal/blip/blip_caption.py b/mmpretrain/models/multimodal/blip/blip_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..9af3e2408da8c6b3a55694a1323e6434dfc609e1 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_caption.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class BlipCaption(BaseModel): + """BLIP Caption. + + Args: + vision_encoder (dict): Encoder for extracting image features. + decoder_head (dict): The decoder head module to forward and + calculate loss from processed features. + tokenizer: (Optional[dict]): The config for tokenizer. + Defaults to None. + prompt (str): Prompt used for training and eval. + Defaults to ''. + max_txt_len (int): Max text length of input text. + num_captions (int): Number of captions to be generated for each image. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_encoder: dict, + decoder_head: dict, + tokenizer: Optional[dict] = None, + prompt: str = '', + max_txt_len: int = 20, + num_captions: int = 1, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipCaption, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.visual_encoder = MODELS.build(vision_encoder) + self.seq_gen_head = MODELS.build(decoder_head) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + self.max_txt_len = max_txt_len + self.num_captions = num_captions + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept two modes: "predict" and "loss": + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): pre_processed img tensor (N, C, ...). + data_samples (List[DataSample], optional): Data samples with + additional infos. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, images, data_samples=None, **kwargs): + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # prepare inputs for decoder generation. + image_embeds = self.visual_encoder(images)[0] + image_embeds = torch.repeat_interleave(image_embeds, self.num_captions, + 0) + + prompt = [self.prompt] * image_embeds.size(0) + prompt = self.tokenizer( + prompt, padding='longest', + return_tensors='pt').to(image_embeds.device) + + prompt.input_ids[:, 0] = self.tokenizer.bos_token_id + prompt.input_ids = prompt.input_ids[:, :-1] + + decoder_out = self.seq_gen_head.predict( + input_ids=prompt.input_ids, + encoder_hidden_states=image_embeds, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + output_attentions=True, + return_dict_in_generate=True, + ) + + decode_tokens = self.tokenizer.batch_decode( + decoder_out.sequences, skip_special_tokens=True) + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(len(decode_tokens))] + + for data_sample, decode_token in zip(data_samples, decode_tokens): + if data_sample is None: + data_sample = DataSample() + data_sample.pred_caption = decode_token[len(self.prompt):] + out_data_samples.append(data_sample) + + return out_data_samples + + def loss(self, images, data_samples): + """Calculate losses from a batch of images and data samples. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[ImageTextDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + image_embeds = self.visual_encoder(images)[0] + raw_text = [self.prompt + ds.gt_caption for ds in data_samples] + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(image_embeds.device) + text.input_ids[:, 0] = self.tokenizer.bos_token_id + + # prepare targets for forwarding decoder + labels = text.input_ids.masked_fill( + text.input_ids == self.tokenizer.pad_token_id, -100) + labels[:, :self.prompt_length] = -100 + # forward decoder + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + + losses = self.seq_gen_head.loss( + input_ids=text.input_ids, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + labels=labels, + ) + return losses diff --git a/mmpretrain/models/multimodal/blip/blip_grounding.py b/mmpretrain/models/multimodal/blip/blip_grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..cb087287220a91b3bfcd50acee244eb5dc118bac --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_grounding.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.model import BaseModel + +from mmpretrain.models.utils.box_utils import box_xyxy_to_cxcywh +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures.data_sample import DataSample + + +@MODELS.register_module() +class BlipGrounding(BaseModel): + """BLIP Grounding. + + Args: + visual_encoder (dict): Backbone for extracting image features. + text_encoder (dict): Backbone for extracting text features. + but we integrate the vqa text extractor + into the tokenizer part in datasets/transform/ + so we don't need text_backbone + multimodal_encoder (Optional[dict]): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + neck (Optional[dict]): The neck module to process features from + backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + tokenizer: Optional[dict] = None, + visual_encoder: Optional[dict] = None, + text_encoder: Optional[dict] = None, + multimodal_encoder: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipGrounding, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.prompt = 'localize instance: ' + self.visual_encoder = MODELS.build(visual_encoder) + self.text_encoder = MODELS.build(text_encoder) + self.multimodal_encoder = MODELS.build(multimodal_encoder) + head.setdefault('tokenizer', self.tokenizer) + self.grounding_head = MODELS.build(head) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[VQADataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: + image_embeds (Tensor): The output features. + """ + image_embeds = self.visual_encoder(images)[0] + return image_embeds + + def loss( + self, + images: torch.Tensor, + data_samples=None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """generate train_loss from the input tensor and data_samples. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + data_samples (List[VQADataSample], optional): The annotation + data of every samples.. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + + # extract image feature + image_embeds = self.extract_feat(images) + image_atts = image_embeds.new_ones( + image_embeds.size()[:-1], dtype=torch.long) + + raw_text = [] + box_targets = [] + for ds in data_samples: + + raw_text.append(ds.text) + box_t = copy.deepcopy(ds.box) * 1.0 + box_t[1] /= ds.img_shape[0] + box_t[3] /= ds.img_shape[0] + box_t[0] /= ds.img_shape[1] + box_t[2] /= ds.img_shape[1] + + box_targets.append(box_t) + + box_targets = image_embeds.new_tensor(np.stack(box_targets)) + box_targets = box_xyxy_to_cxcywh(box_targets) # xywh 0-1 + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=128, + return_tensors='pt', + ).to(image_embeds.device) + + text_embeds = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + mode='text', + return_dict=True) # bz, seq_len, hid + + # multimodal fusion + multimodal_embeds = self.multimodal_encoder( + encoder_embeds=text_embeds.last_hidden_state, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + losses = self.grounding_head.loss( + text_embedding=multimodal_embeds.last_hidden_state, + text_embedding_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + decoder_targets=box_targets, + ) + + return losses + + def predict(self, images, data_samples=None): + """""" + + # extract image feature + image_embeds = self.extract_feat(images) + image_atts = image_embeds.new_ones( + image_embeds.size()[:-1], dtype=torch.long) + + raw_text = [] + for ds in data_samples: + raw_text.append(ds.text) + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=128, + return_tensors='pt', + ).to(image_embeds.device) + + text_embeds = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + mode='text', + return_dict=True) # bz, seq_len, hid + + # multimodal fusion + multimodal_embeds = self.multimodal_encoder( + encoder_embeds=text_embeds.last_hidden_state, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + output_boxes = self.grounding_head.predict( + text_embedding=multimodal_embeds.last_hidden_state, + text_embedding_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + ) # xyxy 0-1 + + out_data_samples = [] + for bbox, data_sample, img in zip(output_boxes, data_samples, images): + if data_sample is None: + data_sample = DataSample() + + img_size = img.shape[-2:] + scale_factor = data_sample.get('scale_factor', (1, 1)) + bbox[0::2] = bbox[0::2] * img_size[1] / scale_factor[0] + bbox[1::2] = bbox[1::2] * img_size[0] / scale_factor[1] + bbox = bbox[None, :] + data_sample.pred_bboxes = bbox + + if 'gt_bboxes' in data_sample: + gt_bboxes = torch.Tensor(data_sample.get('gt_bboxes')) + gt_bboxes[:, 0::2] /= scale_factor[0] + gt_bboxes[:, 1::2] /= scale_factor[1] + data_sample.gt_bboxes = gt_bboxes + + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/blip/blip_nlvr.py b/mmpretrain/models/multimodal/blip/blip_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..f96e3cce237fd3b064c74264e8f907a8bd3a47ca --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_nlvr.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER + + +@MODELS.register_module() +class BlipNLVR(BaseModel): + """BLIP NLVR. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + but we integrate the vqa text extractor into the tokenizer part in + datasets/transform/ so we don't need text_backbone + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + neck (Optional[dict]): The neck module to process features from + backbone. Defaults to None. + head (Optional[dict]): The head module to calculate + loss from processed features. See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + tokenizer: (Optional[dict]): The config for tokenizer + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + multimodal_backbone: dict, + tokenizer: Optional[dict] = None, + max_txt_len: int = 35, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + if tokenizer is not None: + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_backbone = MODELS.build(vision_backbone) + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.max_txt_len = max_txt_len + + # For simplity, directly use head definition here. + # If more complex head is designed, move this and loss to a new + # head module. + hidden_size = self.multimodal_backbone.config.hidden_size + self.head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, 2), + ) + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_text(self, data_samples): + + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + texts = [sample.get('text') for sample in data_samples] + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def forward( + self, + images: dict, + data_samples: Optional[List] = None, + mode: str = 'tensor', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + images and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (dict of torch.Tensor): + img: pre_processed img tensor (N, C, ...). + text: tokenized text (N, L) + data_samples (List[CaptionDataSample], optional): + The annotation data of every samples. + 'image': raw image data + 'text' tokenized text + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + # B, T, C, H, W to T*B, C, H, W + images = images.permute(1, 0, 2, 3, 4).flatten(0, 1) + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, images, data_samples=None): + """Predict caption.""" + # prepare inputs for decoder generation. + image_embeds = self.vision_backbone(images)[0] + texts = self.preprocess_text(data_samples) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + image0_embeds, image1_embeds = torch.split(image_embeds, + texts.input_ids.size(0)) + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + texts.input_ids, + attention_mask=texts.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):], + ], + return_dict=True, + ) + + # get prediction + outputs = self.head(multimodal_embeds.last_hidden_state[:, 0, :]) + + pred_scores = F.softmax(outputs, dim=1) + + for pred_score, data_sample in zip(pred_scores, data_samples): + data_sample.set_pred_score(pred_score) + data_sample.set_pred_label(pred_score.argmax(dim=0)) + + return data_samples + + def loss(self, images, data_samples): + """Calculate losses from a batch of inputs and data samples. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[ImageTextDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # prepare inputs for decoder generation. + image_embeds = self.vision_backbone(images)[0] + texts = self.preprocess_text(data_samples) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + image0_embeds, image1_embeds = torch.split(image_embeds, + texts.input_ids.size(0)) + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + texts.input_ids, + attention_mask=texts.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):], + ], + return_dict=True, + ) + + # get prediction + outputs = self.head(multimodal_embeds.last_hidden_state[:, 0, :]) + + targets = torch.tensor([i.gt_label + for i in data_samples]).to(outputs.device) + loss = F.cross_entropy(outputs, targets) + return {'loss': loss} diff --git a/mmpretrain/models/multimodal/blip/blip_retrieval.py b/mmpretrain/models/multimodal/blip/blip_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..3ebc2513de28d928bc6e2442929cbb402348b1ca --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_retrieval.py @@ -0,0 +1,716 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import ChainMap +from copy import deepcopy +from typing import Dict, List, Optional, Tuple, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import distributed as torch_dist + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process + + +def all_gather_concat(data: torch.Tensor) -> torch.Tensor: + """Gather tensors with different first-dimension size and concat to one + tenosr. + + Note: + Only the first dimension should be different. + + Args: + data (Tensor): Tensor to be gathered. + + Returns: + torch.Tensor: The concatenated tenosr. + """ + if dist.get_world_size() == 1: + return data + + data_size = torch.tensor(data.size(0), device=data.device) + sizes_list = dist.all_gather(data_size) + + max_length = max(sizes_list) + size_diff = max_length.item() - data_size.item() + if size_diff: + padding = torch.zeros( + size_diff, *data.size()[1:], device=data.device, dtype=data.dtype) + data = torch.cat((data, padding)) + + gather_list = dist.all_gather(data) + + all_data = [] + for tensor, size in zip(gather_list, sizes_list): + + all_data.append(tensor[:size]) + + return torch.concat(all_data) + + +@MODELS.register_module() +class BlipRetrieval(BaseModel): + """BLIP Retriever. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. + vision_neck (Optional[dict]): The neck module to process image features + from vision backbone. Defaults to None. + text_neck (Optional[dict]): The neck module to process text features + from text backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed single modality features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal + head module to calculate loss from processed multimodal features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + momentum (float): Momentum used for momentum contrast. + Defaults to .995. + negative_all_rank (bool): Whether to sample negative data from all + ranks for image text matching in training. Defaults to True. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + topk (int): Select topk similarity as candidates for compute matching + scores. Notice that this is not the topk in evaluation. + Defaults to 256. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + multimodal_backbone: Optional[dict] = None, + vision_neck: Optional[dict] = None, + text_neck: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + multimodal_head: Optional[Union[List[dict], dict]] = None, + tokenizer: Optional[dict] = None, + momentum: float = .995, + negative_all_rank: bool = True, + temperature: float = 0.07, + fast_match: bool = False, + topk: int = 256, + max_txt_len: int = 20, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.vision_backbone = MODELS.build(vision_backbone) + self.text_backbone = MODELS.build(text_backbone) + + if multimodal_backbone is not None: + self.multimodal_backbone = MODELS.build(multimodal_backbone) + + if vision_neck is not None: + self.vision_neck = MODELS.build(vision_neck) + + if text_neck is not None: + self.text_neck = MODELS.build(text_neck) + + if head is not None: + self.head = MODELS.build(head) + + if multimodal_head is not None: + self.multimodal_head = MODELS.build(multimodal_head) + + if tokenizer is not None: + self.tokenizer = TOKENIZER.build(tokenizer) + + self.momentum = momentum + self.negative_all_rank = negative_all_rank + self.temp = nn.Parameter(temperature * torch.ones([])) + # Shares the same para + self.head.temp = self.temp + + # create the momentum encoder + self.vision_backbone_m = deepcopy(self.vision_backbone) + self.text_backbone_m = deepcopy(self.text_backbone) + + self.vision_neck_m = deepcopy(self.vision_neck) + self.text_neck_m = deepcopy(self.text_neck) + + self.model_pairs = [ + [self.vision_backbone, self.vision_backbone_m], + [self.text_backbone, self.text_backbone_m], + [self.vision_neck, self.vision_neck_m], + [self.text_neck, self.text_neck_m], + ] + self.copy_params() + + # multimodal backbone shares weights with text backbone in BLIP + # No need to set up + + # Notice that this topk is used for select k candidate to compute + # image-text score, but not the final metric topk in evaluation. + self.fast_match = fast_match + self.topk = topk + + self.max_txt_len = max_txt_len + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_text(self, data_samples): + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + if isinstance(sample_item.get('text'), (list, tuple)): + texts = [] + for sample in data_samples: + texts.extend(sample.get('text')) + elif isinstance(sample_item.get('text'), str): + texts = [sample.get('text') for sample in data_samples] + else: + raise TypeError('text must be a string or a list of strings') + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='max_length', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def forward(self, + images: torch.tensor = None, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor') -> Union[Tuple, dict]: + """The unified entry for a forward process in both training and test. + The method should accept two modes: "tensor", and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + For unified "predict" mode in other mm repos. It is noticed that + image-text retrieval cannot perform batch prediction since it will go + through all the samples. A standard process of retrieval evaluation is + to extract and collect all feats, and then predict all samples. + Therefore the `predict` mode here is remained as a trigger + to inform use to choose the right configurations. + + Args: + images (torch.Tensor): The input inputs tensor of shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + - If ``mode="tensor"``, return a tuple. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + return self.extract_feat(images, data_samples) + elif mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat( + self, + images: torch.Tensor = None, + data_samples: List[DataSample] = None, + return_texts=True, + return_embeds=None, + ) -> Dict[str, torch.Tensor]: + """Extract features from the input dict. + + Args: + images (tensor, optional): The images to extract features. + Defaults to None. + data_samples (list, optional): The data samples containing texts + to extract features. Defaults to None. + return_texts (bool): Whether to return the tokenized text and the + corresponding attention masks. Defaults to True. + return_embeds (bool): Whether to return the text embedding and + image embedding. Defaults to None, which means to use + ``self.fast_match``. + + Returns: + Tuple[torch.Tensor]: The output features. + If multimodal_backbone is not exist, tuple of torch.Tensor + will be returned. + """ + if data_samples is not None: + texts = self.preprocess_text(data_samples) + else: + texts = None + + assert images is not None or texts is not None, \ + 'At least single modality should be passed as inputs.' + + results = {} + if texts is not None and return_texts: + results.update({ + 'text_ids': texts.input_ids, + 'text_attn_mask': texts.attention_mask, + }) + + if return_embeds is None: + return_embeds = not self.fast_match + + # extract image features + if images is not None: + output = self._extract_feat(images, modality='images') + results['image_feat'] = output['image_feat'] + if return_embeds: + results['image_embeds'] = output['image_embeds'] + + # extract text features + if texts is not None: + output = self._extract_feat(texts, modality='texts') + results['text_feat'] = output['text_feat'] + if return_embeds: + results['text_embeds'] = output['text_embeds'] + + return results + + def _extract_feat(self, inputs: Union[torch.Tensor, dict], + modality: str) -> Tuple[torch.Tensor]: + """Extract features from the single modality. + + Args: + inputs (Union[torch.Tensor, dict]): A batch of inputs. + For image, a tensor of shape (N, C, ...) in general. + For text, a dict of tokenized text inputs. + modality (str): Modality feature to be extracted. Only two + options are supported. + + - ``images``: Only extract image features, mostly used for + inference. + - ``texts``: Only extract text features, mostly used for + inference. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + + if modality == 'images': + # extract image features + image_embeds = self.vision_backbone(inputs)[0] + image_feat = F.normalize( + self.vision_neck(image_embeds[:, 0, :]), dim=-1) + return {'image_embeds': image_embeds, 'image_feat': image_feat} + elif modality == 'texts': + # extract text features + text_output = self.text_backbone( + inputs.input_ids, + attention_mask=inputs.attention_mask, + token_type_ids=None, + return_dict=True, + mode='text', + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_neck(text_embeds[:, 0, :]), dim=-1) + return {'text_embeds': text_embeds, 'text_feat': text_feat} + else: + raise RuntimeError(f'Invalid modality "{modality}".') + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + both head and multimodal head. + """ + output = self.extract_feat(images, data_samples, return_embeds=True) + + text_ids = output['text_ids'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.vision_backbone_m(images)[0] + image_feat_m = F.normalize( + self.vision_neck_m(image_embeds_m[:, 0, :]), dim=-1) + + text_output_m = self.text_backbone_m( + text_ids, + attention_mask=text_attn_mask, + token_type_ids=None, + return_dict=True, + mode='text', + ) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize( + self.text_neck_m(text_embeds_m[:, 0, :]), dim=-1) + + loss = self.head.loss( + ([image_feat, text_feat, image_feat_m, text_feat_m], ), + data_samples) + + # prepare for itm + encoder_input_ids = text_ids.clone() + encoder_input_ids[:, + 0] = self.tokenizer.additional_special_tokens_ids[0] + output_pos = self.text_backbone( + encoder_input_ids, + attention_mask=text_attn_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + idx = torch.tensor([i.image_id for i in data_samples]).view(-1, 1) + bs = idx.size(0) + idxs = torch.cat(dist.all_gather(idx)) + if self.negative_all_rank: + # compute sample similarity + with torch.no_grad(): + mask = torch.eq(idx, idxs.t()).to(self.device) + + image_feat_world = torch.cat(dist.all_gather(image_feat)) + text_feat_world = torch.cat(dist.all_gather(text_feat)) + + sim_i2t = image_feat @ text_feat_world.t() / self.temp + sim_t2i = text_feat @ image_feat_world.t() / self.temp + + weights_i2t = F.softmax(sim_i2t, dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i, dim=1) + weights_t2i.masked_fill_(mask, 0) + + world_size = dist.get_world_size() + if world_size == 1: + image_embeds_world = image_embeds + else: + image_embeds_world = torch.cat( + torch_dist.nn.all_gather(image_embeds)) + + # select a negative image (from all ranks) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text (from all ranks) for each image + input_ids_world = torch.cat(dist.all_gather(encoder_input_ids)) + att_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(input_ids_world[neg_idx]) + text_atts_neg.append(att_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0) + text_atts_all = torch.cat([text_attn_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + output_neg = self.text_backbone( + text_ids_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = torch.cat( + [ + output_pos.last_hidden_state[:, 0, :], + output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + + # create false data samples + data_samples.extend( + [DataSample(is_matched=False) for _ in range(2 * bs)]) + loss_multimodal = self.multimodal_head.loss((vl_embeddings, ), + data_samples) + + return dict(ChainMap(loss, loss_multimodal)) + + def predict(self, images, data_samples, cal_i2t=True, cal_t2i=True): + feats = self.extract_feat(images, data_samples) + + return self.predict_all( + feats, data_samples, cal_i2t=cal_i2t, cal_t2i=cal_t2i) + + def predict_all(self, + feats, + data_samples, + num_images=None, + num_texts=None, + cal_i2t=True, + cal_t2i=True): + text_ids = feats['text_ids'] + text_ids[:, 0] = self.tokenizer.additional_special_tokens_ids[0] + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + if not self.fast_match: + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + else: + image_embeds_all = None + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_ids_all = all_gather_concat(text_ids)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_ids_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_ids, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats, img_embeds, text_feats, + text_ids, text_atts): + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + + # compute i2t sim matrix + sim_matrix_i2t = img_feats @ text_feats.t() + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(img_feats.size(0)), 'Compute I2T scores...'): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + + encoder_output = img_embeds[i].repeat(self.topk, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + output = self.text_backbone( + text_ids[topk_idx], + attention_mask=text_atts[topk_idx], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, 0, :], ))[:, 1] + score_matrix_i2t[i, topk_idx] = score + topk_sim + + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats, img_embeds, text_feats, + text_ids, text_atts): + """Compare the score matrix for text-to-image retrieval. Every text + should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + + # compute t2i sim matrix + sim_matrix_t2i = text_feats @ img_feats.t() + if self.fast_match: + return sim_matrix_t2i + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(text_feats.size(0)), 'Compute T2I scores...'): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + + encoder_output = img_embeds[topk_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + output = self.text_backbone( + text_ids[i].repeat(self.topk, 1), + attention_mask=text_atts[i].repeat(self.topk, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, 0, :], ))[:, 1] + score_matrix_t2i[i, topk_idx] = score + topk_sim + + return score_matrix_t2i + + def _get_predictions(self, + result: torch.Tensor, + data_samples: List[DataSample], + mode: str = 'i2t'): + """Post-process the output of retriever. + + Args: + result (torch.Tensor): Score matrix of single retrieve, + either from image or text. + data_samples (List[DataSample], optional): The annotation + data of every samples. + mode (str): Retrieve mode, either `i2t` for image to text, or `t2i` + text to image. Defaults to `i2t`. + + Returns: + List[DataSample]: the raw data_samples with + the predicted results. + """ + + # create data sample if not exists + if data_samples is None: + data_samples = [DataSample() for _ in range(result.size(0))] + elif mode == 't2i': + # Process data samples to align with the num of texts. + new_data_samples = [] + for sample in data_samples: + if isinstance(sample.text, (list, tuple)): + texts = sample.text + else: + texts = [sample.text] + for i, text in enumerate(texts): + new_sample = DataSample(text=text) + if 'gt_image_id' in sample: + new_sample.gt_label = sample.gt_image_id[i] + new_data_samples.append(new_sample) + assert len(new_data_samples) == result.size(0) + data_samples = new_data_samples + elif mode == 'i2t': + for sample in data_samples: + if 'gt_text_id' in sample: + sample.gt_label = sample.gt_text_id + else: + raise ValueError(f'Type {mode} is not supported.') + + for data_sample, score in zip(data_samples, result): + idx = score.argmax(keepdim=True).detach() + + data_sample.set_pred_score(score) + data_sample.set_pred_label(idx) + return data_samples + + # TODO: add temperaily + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for (name, + param), (name_m, + param_m) in zip(model_pair[0].named_parameters(), + model_pair[1].named_parameters()): + # hack to behave the same + if any([i in name for i in ['8', '9', '10', '11'] + ]) and 'layers' in name and any( + [i in name for i in ['attn', 'ffn']]): + param_m.data = param.data + else: + param_m.data = param_m.data * self.momentum + \ + param.data * (1.0 - self.momentum) diff --git a/mmpretrain/models/multimodal/blip/blip_vqa.py b/mmpretrain/models/multimodal/blip/blip_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f4e5861b5c92be302cc48eaa7a37264be63f93 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_vqa.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class BlipVQA(BaseModel): + """BLIP VQA. + + Args: + tokenizer: (dict): The config for tokenizer. + vision_backbone (dict): Encoder for extracting image features. + multimodal_backbone (dict): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + head (dict): The head module to calculate + loss from processed features. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + `MutimodalDataPreprocessor` as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + multimodal_backbone: dict, + head: dict, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipVQA, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_backbone = MODELS.build(vision_backbone) + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.vqa_head = MODELS.build(head) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + + - "loss": For training. Forward and return a dict of losses according + to the given inputs and data samples. Note that this method doesn't + handle neither back propagation nor optimizer updating, which are + done in the :meth:`train_step`. + - "predict": For testing. Forward and return a list of data_sample that + contains pred_answer for each question. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation data of + every samples. Required when ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + - If ``mode="predict"``, return a list of `DataSample` + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from the input tensor with shape (N, C, ..). + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + + Returns: + visual_embeds (Tensor): The output features. + """ + # extract visual feature + if images.ndim == 4: + visual_embeds = self.vision_backbone(images)[0] + elif images.ndim == 5: + # [batch, T, C, H, W] -> [batch * T, C, H, W] + bs = images.size(0) + images = images.reshape(-1, *images.shape[2:]) + visual_embeds = self.vision_backbone(images)[0] + # [batch * num_segs, L, dim] -> [batch, num_segs * L, dim] + visual_embeds = visual_embeds.reshape(bs, -1, + *visual_embeds.shape[2:]) + else: + raise ValueError( + f'Images with {images.ndim} dims is not supported.') + return visual_embeds + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """generate train_loss from the input tensor and data_samples. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + visual_embeds = self.extract_feat(images) + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + questions = [] + for sample in data_samples: + questions.append(sample.get('question')) + questions = self.tokenizer( + questions, padding='longest', return_tensors='pt').to(self.device) + + questions.input_ids[:, 0] = \ + self.tokenizer.additional_special_tokens_ids[0] + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + answer_raw_text = [] + for sample in data_samples: + answer_raw_text.extend(sample.gt_answer) + answer = self.tokenizer( + answer_raw_text, padding='longest', + return_tensors='pt').to(self.device) + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + for sample in data_samples: + # follow BLIP setting, set answer_weight to 0.2 for VG dataset. + if not hasattr(sample, 'gt_answer_weight'): + sample.gt_answer_weight = torch.tensor([0.2]) + else: + sample.gt_answer_weight = torch.tensor(sample.gt_answer_weight) + answer_weight = torch.cat( + [sample.gt_answer_weight for sample in data_samples], + dim=0).to(self.device) + answer_count = torch.tensor( + [len(sample.gt_answer) for sample in data_samples]).to(self.device) + + question_states, question_atts = [], [] + for b, n in enumerate(answer_count): + question_states += [multimodal_embeds.last_hidden_state[b]] * n + question_atts += [questions.attention_mask[b]] * n + + question_states = torch.stack(question_states, dim=0).to(self.device) + question_atts = torch.stack(question_atts, dim=0).to(self.device) + + head_feats = dict( + answer_input_ids=answer.input_ids, + answer_attention_mask=answer.attention_mask, + answer_weight=answer_weight, + answer_targets=answer_targets, + question_states=question_states, + question_atts=question_atts, + batch_size=len(data_samples), + ) + + losses = self.vqa_head.loss(head_feats) + + return losses + + def predict( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ): + """update data_samples that contain pred_answer for each question. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + visual_embeds = self.extract_feat(images) + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + questions = [] + for sample in data_samples: + questions.append(sample.get('question')) + questions = self.tokenizer( + questions, padding='longest', return_tensors='pt').to(self.device) + + questions.input_ids[:, 0] = \ + self.tokenizer.additional_special_tokens_ids[0] + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + if self.vqa_head.inference_method == 'rank': + answer_candidates = self.tokenizer( + self.vqa_head.answer_list, + padding='longest', + return_tensors='pt').to(self.device) + answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id + elif self.vqa_head.inference_method == 'generate': + answer_candidates = None + + head_feats = dict( + multimodal_embeds=multimodal_embeds.last_hidden_state, + question_atts=questions.attention_mask, + answer_candidates=answer_candidates, + bos_token_id=self.tokenizer.bos_token_id, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + if self.vqa_head.inference_method == 'rank': + answers = self.vqa_head.predict(head_feats) + for answer, data_sample in zip(answers, data_samples): + data_sample.pred_answer = answer + + elif self.vqa_head.inference_method == 'generate': + outputs = self.vqa_head.predict(head_feats) + for output, data_sample in zip(outputs, data_samples): + data_sample.pred_answer = self.tokenizer.decode( + output, skip_special_tokens=True) + + return data_samples diff --git a/mmpretrain/models/multimodal/blip/language_model.py b/mmpretrain/models/multimodal/blip/language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..48605a95f60550e970f893f55c4a43e03efb74df --- /dev/null +++ b/mmpretrain/models/multimodal/blip/language_model.py @@ -0,0 +1,1320 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# flake8: noqa + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from torch import Tensor, device + +try: + from transformers.activations import ACT2FN + from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) + from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + from transformers.models.bert.configuration_bert import BertConfig +except: + ACT2FN = None + BaseModelOutputWithPastAndCrossAttentions = None + BaseModelOutputWithPoolingAndCrossAttentions = None + CausalLMOutputWithCrossAttentions = None + PreTrainedModel = None + apply_chunking_to_forward = None + find_pruneable_heads_and_indices = None + prune_linear_layer = None + BertConfig = None + +from mmpretrain.registry import MODELS + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + if config.add_type_embeddings: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + else: + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = ( + attention_scores + relative_position_scores_query + + relative_position_scores_key) + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ((context_layer, attention_probs) if output_attentions else + (context_layer, )) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config, twin=False, merge=False): + super().__init__() + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if twin: + self.dense0 = nn.Linear(config.hidden_size, config.hidden_size) + self.dense1 = nn.Linear(config.hidden_size, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if merge: + self.act = ACT2FN[config.hidden_act] + self.merge_layer = nn.Linear(config.hidden_size * 2, + config.hidden_size) + self.merge = True + else: + self.merge = False + + def forward(self, hidden_states, input_tensor): + if type(hidden_states) == list: + hidden_states0 = self.dense0(hidden_states[0]) + hidden_states1 = self.dense1(hidden_states[1]) + if self.merge: + hidden_states = self.merge_layer( + torch.cat([hidden_states0, hidden_states1], dim=-1)) + else: + hidden_states = (hidden_states0 + hidden_states1) / 2 + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False, layer_num=-1): + super().__init__() + is_nlvr = is_cross_attention and getattr(config, 'nlvr', False) + if is_nlvr: + self.self0 = BertSelfAttention(config, is_nlvr) + self.self1 = BertSelfAttention(config, is_nlvr) + else: + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput( + config, + twin=is_nlvr, + merge=(is_nlvr and layer_num >= 6), + ) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + if type(encoder_hidden_states) == list: + self_outputs0 = self.self0( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[0], + encoder_attention_mask[0], + past_key_value, + output_attentions, + ) + self_outputs1 = self.self1( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[1], + encoder_attention_mask[1], + past_key_value, + output_attentions, + ) + attention_output = self.output( + [self_outputs0[0], self_outputs1[0]], hidden_states) + + outputs = (attention_output, ) + self_outputs0[ + 1:] # add attentions if we output them + else: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + + # compatibility for ALBEF and BLIP + try: + # ALBEF & ALPRO + fusion_layer = self.config.fusion_layer + add_cross_attention = ( + fusion_layer <= layer_num and self.config.add_cross_attention) + + self.fusion_layer = fusion_layer + except AttributeError: + # BLIP + self.fusion_layer = self.config.num_hidden_layers + add_cross_attention = self.config.add_cross_attention + + # if self.config.add_cross_attention: + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, + is_cross_attention=self.config.add_cross_attention, + layer_num=layer_num, + ) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + # TODO line 482 in albef/models/xbert.py + # compatibility for ALBEF and BLIP + if mode in ['multimodal', 'fusion'] and hasattr( + self, 'crossattention'): + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = (outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = (() if output_attentions + and self.config.add_cross_attention else None) + + next_decoder_cache = () if use_cache else None + + try: + # ALBEF + fusion_layer = self.config.fusion_layer + except AttributeError: + # BLIP + fusion_layer = self.config.num_hidden_layers + + if mode == 'text': + start_layer = 0 + # output_layer = self.config.fusion_layer + output_layer = fusion_layer + + elif mode == 'fusion': + # start_layer = self.config.fusion_layer + start_layer = fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == 'multimodal': + start_layer = 0 + output_layer = self.config.num_hidden_layers + + # compatibility for ALBEF and BLIP + # for i in range(self.config.num_hidden_layers): + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + # TODO pay attention to this. + if self.gradient_checkpointing and self.training: + + if use_cache: + # TODO: logger here + # logger.warn( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@MODELS.register_module() +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + if not isinstance(config, BertConfig): + config = BertConfig.from_dict(config) + + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + ) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] + if past_key_values is not None else 0) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BaseEncoder(nn.Module): + """Base class for primitive encoders, such as ViT, TimeSformer, etc.""" + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +@MODELS.register_module() +class XBertEncoder(BertModel, BaseEncoder): + + def __init__(self, med_config, from_pretrained=False): + + med_config = BertConfig.from_dict(med_config) + super().__init__(config=med_config, add_pooling_layer=False) + + def forward_automask(self, tokenized_text, visual_embeds, **kwargs): + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + text = tokenized_text + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return text_output + + def forward_text(self, tokenized_text, **kwargs): + text = tokenized_text + token_type_ids = kwargs.get('token_type_ids', None) + + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + token_type_ids=token_type_ids, + return_dict=True, + mode='text', + ) + + return text_output + + +@MODELS.register_module() +class Linear(torch.nn.Linear): + """Wrapper for linear function.""" + + +@MODELS.register_module() +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, + BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained( + 'bert-base-cased') + >>> config = BertConfig.from_pretrained( + "bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer( + "Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@MODELS.register_module() +class XBertLMHeadDecoder(BertLMHeadModel): + """This class decouples the decoder forward logic from the VL model. + + In this way, different VL models can share this decoder as long as they + feed encoder_embeds as required. + """ + + def __init__(self, med_config): + self.med_config = BertConfig.from_dict(med_config) + super(XBertLMHeadDecoder, self).__init__(config=self.med_config) + + def generate_from_encoder(self, + tokenized_prompt, + visual_embeds, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + **kwargs): + + if not use_nucleus_sampling: + num_beams = num_beams + visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0) + + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + model_kwargs = { + 'encoder_hidden_states': visual_embeds, + 'encoder_attention_mask': image_atts, + } + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + return outputs diff --git a/mmpretrain/models/multimodal/blip2/Qformer.py b/mmpretrain/models/multimodal/blip2/Qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1c7d1e28711ae706ee4f3590cc5351c165fbae --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/Qformer.py @@ -0,0 +1,773 @@ +# flake8: noqa +""" + * Copyright (c) 2023, salesforce.com, inc. +""" +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import apply_chunking_to_forward +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +from mmpretrain.registry import MODELS +from ..blip.language_model import (BertAttention, BertIntermediate, + BertOnlyMLMHead, BertOutput, BertPooler, + BertPreTrainedModel) + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if (self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = (() if output_attentions + and self.config.add_cross_attention else None) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], + prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + if input_ids is None: + assert ( + query_embeds is not None + ), 'You have to specify query_embeds when input_ids is None' + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - + self.config.query_length if past_key_values is not None else 0) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + if self.cls is not None: + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 + tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1]:, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + query_embeds, + past=None, + attention_mask=None, + **model_kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'query_embeds': + query_embeds, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@MODELS.register_module() +class Qformer(BertLMHeadModel): + + def __init__(self, model_style: str, vision_model_width: int, + add_cross_attention: bool, cross_attention_freq: int, + num_query_token: int) -> None: + + config = BertConfig.from_pretrained(model_style) + config.add_cross_attention = add_cross_attention + config.encoder_width = vision_model_width + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + super().__init__(config) diff --git a/mmpretrain/models/multimodal/blip2/__init__.py b/mmpretrain/models/multimodal/blip2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5695f236caf74493fc6e851edbf2a4a05146b5f --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blip2_caption import Blip2Caption +from .blip2_opt_vqa import Blip2VQA +from .blip2_retriever import Blip2Retrieval +from .modeling_opt import OPTForCausalLM +from .Qformer import Qformer + +__all__ = [ + 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'OPTForCausalLM', 'Qformer' +] diff --git a/mmpretrain/models/multimodal/blip2/blip2_caption.py b/mmpretrain/models/multimodal/blip2/blip2_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..acf694827152ad47efd61d58f8361ea23834d68e --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_caption.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class Blip2Caption(BaseModel): + """BLIP2 Caption. + + Module for BLIP2 Caption task. + + Args: + vision_backbone (dict): The config dict for vision backbone. + text_backbone (dict): The config dict for text backbone. + multimodal_backbone (dict): The config dict for multimodal backbone. + vision_neck (dict): The config dict for vision neck. + tokenizer: (Optional[dict]): The config for tokenizer. + Defaults to None. + prompt (str): Prompt used for training and eval. + Defaults to ''. + max_txt_len (int): Max text length of input text. + num_captions (int): Number of captions to be generated for each image. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + _no_split_modules = ['BEiTViT', 'OPTDecoderLayer', 'BertLayer'] + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + multimodal_backbone: dict, + vision_neck: dict, + tokenizer: Optional[dict] = None, + prompt: str = '', + max_txt_len: int = 20, + num_captions: int = 1, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.eos_token_id = self.tokenizer( + '\n', add_special_tokens=False).input_ids[0] + + self.vision_backbone = MODELS.build(vision_backbone) + self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims) + + self.vision_neck = MODELS.build(vision_neck) + + self.text_backbone = MODELS.build(text_backbone) + + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.multimodal_backbone.cls = None + self.multimodal_backbone.bert.embeddings.word_embeddings = None + self.multimodal_backbone.bert.embeddings.position_embeddings = None + for layer in self.multimodal_backbone.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.prompt = prompt + self.max_txt_len = max_txt_len + self.num_captions = num_captions + prompt_tokens = self.tokenizer(prompt, return_tensors='pt') + self.prompt_length = prompt_tokens.attention_mask.sum(1) + + self.query_tokens = nn.Parameter( + torch.zeros(1, self.multimodal_backbone.bert.config.query_length, + self.multimodal_backbone.bert.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, + std=self.multimodal_backbone.bert.config.initializer_range) + + # freeze the text backbone + for _, param in self.text_backbone.named_parameters(): + param.requires_grad = False + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook( + self._ignore_loading_llm_keys_hook) + + if hasattr(self, '_register_state_dict_hook'): + self._register_state_dict_hook(self._igonre_saving_llm_keys_hook) + + def forward(self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss'): + """The unified entry for a forward process in both training and test. + The method should accept two modes: "predict" and "loss": + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): pre_processed img tensor (N, C, ...). + data_samples (List[DataSample], optional): + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def loss(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``loss`` + method of :attr:`head`. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # extract image features + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + self.tokenizer.padding_side = 'right' + + prompt = [ + self.prompt + data_sample.gt_caption + '\n' + for data_sample in data_samples + ] + + opt_tokens = self.tokenizer( + prompt, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + ).to(images.device) + + targets = opt_tokens.input_ids.masked_fill( + opt_tokens.input_ids == self.tokenizer.pad_token_id, -100) + if self.prompt: + targets[:, :self.prompt_length] = -100 + + empty_targets = ( + torch.ones(attns_opt.size(), + dtype=torch.long).to(images.device).fill_(-100)) + targets = torch.cat([empty_targets, targets], dim=1) + + inputs_embeds = ( + self.text_backbone.model.decoder.embed_tokens( + opt_tokens.input_ids)) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + outputs = self.text_backbone( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {'loss': loss} + + def predict(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> List[DataSample]: + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + + # extract image features + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + prompt = [self.prompt] * image_embeds.size(0) + + opt_tokens = self.tokenizer( + prompt, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + ).to(images.device) + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + inputs_embeds = ( + self.text_backbone.get_input_embeddings()(opt_tokens.input_ids)) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + + outputs = self.text_backbone.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=False, + top_p=0.9, + temperature=1., + num_beams=5, + max_new_tokens=self.max_txt_len, + min_length=1, + eos_token_id=self.eos_token_id, + repetition_penalty=1.0, + length_penalty=1.0, + num_return_sequences=self.num_captions, + ) + + output_text = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(len(output_text))] + + for data_sample, decode_token in zip(data_samples, output_text): + if data_sample is None: + data_sample = DataSample() + data_sample.pred_caption = decode_token + out_data_samples.append(data_sample) + + return out_data_samples + + @staticmethod + def _ignore_loading_llm_keys_hook(module, incompatible_keys): + """Avoid warning missing keys of the LLM model.""" + import re + llm_pattern = '^text_backbone' + for key in list(incompatible_keys.missing_keys): + if re.match(llm_pattern, key): + incompatible_keys.missing_keys.remove(key) + + @staticmethod + def _igonre_saving_llm_keys_hook(module, state_dict, prefix, metadata): + """Avoid saving llm state dict.""" + import re + llm_pattern = '^text_backbone' + keys = [k for k, _ in state_dict.items()] + for key in keys: + if re.match(llm_pattern, key): + state_dict.pop(key) diff --git a/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py b/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..20e439fa826725a80462557faab8ae25a8e5660e --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .blip2_caption import Blip2Caption + + +@MODELS.register_module() +class Blip2VQA(Blip2Caption): + """BLIP2 VQA. + + Module for BLIP2 VQA task. For more details about the initialization + params, please refer to :class:`Blip2Caption`. + """ + + def predict(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> List[DataSample]: + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + questions = [d.question for d in data_samples] + + # extract image features from + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + prompt = [self.prompt.format(q) for q in questions] + + # use left padding + self.tokenizer.padding_side = 'left' + + opt_tokens = self.tokenizer( + prompt, return_tensors='pt', padding='longest').to(images.device) + input_ids = opt_tokens.input_ids + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + inputs_embeds = self.text_backbone.model.decoder.embed_tokens( + input_ids) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + + outputs = self.text_backbone.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=False, + num_beams=5, + max_new_tokens=self.max_txt_len, + min_length=1, + eos_token_id=self.eos_token_id, + length_penalty=-1.0, + ) + + output_text = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + out_data_samples = [] + for data_sample, decode_token in zip(data_samples, output_text): + data_sample.pred_answer = decode_token + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/blip2/blip2_retriever.py b/mmpretrain/models/multimodal/blip2/blip2_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..e626404a4cde5798151a0fa9589716470ed928a9 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_retriever.py @@ -0,0 +1,505 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import track_iter_progress + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ..blip.blip_retrieval import BlipRetrieval, all_gather_concat + + +@MODELS.register_module() +class Blip2Retrieval(BlipRetrieval): + """BLIP2 Retriever. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. + vision_neck (Optional[dict]): The neck module to process image features + from vision backbone. Defaults to None. + text_neck (Optional[dict]): The neck module to process text features + from text backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed single modality features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal + head module to calculate loss from processed multimodal features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + tokenizer (Optional[dict]): The config for tokenizer. Defaults to None. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + topk (int): Select topk similarity as candidates for compute matching + scores. Notice that this is not the topk in evaluation. + Defaults to 256. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: Optional[dict] = None, + multimodal_backbone: Optional[dict] = None, + vision_neck: Optional[dict] = None, + text_neck: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + multimodal_head: Optional[Union[List[dict], dict]] = None, + tokenizer: Optional[dict] = None, + temperature: float = 0.07, + fast_match: bool = False, + topk: int = 256, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + # Skip BlipRetrieval init + super(BlipRetrieval, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.vision_backbone = MODELS.build(vision_backbone) + self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims) + self.tokenizer = TOKENIZER.build(tokenizer) + + if text_backbone is not None: + self.text_backbone = MODELS.build(text_backbone) + + if multimodal_backbone is not None: + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.multimodal_backbone.resize_token_embeddings( + len(self.tokenizer)) + self.query_tokens = nn.Parameter( + torch.zeros(1, self.multimodal_backbone.bert.config.query_length, + self.multimodal_backbone.bert.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, + std=self.multimodal_backbone.bert.config.initializer_range) + + if vision_neck is not None: + self.vision_neck = MODELS.build(vision_neck) + + if text_neck is not None: + self.text_neck = MODELS.build(text_neck) + + if head is not None: + self.head = MODELS.build(head) + + if multimodal_head is not None: + self.multimodal_head = MODELS.build(multimodal_head) + + self.temp = nn.Parameter(temperature * torch.ones([])) + + # Notice that this topk is used for select k candidate to compute + # image-text score, but not the final metric topk in evaluation. + self.fast_match = fast_match + self.topk = topk + + def _extract_feat(self, inputs: Union[torch.Tensor, dict], + modality: str) -> Tuple[torch.Tensor]: + """Extract features from the single modality. + Args: + inputs (Union[torch.Tensor, dict]): A batch of inputs. + For image, a tensor of shape (N, C, ...) in general. + For text, a dict of tokenized text inputs. + modality (str): Modality feature to be extracted. Only two + options are supported. + + - ``images``: Only extract image features, mostly used for + inference. + - ``texts``: Only extract text features, mostly used for + inference. + Returns: + Tuple[torch.Tensor]: The output features. + """ + if modality == 'images': + # extract image features + # TODO: + # Add layernorm inside backbone and handle the concat outside + image_embeds = self.ln_vision_backbone( + self.vision_backbone(inputs)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, + -1) + query_output = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + use_cache=True, + return_dict=True, + ) + image_feat = F.normalize( + self.vision_neck([query_output.last_hidden_state]), dim=-1) + return { + 'image_embeds': image_embeds, + 'image_feat': image_feat, + 'query_output': query_output + } + elif modality == 'texts': + # extract text features + text_output = self.multimodal_backbone.bert( + inputs.input_ids, + attention_mask=inputs.attention_mask, + return_dict=True, + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_neck([text_embeds[:, 0, :]]), dim=-1) + return {'text_embeds': text_embeds, 'text_feat': text_feat} + else: + raise RuntimeError(f'Invalid modality "{modality}".') + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + both head and multimodal head. + """ + output = self.extract_feat(images, data_samples) + + text_ids = output['text_ids'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + query_output = output['query_output'] + + # ITC Loss + # B*world_size, num_query, D + image_feat_all = torch.cat(dist.all_gather(image_feat)) + # B*world_size, D + text_feat_all = torch.cat(dist.all_gather(text_feat)) + + # B, B*world_size, num_query + sim_q2t = torch.matmul( + image_feat.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze() + + # image to text similarity + sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # B, B*world_size, num_query + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), + image_feat_all.permute(0, 2, 1)).squeeze() + + # text-image similarity + sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp + + rank = dist.get_rank() + bs = images.size(0) + targets = torch.linspace( + rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) + + itc_loss = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2 + + # prepare for itm + text_input_ids_world = torch.cat(dist.all_gather(text_ids)) + text_attention_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + image_embeds_world = torch.cat(dist.all_gather(image_embeds)) + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 + weights_t2i[:, rank * bs:rank * bs + bs].fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 + weights_i2t[:, rank * bs:rank * bs + bs].fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(text_input_ids_world[neg_idx]) + text_atts_neg.append(text_attention_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([text_ids, text_ids, text_ids_neg], + dim=0) # pos, pos, neg + text_atts_all = torch.cat( + [text_attn_mask, text_attn_mask, text_atts_neg], + dim=0, + ) + + query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, + -1) + query_atts_itm = torch.ones( + query_tokens_itm.size()[:-1], dtype=torch.long).to(self.device) + attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) + + image_embeds_all = torch.cat( + [image_embeds, image_embeds_neg, image_embeds], + dim=0) # pos, neg, pos + image_atts_all = torch.ones( + image_embeds_all.size()[:-1], dtype=torch.long).to(self.device) + + output_itm = self.multimodal_backbone.bert( + text_ids_all, + query_embeds=query_tokens_itm, + attention_mask=attention_mask_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = output_itm.last_hidden_state[:, :query_tokens_itm. + size(1), :] + + # create false data samples + data_samples.extend( + [DataSample(is_matched=False) for _ in range(2 * bs)]) + loss_multimodal = self.multimodal_head.loss((vl_embeddings, ), + data_samples) + + # LM loss + decoder_input_ids = text_ids.clone() + decoder_input_ids[:, 0] = self.tokenizer.bos_token_id + labels = decoder_input_ids.masked_fill( + decoder_input_ids == self.tokenizer.pad_token_id, -100) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text_attn_mask], dim=1) + lm_output = self.multimodal_backbone( + decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output.past_key_values, + return_dict=True, + labels=labels, + ) + + return dict( + itc_loss=itc_loss, **loss_multimodal, lm_loss=lm_output.loss) + + def predict_all(self, + feats: Dict[str, torch.Tensor], + data_samples: List[DataSample], + num_images: int = None, + num_texts: int = None, + cal_i2t: bool = True, + cal_t2i: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute similarity matrix between images and texts across all ranks. + + Args: + feats (Dict[str, torch.Tensor]): Features from the current rank. + data_samples (List[DataSample]): Data samples from the current + rank. + num_images (int, optional): Number of images to use. + Defaults to None. + num_texts (int, optional): Number of texts to use. + Defaults to None. + cal_i2t (bool, optional): Whether to compute image-to-text + similarity. Defaults to True. + cal_t2i (bool, optional): Whether to compute text-to-image + similarity. Defaults to True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Image-to-text and text-to-image + similarity matrices. + """ + text_ids = feats['text_ids'] + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + if not self.fast_match: + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + else: + image_embeds_all = None + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_ids_all = all_gather_concat(text_ids)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_ids_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_ids, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats: torch.Tensor, + img_embeds: List[torch.Tensor], + text_feats: torch.Tensor, + text_ids: torch.Tensor, + text_atts: torch.Tensor) -> torch.Tensor: + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input tensor with shape (M, C). + M stands for numbers of samples on a single GPU. + img_embeds (List[torch.Tensor]): Image features from each layer of + the vision backbone. + text_feats (torch.Tensor): The input tensor with shape (N, C). + N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + + # compute i2t sim matrix + # TODO: check correctness + sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1) + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + + for i in track_iter_progress(range(img_feats.size(0))): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + # get repeated image embeddings + encoder_output = img_embeds[i].repeat(self.topk, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + # query embeds and attention masks + query_tokens = self.query_tokens.expand(encoder_output.shape[0], + -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text_atts[topk_idx]], + dim=1) + output = self.multimodal_backbone.bert( + text_ids[topk_idx], + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, :query_tokens.size(1), :], + ))[:, :, 1].mean(dim=1) + score_matrix_i2t[i, topk_idx] = score + topk_sim + + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats: torch.Tensor, + img_embeds: List[torch.Tensor], + text_feats: torch.Tensor, + text_ids: torch.Tensor, + text_atts: torch.Tensor) -> torch.Tensor: + """Compare the score matrix for text-to-image retrieval. + + Every text should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input tensor with shape (N, C). + N stands for numbers of all samples on all GPUs. + img_embeds (List[torch.Tensor]): Image features from each layer of + the vision backbone. + text_feats (torch.Tensor): The input tensor with shape (M, C). + M stands for numbers of samples on a single GPU. + text_ids (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + + # compute t2i sim matrix + # TODO: check correctness + sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1) + sim_matrix_t2i = sim_matrix_i2t.t() + if self.fast_match: + return sim_matrix_i2t + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + + for i in track_iter_progress(range(text_feats.size(0))): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + # get topk image embeddings + encoder_output = img_embeds[topk_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + # get query embeds and attention masks + query_tokens = self.query_tokens.expand(encoder_output.shape[0], + -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat( + [query_atts, text_atts[i].repeat(self.topk, 1)], dim=1) + output = self.multimodal_backbone.bert( + text_ids[i].repeat(self.topk, 1), + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, :query_tokens.size(1), :], + ))[:, :, 1].mean(dim=1) + score_matrix_t2i[i, topk_idx] = score + topk_sim + + return score_matrix_t2i diff --git a/mmpretrain/models/multimodal/blip2/modeling_opt.py b/mmpretrain/models/multimodal/blip2/modeling_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..7cde0d76a2079a610bd71ed034c0c88940244e76 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/modeling_opt.py @@ -0,0 +1,1083 @@ +# flake8: noqa +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +from mmpretrain.models.utils import register_hf_model + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'facebook/opt-350m' +_CONFIG_FOR_DOC = 'OPTConfig' +_TOKENIZER_FOR_DOC = 'GPT2Tokenizer' + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'facebook/opt-125m', + 'facebook/opt-350m', + 'facebook/opt-1.3b', + 'facebook/opt-2.7b', + 'facebook/opt-6.7b', + 'facebook/opt-13b', + 'facebook/opt-30b', + # See all OPT models at https://huggingface.co/models?filter=opt +] + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + """Make causal mask used for bi-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, + src_seq_len]`.""" + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + torch.finfo(dtype).min) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """This module learns positional embeddings up to a fixed maximum size.""" + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * + attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}' + f' and `num_heads`: {num_heads}).') + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return (tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous()) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel.""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f'Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is' + f' {attn_weights.size()}') + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}' + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask) + attn_weights = torch.max( + attn_weights, + torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads, ): + raise ValueError( + f'Head mask for a single layer should be of size {(self.num_heads,)}, but is' + f' {layer_head_mask.size()}') + attn_weights = layer_head_mask.view( + 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare OPT Model outputting raw hidden-states without any specific head on top.', + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + + config_class = OPTConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['OPTDecoderLayer'] + _keys_to_ignore_on_load_unexpected = [r'decoder\.version'] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (OPTDecoder)): + module.gradient_checkpointing = value + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers* layers. + Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.word_embed_proj_dim, + self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear( + config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear( + config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + 'You have to specify either decoder_input_ids or decoder_inputs_embeds' + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + input_shape = inputs_embeds.size()[:-1] + else: + input_shape = (batch_size, seq_length) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + inputs_embeds.shape[:2], + dtype=torch.bool, + device=inputs_embeds.device) + pos_embeds = self.embed_positions(attention_mask, + past_key_values_length) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ['head_mask']): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f'The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for' + f' {head_mask.size()[0]}.') + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + 'The bare OPT Model outputting raw hidden-states without any specific head on top.', + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +@register_hf_model() +class OPTForCausalLM(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r'lm_head.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear( + config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = 'mean', + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import GPT2Tokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + logits = logits[:, -labels.size(1):, :] + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + loss = loss_fct( + shift_logits.view(-1, self.config.vocab_size), + shift_labels.view(-1)) + if reduction == 'none': + loss = loss.view(shift_logits.size(0), -1).sum(1) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + query_embeds=None, + past_key_values=None, + attention_mask=None, + use_cache=None, + **kwargs, + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + if input_ids is not None: + attention_mask = input_ids.new_ones(input_ids.shape) + if past_key_values: + input_ids = input_ids[:, -1:] + query_embeds = None + # first step, decoder_cached_states are empty + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'query_embeds': query_embeds, + 'attention_mask': attention_mask, + 'past_key_values': past_key_values, + 'use_cache': use_cache, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/chinese_clip/__init__.py b/mmpretrain/models/multimodal/chinese_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..460e9e6a6be748113df029ad76bc0934ab7704d3 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bert import BertModelCN +from .chinese_clip import ChineseCLIP, ModifiedResNet + +__all__ = ['ChineseCLIP', 'ModifiedResNet', 'BertModelCN'] diff --git a/mmpretrain/models/multimodal/chinese_clip/bert.py b/mmpretrain/models/multimodal/chinese_clip/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8dc7322a9aaddb0f5e02f8b70597ba08a8b925 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/bert.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + +# flake8: noqa +import math + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +try: + from transformers.models.bert.configuration_bert import BertConfig +except: + BertConfig = None + +from mmpretrain.registry import MODELS +from ..blip.language_model import BertAttention, BertIntermediate, BertOutput + + +def gelu(x): + """Original Implementation of the gelu activation function in Google Bert + repo when initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives + slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ # noqa + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu_new(x): + """Implementation of the gelu activation function currently in Google Bert + repo (identical to OpenAI GPT) https://arxiv.org/abs/1606.08415.""" + return 0.5 * x * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = { + 'gelu': gelu, + 'relu': torch.nn.functional.relu, + 'swish': swish, + 'gelu_new': gelu_new +} + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type + embeddings.""" + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings \ + + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[ + 1:] # add attentions if we output them + if len(outputs) == 1: + return outputs[0] + return outputs + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.grad_checkpointing = False + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + layer_outputs = checkpoint(layer_module, hidden_states, + attention_mask, head_mask[i]) + else: + layer_outputs = layer_module(hidden_states, attention_mask, + head_mask[i]) + if not isinstance(layer_outputs, tuple): + layer_outputs = (layer_outputs, ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + # last-layer hidden state, (all hidden states), (all attentions) + return outputs + + +class BertPreTrainedModel(nn.Module): + base_model_prefix = 'bert' + + def __init__(self, config): + super(BertPreTrainedModel, self).__init__() + self.config = config + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@MODELS.register_module() +class BertModelCN(BertPreTrainedModel): + """The BERT model implementation for Chinese CLIP.""" + + def __init__(self, config): + config = BertConfig.from_dict(config) + super(BertModelCN, self).__init__(config) + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.apply(self._init_weights) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + if enable: + assert not self.config.output_attentions, \ + 'Grad checkpointing is currently conflict with ' \ + 'output_attentions for BertEncoder, ' \ + 'please set it to False in BertConfig' + + self.encoder.grad_checkpointing = enable + + def forward(self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( + -1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, + -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters( + )).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + encoder_outputs = self.encoder( + embedding_output, extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + # pooled_output = self.pooler(sequence_output) + pooled_output = None + + # add hidden_states and attentions if they are here + outputs = ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + # sequence_output, pooled_output, (hidden_states), (attentions) + return outputs diff --git a/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py b/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..40af5643602685be4d0e37331609bdecae184de9 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py @@ -0,0 +1,446 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel, BaseModule +from torch import nn + +from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import OPENAI_PROMPT + +PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN} +PROMPT_MAP = {'openai': OPENAI_PROMPT} + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +@MODELS.register_module() +class ModifiedResNet(BaseModule): + """A modified ResNet contains the following changes: + + - Apply deep stem with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is + prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ # noqa + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth: int = 50, + base_channels: int = 64, + input_size: int = 224, + num_attn_heads: int = 32, + output_dim: int = 1024, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.input_size = input_size + self.block, stage_blocks = self.arch_settings[depth] + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, + base_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(base_channels // 2) + self.conv2 = nn.Conv2d( + base_channels // 2, + base_channels // 2, + kernel_size=3, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(base_channels // 2) + self.conv3 = nn.Conv2d( + base_channels // 2, + base_channels, + kernel_size=3, + padding=1, + bias=False) + self.bn3 = nn.BatchNorm2d(base_channels) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + # this is a *mutable* variable used during construction + self._inplanes = base_channels + self.layer1 = self._make_layer(base_channels, stage_blocks[0]) + self.layer2 = self._make_layer( + base_channels * 2, stage_blocks[1], stride=2) + self.layer3 = self._make_layer( + base_channels * 4, stage_blocks[2], stride=2) + self.layer4 = self._make_layer( + base_channels * 8, stage_blocks[3], stride=2) + + embed_dim = base_channels * 32 + self.attnpool = AttentionPool2d(input_size // 32, embed_dim, + num_attn_heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +@MODELS.register_module() +class ChineseCLIP(BaseModel): + """The implementation of `ChineseCLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. Defaults to 'openai'. + context_length (int): The context length to use. Defaults to 52. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + tokenizer: dict, + proj_dim: int, + text_prototype: Union[str, List[str]], + text_prompt: str = 'openai', + context_length: int = 52, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.vision_backbone = MODELS.build(vision_backbone) + self.text_backbone = MODELS.build(text_backbone) + + if not isinstance(self.vision_backbone, ModifiedResNet): + self.vision_projection = nn.Parameter( + torch.empty(self.vision_backbone.embed_dims, proj_dim)) + text_hidden_size = text_backbone['config']['hidden_size'] + self.text_projection = nn.Parameter( + torch.empty(text_hidden_size, proj_dim)) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.context_length = context_length + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + if isinstance(self.vision_backbone, ModifiedResNet): + return self.vision_backbone(images) + return self.vision_backbone(images)[-1] @ self.vision_projection + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + pad_index = self.tokenizer.vocab['[PAD]'] + attn_mask = texts.ne(pad_index) + # [batch_size, seq_length, hidden_size] + x = self.text_backbone(texts, attention_mask=attn_mask)[0] + return x[:, 0, :] @ self.text_projection + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['[CLS]']] + + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['[SEP]']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/mmpretrain/models/multimodal/chinese_clip/utils.py b/mmpretrain/models/multimodal/chinese_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6964722bd3dbb05a6a59a1dc2c57c0a6e8692c31 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/utils.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +OPENAI_PROMPT = [ + lambda c: f'{c}的照片', + lambda c: f'质量差的{c}的照片', + lambda c: f'许多{c}的照片', + lambda c: f'{c}的雕塑', + lambda c: f'难以看到{c}的照片', + lambda c: f'{c}的低分辨率照片', + lambda c: f'{c}的渲染', + lambda c: f'涂鸦{c}', + lambda c: f'{c}的糟糕照片', + lambda c: f'{c}的裁剪照片', + lambda c: f'{c}的纹身', + lambda c: f'{c}的刺绣照片', + lambda c: f'很难看到{c}的照片', + lambda c: f'{c}的明亮照片', + lambda c: f'一张干净的{c}的照片', + lambda c: f'一张包含{c}的照片', + lambda c: f'{c}的深色照片', + lambda c: f'{c}的手绘画', + lambda c: f'我的{c}的照片', + lambda c: f'不自然的{c}的照片', + lambda c: f'一张酷的{c}的照片', + lambda c: f'{c}的特写照片', + lambda c: f'{c}的黑白照片', + lambda c: f'一幅{c}的画', + lambda c: f'一幅{c}的绘画', + lambda c: f'一张{c}的像素照片', + lambda c: f'{c}的雕像', + lambda c: f'一张{c}的明亮照片', + lambda c: f'{c}的裁剪照片', + lambda c: f'人造的{c}的照片', + lambda c: f'一张关于{c}的照片', + lambda c: f'损坏的{c}的jpeg照片', + lambda c: f'{c}的模糊照片', + lambda c: f'{c}的相片', + lambda c: f'一张{c}的好照片', + lambda c: f'{c}的渲染照', + lambda c: f'视频游戏中的{c}', + lambda c: f'一张{c}的照片', + lambda c: f'{c}的涂鸦', + lambda c: f'{c}的近距离照片', + lambda c: f'{c}的折纸', + lambda c: f'{c}在视频游戏中', + lambda c: f'{c}的草图', + lambda c: f'{c}的涂鸦照', + lambda c: f'{c}的折纸形状', + lambda c: f'低分辨率的{c}的照片', + lambda c: f'玩具{c}', + lambda c: f'{c}的副本', + lambda c: f'{c}的干净的照片', + lambda c: f'一张大{c}的照片', + lambda c: f'{c}的重现', + lambda c: f'一张漂亮的{c}的照片', + lambda c: f'一张奇怪的{c}的照片', + lambda c: f'模糊的{c}的照片', + lambda c: f'卡通{c}', + lambda c: f'{c}的艺术作品', + lambda c: f'{c}的素描', + lambda c: f'刺绣{c}', + lambda c: f'{c}的像素照', + lambda c: f'{c}的拍照', + lambda c: f'{c}的损坏的照片', + lambda c: f'高质量的{c}的照片', + lambda c: f'毛绒玩具{c}', + lambda c: f'漂亮的{c}的照片', + lambda c: f'小{c}的照片', + lambda c: f'照片是奇怪的{c}', + lambda c: f'漫画{c}', + lambda c: f'{c}的艺术照', + lambda c: f'{c}的图形', + lambda c: f'大{c}的照片', + lambda c: f'黑白的{c}的照片', + lambda c: f'{c}毛绒玩具', + lambda c: f'一张{c}的深色照片', + lambda c: f'{c}的摄影图', + lambda c: f'{c}的涂鸦照', + lambda c: f'玩具形状的{c}', + lambda c: f'拍了{c}的照片', + lambda c: f'酷酷的{c}的照片', + lambda c: f'照片里的小{c}', + lambda c: f'{c}的刺青', + lambda c: f'{c}的可爱的照片', + lambda c: f'一张{c}可爱的照片', + lambda c: f'{c}可爱图片', + lambda c: f'{c}酷炫图片', + lambda c: f'一张{c}的酷炫的照片', + lambda c: f'一张{c}的酷炫图片', + lambda c: f'这是{c}', + lambda c: f'{c}的好看照片', + lambda c: f'一张{c}的好看的图片', + lambda c: f'{c}的好看图片', + lambda c: f'{c}的照片。', + lambda c: f'质量差的{c}的照片。', + lambda c: f'许多{c}的照片。', + lambda c: f'{c}的雕塑。', + lambda c: f'难以看到{c}的照片。', + lambda c: f'{c}的低分辨率照片。', + lambda c: f'{c}的渲染。', + lambda c: f'涂鸦{c}。', + lambda c: f'{c}的糟糕照片。', + lambda c: f'{c}的裁剪照片。', + lambda c: f'{c}的纹身。', + lambda c: f'{c}的刺绣照片。', + lambda c: f'很难看到{c}的照片。', + lambda c: f'{c}的明亮照片。', + lambda c: f'一张干净的{c}的照片。', + lambda c: f'一张包含{c}的照片。', + lambda c: f'{c}的深色照片。', + lambda c: f'{c}的手绘画。', + lambda c: f'我的{c}的照片。', + lambda c: f'不自然的{c}的照片。', + lambda c: f'一张酷的{c}的照片。', + lambda c: f'{c}的特写照片。', + lambda c: f'{c}的黑白照片。', + lambda c: f'一幅{c}的画。', + lambda c: f'一幅{c}的绘画。', + lambda c: f'一张{c}的像素照片。', + lambda c: f'{c}的雕像。', + lambda c: f'一张{c}的明亮照片。', + lambda c: f'{c}的裁剪照片。', + lambda c: f'人造的{c}的照片。', + lambda c: f'一张关于{c}的照片。', + lambda c: f'损坏的{c}的jpeg照片。', + lambda c: f'{c}的模糊照片。', + lambda c: f'{c}的相片。', + lambda c: f'一张{c}的好照片。', + lambda c: f'{c}的渲染照。', + lambda c: f'视频游戏中的{c}。', + lambda c: f'一张{c}的照片。', + lambda c: f'{c}的涂鸦。', + lambda c: f'{c}的近距离照片。', + lambda c: f'{c}的折纸。', + lambda c: f'{c}在视频游戏中。', + lambda c: f'{c}的草图。', + lambda c: f'{c}的涂鸦照。', + lambda c: f'{c}的折纸形状。', + lambda c: f'低分辨率的{c}的照片。', + lambda c: f'玩具{c}。', + lambda c: f'{c}的副本。', + lambda c: f'{c}的干净的照片。', + lambda c: f'一张大{c}的照片。', + lambda c: f'{c}的重现。', + lambda c: f'一张漂亮的{c}的照片。', + lambda c: f'一张奇怪的{c}的照片。', + lambda c: f'模糊的{c}的照片。', + lambda c: f'卡通{c}。', + lambda c: f'{c}的艺术作品。', + lambda c: f'{c}的素描。', + lambda c: f'刺绣{c}。', + lambda c: f'{c}的像素照。', + lambda c: f'{c}的拍照。', + lambda c: f'{c}的损坏的照片。', + lambda c: f'高质量的{c}的照片。', + lambda c: f'毛绒玩具{c}。', + lambda c: f'漂亮的{c}的照片。', + lambda c: f'小{c}的照片。', + lambda c: f'照片是奇怪的{c}。', + lambda c: f'漫画{c}。', + lambda c: f'{c}的艺术照。', + lambda c: f'{c}的图形。', + lambda c: f'大{c}的照片。', + lambda c: f'黑白的{c}的照片。', + lambda c: f'{c}毛绒玩具。', + lambda c: f'一张{c}的深色照片。', + lambda c: f'{c}的摄影图。', + lambda c: f'{c}的涂鸦照。', + lambda c: f'玩具形状的{c}。', + lambda c: f'拍了{c}的照片。', + lambda c: f'酷酷的{c}的照片。', + lambda c: f'照片里的小{c}。', + lambda c: f'{c}的刺青。', + lambda c: f'{c}的可爱的照片。', + lambda c: f'一张{c}可爱的照片。', + lambda c: f'{c}可爱图片。', + lambda c: f'{c}酷炫图片。', + lambda c: f'一张{c}的酷炫的照片。', + lambda c: f'一张{c}的酷炫图片。', + lambda c: f'这是{c}。', + lambda c: f'{c}的好看照片。', + lambda c: f'一张{c}的好看的图片。', + lambda c: f'{c}的好看图片。', + lambda c: f'一种叫{c}的花的照片', + lambda c: f'一种叫{c}的食物的照片', + lambda c: f'{c}的卫星照片', +] diff --git a/mmpretrain/models/multimodal/clip/__init__.py b/mmpretrain/models/multimodal/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a117ea7ca57ce30d7ad304103220f7af84e7c0 --- /dev/null +++ b/mmpretrain/models/multimodal/clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..clip.clip import CLIP, CLIPZeroShot +from ..clip.clip_transformer import CLIPProjection, CLIPTransformer + +__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection'] diff --git a/mmpretrain/models/multimodal/clip/clip.py b/mmpretrain/models/multimodal/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..b509a63b3be964000232a006da33243d9f93f84b --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES, + IMAGENET_SIMPLE_CATEGORIES) +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT, + OPENAI_IMAGENET_PROMPT_SUB) + +CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] +PROTOTYPE_MAP = { + 'imagenet': IMAGENET_SIMPLE_CATEGORIES, + 'cifar100': CIFAR100_CATEGORIES, +} +PROMPT_MAP = { + 'openai_imagenet': OPENAI_IMAGENET_PROMPT, + 'openai_cifar100': OPENAI_CIFAR100_PROMPT, + 'vanilla': [lambda c: f'a photo of a {c}'], + 'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class CLIP(BaseModel): + """The implementation of `CLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. + Defaults to 'vanilla',which refers to "a photo of {cls}". + context_length (int): The context length to use. Defaults to 77. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.context_length = context_length + + # build the vision transformer + self.visual = MODELS.build(vision_backbone) + + # build the visual projection + self.visual_proj = MODELS.build(projection) + + # build attn_mask for casual-attn + text_backbone['attn_mask'] = self.build_attention_mask() + + # build the text transformer + self.transformer = MODELS.build(text_backbone) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, proj_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + self.tokenizer = TOKENIZER.build(tokenizer) + + self.tokenizer.vocab = self.tokenizer.get_vocab( + ) # CLIPTokenizer has no attribute named 'vocab', so manually + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, + # with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + return self.visual_proj(self.visual(images))[0] + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + x = self.token_embedding(texts) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x)[0] + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + texts.argmax(dim=-1)] @ self.text_projection + + return x + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + # text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['<|startoftext|>'] + ] + # <|startoftext|>代表[CLS] token + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['<|endoftext|>']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +@MODELS.register_module() +class CLIPZeroShot(CLIP): + + def __init__( + self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + text_prototype: Union[str, List[str]] = 'imagenet', + text_prompt: str = 'vanilla', + ): + super(CLIPZeroShot, + self).__init__(vision_backbone, projection, text_backbone, + tokenizer, vocab_size, transformer_width, + proj_dim, context_length, data_preprocessor, + init_cfg) + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) diff --git a/mmpretrain/models/multimodal/clip/clip_transformer.py b/mmpretrain/models/multimodal/clip/clip_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5f76661cbc3317a04e17f11680266bc44ea3eb --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip_transformer.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from typing import Optional, Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.models.utils.clip_generator_helper import \ + ResidualAttentionBlock +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CLIPTransformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +@MODELS.register_module() +class CLIPProjection(BaseModule): + """Neck with CLIP Projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + init_cfg: Optional[dict] = None): + super(CLIPProjection, self).__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + scale = in_channels**-0.5 + self.proj = nn.Parameter(scale * + torch.randn(in_channels, out_channels)) + + def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Tuple): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + Returns: + Tuple(torch.Tensor)): A tuple of reducted features. + """ + if isinstance(inputs, tuple): + inputs = inputs[-1] + out = inputs @ self.proj + elif isinstance(inputs, torch.Tensor): + out = inputs @ self.proj + else: + raise TypeError( + '`CLIPProjection` neck inputs should be tuple or torch.tensor') + return (out, ) diff --git a/mmpretrain/models/multimodal/clip/utils.py b/mmpretrain/models/multimodal/clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..65239bc37d6c26826c4fe1cbaffb35a45cd948fd --- /dev/null +++ b/mmpretrain/models/multimodal/clip/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +OPENAI_CIFAR100_PROMPT = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +OPENAI_IMAGENET_PROMPT_SUB = [ + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +] + +OPENAI_IMAGENET_PROMPT = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] diff --git a/mmpretrain/models/multimodal/flamingo/__init__.py b/mmpretrain/models/multimodal/flamingo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0bfd63b657f5f0f1517ad6d31bce2821cb372cd --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adapter import FlamingoLMAdapter +from .flamingo import Flamingo + +__all__ = ['Flamingo', 'FlamingoLMAdapter'] diff --git a/mmpretrain/models/multimodal/flamingo/adapter.py b/mmpretrain/models/multimodal/flamingo/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..bef0e2f86bfbe81046bb25fa4b9915e4c4f9005a --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/adapter.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .modules import FlamingoLayer, GatedCrossAttentionBlock +from .utils import getattr_recursive, setattr_recursive + + +@MODELS.register_module() +class FlamingoLMAdapter: + """Mixin to add cross-attention layers to a language model.""" + + @classmethod + def extend_init( + cls, + base: object, + vis_hidden_size: int, + cross_attn_every_n_layers: int, + use_media_placement_augmentation: bool, + only_attend_previous: bool = False, + ): + """Initialize Flamingo by adding a new gated cross attn to the decoder. + + Store the media token id for computing the media locations. + + Args: + base (object): Base module could be any object that represent + a instance of language model. + vis_hidden_size: (int): Hidden size of vision embeddings. + cross_attn_every_n_layers: (int): Additional cross attn for + every n layers. + use_media_placement_augmentation: (bool): Whether to use media + placement augmentation. + """ + base.set_decoder_layers_attr_name('model.layers') + gated_cross_attn_layers = nn.ModuleList([ + GatedCrossAttentionBlock( + dim=base.config.hidden_size, dim_visual=vis_hidden_size) if + (layer_idx + 1) % cross_attn_every_n_layers == 0 else None + for layer_idx, _ in enumerate(base._get_decoder_layers()) + ]) + base._set_decoder_layers( + nn.ModuleList([ + FlamingoLayer(gated_cross_attn_layer, decoder_layer) + for gated_cross_attn_layer, decoder_layer in zip( + gated_cross_attn_layers, base._get_decoder_layers()) + ])) + base.use_media_placement_augmentation = use_media_placement_augmentation # noqa + base.initialized_flamingo = True + base.only_attend_previous = only_attend_previous + return base + + def set_decoder_layers_attr_name(self, decoder_layers_attr_name): + """Set decoder layers attribute name.""" + self.decoder_layers_attr_name = decoder_layers_attr_name + + def _get_decoder_layers(self): + """Get decoder layers according to attribute name.""" + return getattr_recursive(self, self.decoder_layers_attr_name) + + def _set_decoder_layers(self, value): + """Set decoder layers according to attribute name.""" + setattr_recursive(self, self.decoder_layers_attr_name, value) + + def forward(self, *input, **kwargs): + """Condition the Flamingo layers on the media locations before forward + function.""" + input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0] + media_locations = input_ids == self.media_token_id + if self.only_attend_previous: + attend_previous = True + elif self.use_media_placement_augmentation: + attend_previous = (random.random() < 0.5) + else: + attend_previous = False + + for layer in self.get_decoder().layers: + layer.condition_media_locations(media_locations) + layer.condition_attend_previous(attend_previous) + + return super().forward( + *input, **kwargs) # Call the other parent's forward method + + def is_conditioned(self) -> bool: + """Check whether all decoder layers are already conditioned.""" + return all(layer.is_conditioned() + for layer in self._get_decoder_layers()) + + def clear_conditioned_layers(self): + """Clear all conditional layers.""" + for layer in self._get_decoder_layers(): + layer.condition_vis_x(None) + layer.condition_media_locations(None) + layer.condition_attend_previous(None) diff --git a/mmpretrain/models/multimodal/flamingo/flamingo.py b/mmpretrain/models/multimodal/flamingo/flamingo.py new file mode 100644 index 0000000000000000000000000000000000000000..729d6c741898e0ba88d59604f3d86e5ba0c539d9 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/flamingo.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .modules import PerceiverResampler +from .utils import ExtendModule + + +@MODELS.register_module() +class Flamingo(BaseModel): + """The Open Flamingo model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to 'Output:'. + shot_prompt_tmpl (str): Prompt used for few-shot inference. + Defaults to ``Output:{caption}<|endofchunk|>``. + final_prompt_tmpl (str): Final part of prompt used for inference. + Defaults to 'Output:'. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + _no_split_modules = [ + 'TransformerEncoderLayer', 'PerceiverAttention', + 'GatedCrossAttentionBlock', 'FlamingoLayer' + ] + + def __init__( + self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + zeroshot_prompt: str = 'Output:', + shot_prompt_tmpl: str = 'Output:{caption}<|endofchunk|>', + final_prompt_tmpl: str = 'Output:', + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Flamingo special tokens to the tokenizer + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['<|endofchunk|>', '']}) + self.tokenizer.bos_token_id = 1 + if self.tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + self.tokenizer.add_special_tokens({'pad_token': ''}) + + # Template to format the prompt input + self.zeroshot_prompt = zeroshot_prompt + self.shot_prompt_tmpl = shot_prompt_tmpl + self.final_prompt_tmpl = final_prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + self.vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + self.vision_encoder.is_init = True + + self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) + + # init language encoder related modules + self.lang_encoder = ExtendModule(**lang_encoder) + self.lang_encoder.resize_token_embeddings(len(self.tokenizer)) + self.lang_encoder.media_token_id = self.tokenizer.encode('')[-1] + + # other necessary parameters + self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1] + self.generation_cfg = { + 'num_beams': 1, + 'max_new_tokens': None, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 1.0, + 'no_repeat_ngram_size': 0, + 'prefix_allowed_tokens_fn': None, + 'length_penalty': 1.0, + 'num_return_sequences': 1, + 'do_sample': False, + 'early_stopping': False, + **generation_cfg, + } + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_adapter_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): The input image tensor with different ndim + according to the inputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_vision_feats(self, images: torch.Tensor) -> torch.Tensor: + """Extract vision features. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + + Returns: + torch.Tensor: Return extracted features. + """ + if images.ndim == 4: + # (B, C, H, W) -> (B, 1, C, H, W) for zero-shot. + images = images.unsqueeze(1) + b, T = images.shape[:2] + # b T c h w -> (b T) c h w + images = images.view(b * T, *images.shape[-3:]) + + with torch.no_grad(): + vision_feats = self.vision_encoder(images)[-1][:, 1:] + + # (b T F) v d -> b T F v d Only support F=1 here + vision_feats = vision_feats.view(b, T, 1, *vision_feats.shape[-2:]) + + vision_feats = self.perceiver(vision_feats) # reshapes to (b, T, n, d) + return vision_feats + + def predict(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **generation_cfg): + """Predict generation results from a batch of inputs. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **generation_cfg: Other keyword arguments accepted by the + ``generate`` method of :attr:`lang_encoder`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # generation_cfg in prediction should be dominant + generation_cfg = {**self.generation_cfg, **generation_cfg} + num_beams = generation_cfg['num_beams'] + + if num_beams > 1: + images = images.repeat_interleave(num_beams, dim=0) + + # extra vision feats and set as language condition feats + vision_x = self.extract_vision_feats(images) + for layer in self.lang_encoder._get_decoder_layers(): + layer.condition_vis_x(vision_x) + + input_text = self.preprocess_text(data_samples, device=images.device) + + outputs = self.lang_encoder.generate( + input_text.input_ids, + attention_mask=input_text.attention_mask, + eos_token_id=self.eoc_token_id, + **generation_cfg) + + # clear conditioned layers for language models + self.lang_encoder.clear_conditioned_layers() + + # remove prefix + outputs = outputs[:, len(input_text.input_ids[0]):] + + return self.post_process(outputs, data_samples) + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + if 'shots' in sample: + # few-shot + shot_prompt = ''.join([ + self.shot_prompt_tmpl.format(**shot) + for shot in sample.get('shots') + ]) + else: + # zero-shot + shot_prompt = self.zeroshot_prompt + + # add final prompt + final_prompt = self.final_prompt_tmpl.format(**sample.to_dict()) + prompts.append(shot_prompt + final_prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = re.split('Output', output, + 1)[0].replace('"', '') + elif self.task == 'vqa': + data_sample.pred_answer = re.split('Question|Answer', output, + 1)[0] + + return data_samples + + @staticmethod + def _load_adapter_hook(module, incompatible_keys): + """Avoid warning missing keys except adapter keys.""" + adapter_patterns = [ + '^perceiver', + 'lang_encoder.*embed_tokens', + 'lang_encoder.*gated_cross_attn_layers', + 'lang_encoder.*rotary_emb', + ] + for key in list(incompatible_keys.missing_keys): + if not any(re.match(pattern, key) for pattern in adapter_patterns): + incompatible_keys.missing_keys.remove(key) + + for key in list(incompatible_keys.unexpected_keys): + if 'position_ids' in key: + incompatible_keys.unexpected_keys.remove(key) + if 'lang_encoder.gated_cross_attn_layers' in key: + incompatible_keys.unexpected_keys.remove(key) diff --git a/mmpretrain/models/multimodal/flamingo/modules.py b/mmpretrain/models/multimodal/flamingo/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..730c61b68a8d0fb799b7985636f09b6484ef99c2 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/modules.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Taken from https://github.com/lucidrains/flamingo-pytorch.""" + +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + + +def FeedForward(dim, mult: int = 4): + """Feedforward layers. + + Args: + mult (int): Layer expansion muliplier. Defaults to 4. + """ + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + """Perceiver attetion layers. + + Args: + dim (int): Input dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + """ + + def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x: torch.Tensor, latents: torch.Tensor): + """Forward function. + + Args: + x (torch.Tensor): image features of shape (b, T, n1, D). + latent (torch.Tensor): latent features of shape (b, T, n2, D). + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q = rearrange(q, 'b t n (h d) -> b h t n d', h=h) + k = rearrange(k, 'b t n (h d) -> b h t n d', h=h) + v = rearrange(v, 'b t n (h d) -> b h t n d', h=h) + q = q * self.scale + + # attention + sim = einsum('... i d, ... j d -> ... i j', q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h t n d -> b t n (h d)', h=h) + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + """Perceiver resampler layers. + + Args: + dim (int): Input dimensions. + depth (int): Depth of resampler. Defaults to 6. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + num_latents (int): Number of latents. Defaults to 64. + max_num_media (int, optional): Max number of media. + Defaults to None. + max_num_frames (int, optional): Max number of frames. + Defaults to None. + ff_mult (int): Feed forward multiplier. Defaults to 4. + """ + + def __init__( + self, + *, + dim: int, + depth: int = 6, + dim_head: int = 64, + heads: int = 8, + num_latents: int = 64, + max_num_media: Optional[int] = None, + max_num_frames: Optional[int] = None, + ff_mult: int = 4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if max_num_frames is not None else None) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if max_num_media is not None else None) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention( + dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): image features of shape (b, T, F, v, D) + + Returns: + torch.Tensor: shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if self.frame_embs is not None: + frame_embs = repeat( + self.frame_embs[:F], 'F d -> b T F v d', b=b, T=T, v=v) + x = x + frame_embs + x = rearrange(x, 'b T F v d -> b T (F v) d' + ) # flatten the frame and spatial dimensions + if self.media_time_embs is not None: + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, 'n d -> b T n d', b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +class MaskedCrossAttention(nn.Module): + """Masked cross attention layers. + + Args: + dim (int): Input text feature dimensions. + dim_visual (int): Input visual feature dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + only_attend_immediate_media (bool): Whether attend immediate media. + Defaults to True. + """ + + def __init__( + self, + *, + dim: int, + dim_visual: int, + dim_head: int = 64, + heads: int = 8, + only_attend_immediate_media: bool = True, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether for text to only attend to immediate preceding image + # or all previous images + self.only_attend_immediate_media = only_attend_immediate_media + + def forward(self, + x: torch.Tensor, + media: torch.Tensor, + media_locations: Optional[torch.Tensor] = None, + attend_previous: bool = True): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): text features of shape (B, T_txt, D_txt). + media (torch.Tensor): image features of shape + (B, T_img, n, D_img) where n is the dim of the latents. + media_locations (torch.Tensor, optional): boolean mask identifying + the media tokens in x of shape (B, T_txt). Defaults to None. + attend_previous (bool): If false, ignores immediately preceding + image and starts attending when following image. + Defaults to True. + """ + _, T_img, n = media.shape[:3] + h = self.heads + + x = self.norm(x) + + q = self.to_q(x) + media = rearrange(media, 'b t n d -> b (t n) d') + + k, v = self.to_kv(media).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + k = rearrange(k, 'b n (h d) -> b h n d', h=h) + v = rearrange(v, 'b n (h d) -> b h n d', h=h) + + q = q * self.scale + + sim = einsum('... i d, ... j d -> ... i j', q, k) + + if media_locations is not None: + # at each boolean of True, increment the time counter + # (relative to media time) + text_time = media_locations.cumsum(dim=-1) + media_time = torch.arange(T_img, device=x.device) + 1 + + if not attend_previous: + text_time[~media_locations] += 1 + # make sure max is still the number of images in the sequence + text_time[text_time > repeat( + torch.count_nonzero(media_locations, dim=1), + 'b -> b i', + i=text_time.shape[1], + )] = 0 + + # text time must equal media time if only attending to most + # immediate image otherwise, as long as text time is greater than + # media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge # noqa + + text_to_media_mask = mask_op( + rearrange(text_time, 'b i -> b 1 i 1'), + repeat(media_time, 'j -> 1 1 1 (j n)', n=n), + ) + sim = sim.masked_fill(~text_to_media_mask, + -torch.finfo(sim.dtype).max) + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + if media_locations is not None and self.only_attend_immediate_media: + # any text without a preceding media needs to have + # attention zeroed out + text_without_media_mask = text_time == 0 + text_without_media_mask = rearrange(text_without_media_mask, + 'b i -> b 1 i 1') + attn = attn.masked_fill(text_without_media_mask, 0.0) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class GatedCrossAttentionBlock(nn.Module): + """Gated cross attention layers. + + Args: + dim (int): Input text feature dimensions. + dim_visual (int): Input visual feature dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + ff_mult (int): Feed forward multiplier. Defaults to 4. + only_attend_immediate_media (bool): Whether attend immediate media. + Defaults to True. + """ + + def __init__( + self, + *, + dim: int, + dim_visual: int, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + only_attend_immediate_media: bool = True, + ): + super().__init__() + self.attn = MaskedCrossAttention( + dim=dim, + dim_visual=dim_visual, + dim_head=dim_head, + heads=heads, + only_attend_immediate_media=only_attend_immediate_media, + ) + self.attn_gate = nn.Parameter(torch.tensor([0.0])) + + self.ff = FeedForward(dim, mult=ff_mult) + self.ff_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, + x: torch.Tensor, + media: torch.Tensor, + media_locations: Optional[torch.Tensor] = None, + attend_previous: bool = True): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): text features of shape (B, T_txt, D_txt). + media (torch.Tensor): image features of shape + (B, T_img, n, D_img) where n is the dim of the latents. + media_locations (torch.Tensor, optional): boolean mask identifying + the media tokens in x of shape (B, T_txt). Defaults to None. + attend_previous (bool): If false, ignores immediately preceding + image and starts attending when following image. + Defaults to True. + """ + x = ( + self.attn( + x, + media, + media_locations=media_locations, + attend_previous=attend_previous, + ) * self.attn_gate.tanh() + x) + x = self.ff(x) * self.ff_gate.tanh() + x + + return x + + +class FlamingoLayer(nn.Module): + """Faminogo layers. + + Args: + gated_cross_attn_layer (nn.Module): Gated cross attention layer. + decoder_layer (nn.Module): Decoder layer. + """ + + def __init__(self, gated_cross_attn_layer: nn.Module, + decoder_layer: nn.Module): + super().__init__() + self.gated_cross_attn_layer = gated_cross_attn_layer + self.decoder_layer = decoder_layer + self.vis_x = None + self.media_locations = None + + def is_conditioned(self) -> bool: + """Check whether the layer is conditioned.""" + return self.vis_x is not None + + def condition_vis_x(self, vis_x): + """Set condition vision features.""" + self.vis_x = vis_x + + def condition_media_locations(self, media_locations): + """Set condition media locations.""" + self.media_locations = media_locations + + def condition_attend_previous(self, attend_previous): + """Set attend previous.""" + self.attend_previous = attend_previous + + def forward( + self, + lang_x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **decoder_layer_kwargs, + ): + """Forward function. + + Args: + lang_x (torch.Tensor): language inputs. + attention_mask (torch.Tensor, optional): text attention mask. + Defaults to None. + **decoder_layer_kwargs: Other decoder layer keyword arguments. + """ + if self.gated_cross_attn_layer is None: + return self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + + if self.vis_x is None: + raise ValueError('vis_x must be conditioned before forward pass') + + if self.media_locations is None: + raise ValueError( + 'media_locations must be conditioned before forward pass') + + lang_x = self.gated_cross_attn_layer( + lang_x, + self.vis_x, + media_locations=self.media_locations, + attend_previous=self.attend_previous, + ) + lang_x = self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + return lang_x diff --git a/mmpretrain/models/multimodal/flamingo/utils.py b/mmpretrain/models/multimodal/flamingo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1077e145a7daeeff1c769d837ec9c5aac0cf3d93 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Type + +from mmpretrain.registry import MODELS + + +class ExtendModule: + """Combine the base language model with adapter. This module will create a + instance from base with extended functions in adapter. + + Args: + base (object): Base module could be any object that represent + a instance of language model or a dict that can build the + base module. + adapter: (dict): Dict to build the adapter. + """ + + def __new__(cls, base: object, adapter: dict): + + if isinstance(base, dict): + base = MODELS.build(base) + + adapter_module = MODELS.get(adapter.pop('type')) + cls.extend_instance(base, adapter_module) + return adapter_module.extend_init(base, **adapter) + + @classmethod + def extend_instance(cls, base: object, mixin: Type[Any]): + """Apply mixins to a class instance after creation. + + Args: + base (object): Base module instance. + mixin: (Type[Any]): Adapter class type to mixin. + """ + base_cls = base.__class__ + base_cls_name = base.__class__.__name__ + base.__class__ = type( + base_cls_name, (mixin, base_cls), + {}) # mixin needs to go first for our forward() logic to work + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == '': + return obj + i = att.find('.') + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1:]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) + is equivalent to obj.a.b.c = val + """ + if '.' in att: + obj = getattr_recursive(obj, '.'.join(att.split('.')[:-1])) + setattr(obj, att.split('.')[-1], val) diff --git a/mmpretrain/models/multimodal/llava/__init__.py b/mmpretrain/models/multimodal/llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aef10d34d46fc3974744881c814068ae7d6f9357 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .llava import Llava +from .modules import LlavaLlamaForCausalLM + +__all__ = ['Llava', 'LlavaLlamaForCausalLM'] diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py new file mode 100644 index 0000000000000000000000000000000000000000..103d81296f03c322e2da8697dcca7f2e0e2822a2 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ...utils import no_load_hf_pretrained_model +from .modules import LlavaLlamaForCausalLM + + +@MODELS.register_module() +class Llava(BaseModel): + """The LLaVA model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + prompt_tmpl (str): Prompt template for inference. + task (int): The task to perform prediction. + use_im_start_end (bool): Whether to use the im_start and im_end tokens + mm_vision_select_layer (int): The index from vision encoder output. + Defaults to -1. + use_mm_proj (bool): Whether to enable multi-modal projection. + Defaults to True. + load_lang_pretrained (bool): Whether to load the pretrained model of + language encoder. Defaults to False. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + im_patch_token = '' + im_start_token = '' + im_end_token = '' + + def __init__(self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + mm_hidden_size: int, + prompt_tmpl: str, + task: str = 'caption', + use_im_start_end: bool = False, + mm_vision_select_layer: int = -1, + use_mm_proj: bool = True, + generation_cfg: dict = dict(), + load_lang_pretrained: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Llava special tokens to the tokenizer + self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) + if use_im_start_end: + self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], + special_tokens=True) + + # Template to format the prompt input + self.prompt_tmpl = prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + vision_encoder.is_init = True + + # init language encoder related modules + if load_lang_pretrained: + lang_encoder = MODELS.build(lang_encoder) + else: + with no_load_hf_pretrained_model(): + lang_encoder = MODELS.build(lang_encoder) + lang_encoder.resize_token_embeddings(len(self.tokenizer)) + + self.model = LlavaLlamaForCausalLM( + vision_encoder=vision_encoder, + lang_encoder=lang_encoder, + mm_hidden_size=mm_hidden_size, + use_mm_proj=use_mm_proj, + use_im_start_end=use_im_start_end, + im_start_token=self.tokenizer.convert_tokens_to_ids( + self.im_start_token), + im_end_token=self.tokenizer.convert_tokens_to_ids( + self.im_end_token), + im_patch_token=self.tokenizer.convert_tokens_to_ids( + self.im_patch_token), + mm_vision_select_layer=mm_vision_select_layer) + + self.generation_cfg = generation_cfg + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_ckpt_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): The input image tensor with different ndim + according to the inputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'predict': + return self.predict(images, data_samples) + elif mode == 'loss': + raise NotImplementedError + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **generation_cfg): + """Predict generation results from a batch of inputs. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **generation_cfg: Other keyword arguments accepted by the + ``generate`` method of :attr:`lang_encoder`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # generation_cfg in prediction should be dominant + generation_cfg = {**self.generation_cfg, **generation_cfg} + + input_text = self.preprocess_text(data_samples, device=images.device) + + outputs = self.model.generate( + input_text.input_ids, + attention_mask=input_text.attention_mask, + eos_token_id=self.tokenizer.eos_token_id, + images=images, + **generation_cfg) + + # remove prefix + outputs = outputs[:, len(input_text.input_ids[0]):] + + return self.post_process(outputs, data_samples) + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + final_prompt = self.prompt_tmpl.format(**sample.to_dict()) + prompts.append(final_prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = output + elif self.task == 'vqa': + data_sample.pred_answer = output + + return data_samples + + @staticmethod + def _load_ckpt_hook(module, incompatible_keys): + """Avoid warning missing keys except lang_encoder keys.""" + for key in list(incompatible_keys.missing_keys): + if re.match('model.vision_tower', key): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..afa6eefadcbd73f630d8c842c80b83f229216c97 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -0,0 +1,238 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +DEFAULT_IMAGE_TOKEN = '' +DEFAULT_IMAGE_PATCH_TOKEN = '' +DEFAULT_IM_START_TOKEN = '' +DEFAULT_IM_END_TOKEN = '' + + +class LlavaLlamaForCausalLM(PreTrainedModel): + + def __init__(self, + vision_encoder, + lang_encoder, + mm_hidden_size, + use_im_start_end=True, + use_mm_proj=True, + im_start_token: Optional[int] = None, + im_end_token: Optional[int] = None, + im_patch_token: Optional[int] = None, + mm_vision_select_layer: int = -1): + super().__init__(lang_encoder.config) + self.vision_tower = vision_encoder + self.lang_encoder = lang_encoder + + self.use_im_start_end = use_im_start_end + self.im_start_token = im_start_token + self.im_end_token = im_end_token + self.im_patch_token = im_patch_token + self.mm_hidden_size = mm_hidden_size + self.mm_vision_select_layer = mm_vision_select_layer + self.lang_hidden_size = lang_encoder.config.hidden_size + + if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'): + mm_projector = nn.Linear(self.mm_hidden_size, + self.lang_hidden_size) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif not use_mm_proj: + self.lang_encoder.model.add_module('mm_projector', nn.Identity()) + + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of + # (dec_features, layer_state, dec_hidden, dec_attn) + if inputs_embeds is None: + inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids) + + inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds, + images) + + return self.lang_encoder( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use + # them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + 'images': kwargs.get('images', None), + }) + return model_inputs + + def forward_vision_tower( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + images: Union[torch.FloatTensor, list, None] = None, + ): + if self.use_im_start_end: + assert self.im_start_token is not None + assert self.im_end_token is not None + if images is not None: + assert self.im_patch_token is not None + + if self.vision_tower is None or images is None or ( + input_ids.shape[1] == 1 and not self.training): + return inputs_embeds + + with torch.no_grad(): + if isinstance(images, (list, tuple)): + # variable length images + image_features = [] + for image in images: + feats = self.vision_tower(image.unsqueeze(0)) + image_feature = feats[self.mm_vision_select_layer][:, 1:] + image_features.append(image_feature) + else: + feats = self.vision_tower(images) + image_features = feats[self.mm_vision_select_layer][:, 1:] + + mm_projector = self.lang_encoder.model.mm_projector + if isinstance(images, (list, tuple)): + image_features = [ + mm_projector(image_feature)[0] + for image_feature in image_features + ] + else: + image_features = mm_projector(image_features) + + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = mm_projector(dummy_image_features) + + new_input_embeds = [] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids != self.im_patch_token).all(): + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + ( + 0. * dummy_image_features).sum() + new_input_embeds.append(cur_input_embeds) + cur_image_idx += 1 + continue + if self.use_im_start_end: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == self.im_start_token).sum() != ( + cur_input_ids == self.im_end_token).sum(): + raise ValueError('The number of image start tokens and ' + 'image end tokens should be the same.') + image_start_tokens = torch.where( + cur_input_ids == self.im_start_token)[0] + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device) + num_patches = cur_image_features.shape[0] + if cur_input_ids[image_start_token_pos + num_patches + + 1] != self.im_end_token: + raise ValueError('The image end token should follow ' + 'the image start token.') + cur_new_input_embeds = torch.cat( + (cur_input_embeds[:image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[image_start_token_pos + num_patches + + 1:]), + dim=0) + cur_image_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == self.im_patch_token).sum() != num_patches: + print(f'Debug: num_patches: {num_patches}') + raise ValueError( + 'The number of image patch tokens should ' + 'be the same as the number of image patches.') + masked_indices = torch.where( + cur_input_ids == self.im_patch_token)[0] + mask_index_start = masked_indices[0] + if (masked_indices != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype)).any(): + raise ValueError( + 'The image patch tokens should be consecutive.') + cur_new_input_embeds = torch.cat( + (cur_input_embeds[:mask_index_start], cur_image_features, + cur_input_embeds[mask_index_start + num_patches:]), + dim=0) + new_input_embeds.append(cur_new_input_embeds) + cur_image_idx += 1 + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return inputs_embeds + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/minigpt4/__init__.py b/mmpretrain/models/multimodal/minigpt4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5358bb1377ee6da7d848c06f3a249493645cdbf7 --- /dev/null +++ b/mmpretrain/models/multimodal/minigpt4/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .minigpt4 import MiniGPT4 + +__all__ = ['MiniGPT4'] diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py new file mode 100644 index 0000000000000000000000000000000000000000..d25d0b6be36cbc52d9ae636e1b62e27f00bd2cbd --- /dev/null +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +import re +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.logging import MMLogger +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class MiniGPT4(BaseModel): + """The multi-modality model of MiniGPT-4. + + The implementation of `MiniGPT-4 `_. + Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py + + Args: + vision_encoder (dict): The config for vision encoder. + q_former_model (dict): The config for Qformer. + lang_encoder (dict): The config for language model. + tokenizer (dict): The config for tokenizer. + task (str): To define the task, which control the processing of text. + Defaults to 'caption'. + freeze_vit (bool): Freeze the training of ViT. Defaults to True. + freeze_q_former (bool): Freeze the training of Qformer. Defaults to + True. + num_query_token (int): Number of query tokens of Qformer. Defaults to + 32. + prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]) + raw_prompts (dict): Prompts for training. Defaults to dict(). + max_txt_len (int): Max token length while doing tokenization. Defaults + to 32. + end_sym (str): Ended symbol of the sequence. Defaults to '###'. + generation_cfg (dict): The config of text generation. Defaults to + dict(). + data_preprocessor (:obj:`BaseDataPreprocessor`): Used for + pre-processing data sampled by dataloader to the format accepted by + :meth:`forward`. Defaults to None. + init_cfg (dict): Initialization config dict. Defaults to None. + """ # noqa + + def __init__(self, + vision_encoder: dict, + q_former_model: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + freeze_vit: bool = True, + freeze_q_former: bool = True, + num_query_token: int = 32, + prompt_template: dict = dict([('en', + '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts: dict = dict(), + max_txt_len: int = 32, + end_sym: str = '###', + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.task = task + logger = MMLogger.get_current_instance() + + # build vision model + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims) + + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint(self.vision_encoder, vision_encoder_weight) + self.vision_encoder.is_init = True + if freeze_vit: + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + else: + logger.warning('Please check `frozen_stages` in the dict of' + '`vision_encoder`. Also set it to be -1 if do not' + 'freeze ViT.') + + # build Qformer + q_former_model_weight = q_former_model.pop('pretrained', None) + self.q_former = MODELS.build(q_former_model) + self.q_former.cls = None + self.q_former.bert.embeddings.word_embeddings = None + self.q_former.bert.embeddings.position_embeddings = None + for layer in self.q_former.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, self.q_former.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, std=self.q_former.config.initializer_range) + + if q_former_model_weight is not None: + from mmengine.runner.checkpoint import CheckpointLoader + state_dict = CheckpointLoader.load_checkpoint( + q_former_model_weight)['state_dict'] + self.load_state_dict(state_dict, strict=False) + # The ln_vision weights are also in the q-former checkpoint. + setattr(self.ln_vision, 'is_init', True) + setattr(self.q_former, 'is_init', True) + + if freeze_q_former: + for name, param in self.q_former.named_parameters(): + param.requires_grad = False + self.q_former.eval() + self.query_tokens.requires_grad = False + + # build language model + self.llama_tokenizer = TOKENIZER.build(tokenizer) + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + + self.llama_model = MODELS.build(lang_encoder) + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + + # build linear projection layer + self.llama_proj = nn.Linear(self.q_former.config.hidden_size, + self.llama_model.config.hidden_size) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] + + # set prompts + self.en_prompt_list, self.zh_prompt_list = [], [] + if raw_prompts.get('en') is not None: + en_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['en'] + if '' in raw_prompt + ] + self.en_prompt_list = [ + prompt_template['en'].format(p) for p in en_filted_prompts + ] + if raw_prompts.get('zh') is not None: + zh_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['zh'] + if '' in raw_prompt + ] + self.zh_prompt_list = [ + prompt_template['zh'].format(p) for p in zh_filted_prompts + ] + + # update generation configs + self.generation_cfg = dict( + max_new_tokens=300, + num_beams=1, + do_sample=True, + min_length=1, + top_p=0.9, + repetition_penalty=1.1, + length_penalty=1.0, + temperature=1.0) + self.generation_cfg.update(**generation_cfg) + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_llama_proj_hook) + + def half(self): + self.llama_model = self.llama_model.half() + return self + + def encode_img(self, + images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """The function to encode the images.""" + device = images.device + x = self.vision_encoder(images)[0] + image_embeds = self.ln_vision(x).to(device) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.q_former.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + atts_llama = torch.ones( + inputs_llama.size()[:-1], dtype=torch.long).to(images.device) + return inputs_llama, atts_llama + + def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, + prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """The function to wrap the image and prompt. + + Make sure that len(prompt) == img_embeds.shape[0]. + + Args: + img_embeds (torch.Tensor): The embedding of the input images. + atts_img (torch.Tensor): Attention map of the image embeddings. + prompt (List[str]): The prompt of the batch data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. + """ + if len(prompt) > 0: + p_before_list, p_after_list = [], [] + for pro in prompt: + p_before, p_after = pro.split('') + p_before_list.append(p_before) + p_after_list.append(p_after) + p_before_tokens = self.llama_tokenizer( + p_before_list, + return_tensors='pt', + padding='longest', + add_special_tokens=False).to(img_embeds.device) + p_after_tokens = self.llama_tokenizer( + p_after_list, + return_tensors='pt', + padding='longest', + add_special_tokens=False).to(img_embeds.device) + p_before_embeds = self.llama_model.model.embed_tokens( + p_before_tokens.input_ids) + p_after_embeds = self.llama_model.model.embed_tokens( + p_after_tokens.input_ids) + wrapped_img_embeds = torch.cat( + [p_before_embeds, img_embeds, p_after_embeds], dim=1) + wrapped_atts_img = atts_img[:, :1].expand( + -1, wrapped_img_embeds.shape[1]) + return wrapped_img_embeds, wrapped_atts_img + else: + return img_embeds, atts_img + + def loss(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None) -> dict: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + img_embeds, atts_img = self.encode_img(images) + + self.llama_tokenizer.padding_side = 'right' + + prompts, texts = [], [] + for t in data_samples: + chat_content = t.chat_content + split_mark = '###Answer: ' if t.lang == 'en' else '###答:' + prompt, text = chat_content.split(split_mark) + prompt += split_mark + text += self.end_sym + prompts.append(prompt) + texts.append(text) + + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) + + to_regress_tokens = self.llama_tokenizer( + texts, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False).to(images.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, + -100) + + empty_targets = ( + torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], + dtype=torch.long).to(images.device).fill_( + -100) # plus one for bos + ) + targets = torch.cat([empty_targets, targets], dim=1) + + batch_size = img_embeds.shape[0] + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device + ) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + atts_bos = atts_img[:, :1] + + to_regress_embeds = self.llama_model.model.embed_tokens( + to_regress_tokens.input_ids) + inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], + dim=1) + attention_mask = torch.cat( + [atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + return dict(loss=loss) + + def predict( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None + ) -> List[DataSample]: + + with torch.no_grad(): + img_embeds, atts_img = self.encode_img(images) + + prompts = [ + random.choice(self.zh_prompt_list) if hasattr(t, 'lang') + and t.lang == 'zh' else random.choice(self.en_prompt_list) + for t in data_samples + ] + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) + + batch_size = img_embeds.shape[0] + bos = torch.ones( + [batch_size, 1], dtype=torch.long, + device=img_embeds.device) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) + + outputs = self.llama_model.generate( + inputs_embeds=inputs_embeds, + eos_token_id=self.end_token_id, + **self.generation_cfg) + + return self.post_process(outputs, data_samples) + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.llama_tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + if self.task == 'caption': + output = output.split('###')[0] + data_sample.pred_caption = output + else: + # raw output + data_sample.pred_output = output + return data_samples + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @staticmethod + def _load_llama_proj_hook(module, incompatible_keys): + """Avoid warning missing keys except LLaMA projection keys.""" + proj_patterns = [ + 'vision_encoder.*', + 'ln_vision.*', + 'q_former.*', + 'query_tokens', + 'llama_model.*', + ] + for key in list(incompatible_keys.missing_keys): + if any(re.match(pattern, key) for pattern in proj_patterns): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/ofa/__init__.py b/mmpretrain/models/multimodal/ofa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb3f45f09b757304bfca3de2a94d217ff78d8d4 --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ofa import OFA +from .ofa_modules import OFADecoder, OFAEncoder, OFAEncoderDecoder + +__all__ = ['OFAEncoderDecoder', 'OFA', 'OFAEncoder', 'OFADecoder'] diff --git a/mmpretrain/models/multimodal/ofa/ofa.py b/mmpretrain/models/multimodal/ofa/ofa.py new file mode 100644 index 0000000000000000000000000000000000000000..e15787a60d66ac56308b320cdd73a7703a2a29bc --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/ofa.py @@ -0,0 +1,320 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import string +from collections import defaultdict +from functools import partial +from typing import Optional, Union + +import mmengine +import torch +from mmengine.model import BaseModel + +from mmpretrain.datasets import CleanCaption +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .ofa_modules import OFAEncoderDecoder + + +class TreeNode(): + + def __init__(self): + self.child = defaultdict(TreeNode) + + +class Trie: + + def __init__(self, eos): + self.root = TreeNode() + self.eos = eos + + def insert(self, word): + cur = self.root + for c in word: + cur = cur.child[c] + + def get_next_layer(self, word): + cur = self.root + for c in word: + cur = cur.child.get(c) + if cur is None: + return [self.eos] + return list(cur.child.keys()) + + +def apply_constraint( + input_ids: torch.Tensor, + logits: torch.Tensor, + decoder_prompts: Optional[list], + num_beams: int, + constraint_trie: Trie = None, +): + if decoder_prompts is None and constraint_trie is None: + return logits + + mask = logits.new_zeros(logits[:, -1, :].size(), dtype=torch.bool) + input_ids = input_ids.view(-1, num_beams, input_ids.shape[-1]) + for batch_id, beam_sent in enumerate(input_ids): + for beam_id, sent in enumerate(beam_sent): + if decoder_prompts is None: + prompt_len = 0 + else: + prompt_len = len(decoder_prompts[batch_id]) + + if sent.size(0) - 1 < prompt_len: + allowed_tokens = [decoder_prompts[batch_id][sent.size(0) - 1]] + mask[batch_id * num_beams + beam_id, allowed_tokens] = True + elif constraint_trie is not None: + answer_tokens = [0] + sent[prompt_len + 1:].tolist() + allowed_tokens = constraint_trie.get_next_layer(answer_tokens) + mask[batch_id * num_beams + beam_id, allowed_tokens] = True + else: + mask[batch_id * num_beams + beam_id, :] = True + logits[:, -1, :].masked_fill_(~mask, float('-inf')) + return logits + + +@MODELS.register_module() +class OFA(BaseModel): + """The OFA model for multiple tasks. + + Args: + encoder_cfg (dict): The config of the encoder, accept the keyword + arguments of :class:`OFAEncoder`. + decoder_cfg (dict): The config of the decoder, accept the keyword + arguments of :class:`OFADecoder`. + vocab_size (int): The size of the vocabulary. + embedding_dim (int): The embedding dimensions of both the encoder + and the decoder. + tokenizer (dict | PreTrainedTokenizer): The tokenizer to encode + the text. + task (str): The task name, supported tasks are "caption", "vqa" and + "refcoco". + prompt (str, optional): The prompt template for the following tasks, + If None, use default prompt: + + - **caption**: ' what does the image describe?' + - **refcoco**: ' which region does the text " {} " describe?' + + Defaults to None + ans2label (str | Sequence | None): The answer to label mapping for + the vqa task. If a string, it should be a pickle or json file. + The sequence constrains the output answers. Defaults to None, + which means no constraint. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of :class:`~transformers.GenerationConfig`. + Defaults to an empty dict. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. See :class: + `MultiModalDataPreprocessor` for more details. Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + support_tasks = {'caption', 'vqa', 'refcoco'} + + def __init__( + self, + encoder_cfg, + decoder_cfg, + vocab_size, + embedding_dim, + tokenizer, + task, + prompt=None, + ans2label: Union[dict, str, None] = None, + generation_cfg=dict(), + data_preprocessor: Optional[dict] = None, + init_cfg=None, + ): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if isinstance(tokenizer, dict): + self.tokenizer = TOKENIZER.build(tokenizer) + else: + self.tokenizer = tokenizer + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + + self.prompt = prompt + self.task = task + + if isinstance(ans2label, str): + self.ans2label = mmengine.load(ans2label) + else: + self.ans2label = ans2label + + if self.task == 'vqa' and self.ans2label is not None: + self.constraint_trie = Trie(eos=self.tokenizer.eos_token_id) + answers = [f' {answer}' for answer in self.ans2label] + answer_tokens = self.tokenizer(answers, padding=False) + for answer_token in answer_tokens['input_ids']: + self.constraint_trie.insert(answer_token) + else: + self.constraint_trie = None + + generation_cfg = { + 'num_beams': 5, + 'max_new_tokens': 20, + 'no_repeat_ngram_size': 3, + **generation_cfg, + } + self.model = OFAEncoderDecoder( + encoder_cfg=encoder_cfg, + decoder_cfg=decoder_cfg, + padding_idx=self.tokenizer.pad_token_id, + vocab_size=vocab_size, + embedding_dim=embedding_dim, + generation_cfg=generation_cfg, + ) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict( + self, + images, + data_samples=None, + post_process=True, + **generation_config, + ): + text_tokens = self.preprocess_text(data_samples, images.size(0), + images.device) + + if 'images_mask' in data_samples[0]: + images_mask = torch.tensor([ + sample.get('images_mask') for sample in data_samples + ]).bool().to(images.device) + else: + images_mask = None + + num_beams = generation_config.get( + 'num_beams', getattr(self.model.generation_config, 'num_beams')) + decoder_prompts = self.get_decoder_prompts(data_samples) + constrain_fn = partial( + apply_constraint, + constraint_trie=self.constraint_trie, + decoder_prompts=decoder_prompts, + num_beams=num_beams, + ) + + outputs = self.model.generate( + input_ids=text_tokens, + images=images, + images_mask=images_mask, + constrain_fn=constrain_fn, + **generation_config, + ) + + if decoder_prompts is not None: + # Remove the prefix decoder prompt. + for prompt_ids, token in zip(decoder_prompts, outputs): + token[1:len(prompt_ids) + 1] = self.tokenizer.pad_token_id + + if post_process: + return self.post_process(outputs, data_samples) + else: + return outputs + + def get_decoder_prompts(self, data_samples): + decoder_prompts = [] + if 'decoder_prompt' not in data_samples[0]: + return None + for sample in data_samples: + prompt = ' ' + sample.get('decoder_prompt') + prompt_ids = self.tokenizer(prompt, add_special_tokens=False) + prompt_ids = prompt_ids['input_ids'] + decoder_prompts.append(prompt_ids) + return decoder_prompts + + def preprocess_text(self, data_samples, batch_size, device): + if self.task == 'caption': + prompt = self.prompt or ' what does the image describe?' + prompts = [prompt] * batch_size + prompts = self.tokenizer(prompts, return_tensors='pt') + return prompts.input_ids.to(device) + elif self.task == 'vqa': + prompts = [] + for sample in data_samples: + assert 'question' in sample + prompt = ' ' + sample.get('question') + prompts.append(prompt) + prompts = self.tokenizer( + prompts, return_tensors='pt', padding=True) + return prompts.input_ids.to(device) + elif self.task == 'refcoco': + prompt_template = self.prompt or \ + ' which region does the text " {} " describe?' + prompts = [] + for sample in data_samples: + assert 'text' in sample + prompt = prompt_template.format(sample.get('text')) + prompts.append(prompt) + prompts = self.tokenizer( + prompts, return_tensors='pt', padding=True) + return prompts.input_ids.to(device) + + def post_process(self, outputs, data_samples): + + out_data_samples = [] + if data_samples is None: + data_samples = [None] * outputs.size(0) + + for data_sample, token in zip(data_samples, outputs): + if data_sample is None: + data_sample = DataSample() + + if self.task == 'caption': + text = self.tokenizer.decode(token, skip_special_tokens=True) + text = CleanCaption( + lowercase=False, + remove_chars=string.punctuation).clean(text) + data_sample.pred_caption = text + elif self.task == 'vqa': + text = self.tokenizer.decode(token, skip_special_tokens=True) + data_sample.pred_answer = text.strip() + elif self.task == 'refcoco': + bbox = token[1:5] - self.tokenizer.bin_offset + # During training, the bbox is normalized by 512. It's related + # to the `max_image_size` config in the official repo. + bbox = bbox / self.tokenizer.num_bins * 512 + scale_factor = data_sample.get('scale_factor', (1, 1)) + bbox[0::2] /= scale_factor[0] + bbox[1::2] /= scale_factor[1] + data_sample.pred_bboxes = bbox.unsqueeze(0) + if 'gt_bboxes' in data_sample: + gt_bboxes = bbox.new_tensor(data_sample.gt_bboxes) + gt_bboxes[:, 0::2] /= scale_factor[0] + gt_bboxes[:, 1::2] /= scale_factor[1] + data_sample.gt_bboxes = gt_bboxes + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/ofa/ofa_modules.py b/mmpretrain/models/multimodal/ofa/ofa_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..ef5c8533755739fb6b9f01211cbf10032544bf8b --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/ofa_modules.py @@ -0,0 +1,1613 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.utils import digit_version +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, ModelOutput, Seq2SeqLMOutput) +from transformers.modeling_utils import (GenerationConfig, GenerationMixin, + PretrainedConfig) + +from mmpretrain.registry import MODELS +from ...backbones.resnet import Bottleneck, ResNet + +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +def make_token_bucket_position(bucket_size, max_position=1024): + context_pos = torch.arange(max_position, dtype=torch.long)[:, None] + memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] + relative_pos = context_pos - memory_pos + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), + mid - 1, torch.abs(relative_pos)) + log_pos = torch.ceil( + torch.log(abs_pos / mid) / math.log( + (max_position - 1) / mid) * (mid - 1)) + mid + log_pos = log_pos.int() + bucket_pos = torch.where(abs_pos.le(mid), relative_pos, + log_pos * sign).long() + return bucket_pos + bucket_size - 1 + + +def make_image_bucket_position(bucket_size, num_relative_distance): + coords_h = torch.arange(bucket_size) + coords_w = torch.arange(bucket_size) + # (2, h, w) + coords = torch.stack(torch_meshgrid([coords_h, coords_w])) + # (2, h*w) + coords_flatten = torch.flatten(coords, 1) + # (2, h*w, h*w) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + # (h*w, h*w, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += bucket_size - 1 + relative_coords[:, :, 0] *= 2 * bucket_size - 1 + relative_position_index = torch.zeros( + size=(bucket_size * bucket_size + 1, ) * 2, + dtype=relative_coords.dtype) + # (h*w, h*w) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + """Make causal mask used for uni-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float('-inf')) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from ``[B, L_s]`` to ``[B, 1, L_t, L_s]``. + + Where ``B`` is batch_size, `L_s`` is the source sequence length, and + ``L_t`` is the target sequence length. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +class MultiheadAttention(BaseModule): + """Multi-head Attention Module for OFA. + + Args: + embedding_dim (int): The embedding dimension of query. + num_heads (int): Parallel attention heads. + kdim (int, optional): The embedding dimension of key. + Defaults to None, which means the same as the `embedding_dim`. + vdim (int, optional): The embedding dimension of value. + Defaults to None, which means the same as the `embedding_dim`. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + scale_factor (float): The scale of qk will be + ``(head_dim * scale_factor) ** -0.5``. Defaults to 1. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embedding_dim, + num_heads, + kdim=None, + vdim=None, + attn_drop=0., + scale_factor=1., + qkv_bias=True, + proj_bias=True, + scale_heads=False, + init_cfg=None): + super(MultiheadAttention, self).__init__(init_cfg=init_cfg) + + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.kdim = kdim or embedding_dim + self.vdim = vdim or embedding_dim + + self.head_dim = embedding_dim // num_heads + self.scale = (self.head_dim * scale_factor)**-0.5 + + self.q_proj = nn.Linear(embedding_dim, embedding_dim, bias=qkv_bias) + self.k_proj = nn.Linear(self.kdim, embedding_dim, bias=qkv_bias) + self.v_proj = nn.Linear(self.vdim, embedding_dim, bias=qkv_bias) + self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=proj_bias) + + self.attn_drop = nn.Dropout(p=attn_drop) + + if scale_heads: + self.c_attn = nn.Parameter(torch.ones(num_heads)) + else: + self.c_attn = None + + def forward( + self, + query, + key_value=None, + attn_mask=None, + attn_bias=None, + past_key_value=None, + output_attentions=False, + ): + B, _, C = query.shape + assert C == self.head_dim * self.num_heads + + is_cross_attention = key_value is not None + if key_value is None: + key_value = query + + # (B, L, C) -> (B, num_heads, L, head_dims) + q = self.q_proj(query).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + + if is_cross_attention and past_key_value is not None: + # Reuse key and value in cross_attentions + k, v = past_key_value + else: + k = self.k_proj(key_value).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + v = self.v_proj(key_value).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + if past_key_value is not None: + past_key, past_value = past_key_value + k = torch.cat([past_key, k], dim=2) + v = torch.cat([past_value, v], dim=2) + + past_key_value = (k, v) + + attn_weights = q @ k.transpose(-2, -1) * self.scale + + if attn_bias is not None: + src_len = k.size(2) + attn_weights[:, :, -src_len:] += attn_bias[:, :, -src_len:] + + if attn_mask is not None: + attn_weights += attn_mask + attn_weights = torch.softmax(attn_weights, dim=-1) + attn = self.attn_drop(attn_weights) @ v + + if self.c_attn is not None: + attn = torch.einsum('bhlc,h->bhlc', attn, self.c_attn) + + # (B, num_heads, L, head_dims) -> (B, L, C) + attn = attn.transpose(1, 2).reshape(B, -1, self.embedding_dim) + attn = self.out_proj(attn) + + if output_attentions: + return attn, attn_weights, past_key_value + else: + return attn, None, past_key_value + + +@MODELS.register_module(force=True) +class OFAResNet(ResNet): + """ResNet module for OFA. + + The ResNet in OFA has only three stages. + """ + arch_settings = { + 50: (Bottleneck, (3, 4, 6)), + 101: (Bottleneck, (3, 4, 23)), + 152: (Bottleneck, (3, 8, 36)), + } + + def __init__(self, depth, *args, **kwargs): + super().__init__( + depth=depth, + *args, + num_stages=3, + out_indices=(2, ), + dilations=(1, 1, 1), + strides=(1, 2, 2), + **kwargs) + + +@dataclass +class OFAEncoderOutput(ModelOutput): + """OFA encoder outputs. + + Args: + last_hidden_state (torch.tensor): The hidden-states of the output at + the last layer of the model. The shape is (B, L, C). + hidden_states (Tuple[torch.tensor]): The initial embedding and the + output of each layer. The shape of every item is (B, L, C). + attentions (Tuple[torch.tensor]): The attention weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. The shape of every item is + (B, num_heads, L, L). + position_embedding (torch.tensor): The positional embeddings of the + inputs. The shape is (B, L, C). + """ + + last_hidden_state: torch.FloatTensor = None + padding_mask: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + position_embedding: Optional[torch.FloatTensor] = None + + +class OFAEncoderLayer(nn.Module): + """OFAEncoder layer block.""" + + def __init__(self, + embedding_dim, + num_heads, + dropout_rate=0., + drop_path_rate=0., + attn_drop=0., + act_drop=0., + scale_factor=2., + mlp_ratio=4., + scale_heads=True, + normformer=True, + pre_norm=True, + act_cfg=dict(type='GELU')): + super().__init__() + self.embedding_dim = embedding_dim + self.pre_norm = pre_norm + + self.attn = MultiheadAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + mid_channels = int(embedding_dim * mlp_ratio) + self.fc1 = nn.Linear(embedding_dim, mid_channels) + self.fc2 = nn.Linear(mid_channels, embedding_dim) + self.act = MODELS.build(act_cfg) + self.act_drop = nn.Dropout( + act_drop) if act_drop > 0. else nn.Identity() + + # LayerNorm between attention block and ffn block. + self.attn_ln = nn.LayerNorm(embedding_dim) + self.ffn_ln = nn.LayerNorm(embedding_dim) + + # Extra LayerNorm + self.normformer = normformer + if self.normformer: + self.attn_mid_ln = nn.LayerNorm(embedding_dim) + self.ffn_mid_ln = nn.LayerNorm(mid_channels) + + self.dropout = nn.Dropout(dropout_rate) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, + x, + attention_mask=None, + attn_bias=None, + output_attentions=False): + """Forward the encoder layer. + + Args: + x (torch.tensor): The input to the layer of shape ``(B, L, C)``. + attention_mask (torch.Tensor, optional): The attention mask of size + ``(B, 1, L, L)``, where padding elements are indicated by very + large negative values. Defaults to None. + attn_bias (torch.tensor, optional): The bias for positional + information. Defaults to None. + output_attentions (bool): Whether to return the attentions tensors + of the attention layer. + + Returns: + List[torch.tensor]: The first element is the encoded output of + shape ``(B, L, C)``. And the second element is the output + attentions if ``output_attentions=True``. + """ + residual = x + + # Attention block + if self.pre_norm: + x = self.attn_ln(x) + x, attn_weights, _ = self.attn( + query=x, + attn_mask=attention_mask, + attn_bias=attn_bias, + output_attentions=output_attentions) + if self.normformer: + x = self.attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.attn_ln(x) + + residual = x + + # FFN block + if self.pre_norm: + x = self.ffn_ln(x) + x = self.act(self.fc1(x)) + x = self.act_drop(x) + if self.normformer: + x = self.ffn_mid_ln(x) + x = self.fc2(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.ffn_ln(x) + + if output_attentions: + return [x, attn_weights] + else: + return [x] + + +class OFADecoderLayer(nn.Module): + """OFADecoder layer block.""" + + def __init__(self, + embedding_dim, + num_heads, + dropout_rate=0., + drop_path_rate=0., + attn_drop=0., + act_drop=0., + scale_factor=2., + mlp_ratio=4., + encoder_embed_dim=None, + scale_heads=True, + normformer=True, + pre_norm=True, + act_cfg=dict(type='GELU')): + super().__init__() + self.embedding_dim = embedding_dim + self.pre_norm = pre_norm + + self.self_attn = MultiheadAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + self.cross_attn = MultiheadAttention( + embedding_dim=embedding_dim, + kdim=encoder_embed_dim, + vdim=encoder_embed_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + mid_channels = int(embedding_dim * mlp_ratio) + self.fc1 = nn.Linear(embedding_dim, mid_channels) + self.fc2 = nn.Linear(mid_channels, embedding_dim) + self.act = MODELS.build(act_cfg) + self.act_drop = nn.Dropout( + act_drop) if act_drop > 0. else nn.Identity() + + # LayerNorm between attention block and ffn block. + self.self_attn_ln = nn.LayerNorm(embedding_dim) + self.cross_attn_ln = nn.LayerNorm(embedding_dim) + self.ffn_ln = nn.LayerNorm(embedding_dim) + + # Extra LayerNorm + self.normformer = normformer + if self.normformer: + self.self_attn_mid_ln = nn.LayerNorm(embedding_dim) + self.cross_attn_mid_ln = nn.LayerNorm(embedding_dim) + self.ffn_mid_ln = nn.LayerNorm(mid_channels) + + self.dropout = nn.Dropout(dropout_rate) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward( + self, + x, + attention_mask=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[List[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + self_attn_bias: Optional[torch.Tensor] = None, + cross_attn_bias: Optional[torch.Tensor] = None, + ): + """Forward the decoder layer. + + Args: + x (torch.tensor): The input to the layer of shape ``(B, L, C)``. + attention_mask (torch.Tensor, optional): The attention mask of size + ``(B, 1, L, L)``, where padding elements are indicated by very + large negative values. Defaults to None. + encoder_hidden_states (torch.Tensor, optional): The cross attention + input to the layer of size ``(B, L, C)``. Defaults to None. + encoder_attention_mask (torch.Tensor, optional): The cross + attention mask where padding elements are indicated by very + large negative values. Defaults to None. + past_key_value (Tuple[torch.tensor], optional): The cached past key + and value projection states. Defaults to none. + output_attentions (bool): whether to return the attentions tensors + of all attention layers. Defaults to False. + use_cache (bool, optional): Whether to use cache. + Defaults to False. + self_attn_bias (torch.Tensor, optional): The self attention bias + for positional information. Defaults to None. + cross_attn_bias (torch.Tensor, optional): The cross attention bias + for positional information. Defaults to None. + + Returns: + List[torch.tensor]: The first element is the encoded output of + shape ``(B, L, C)``. The following two elements can be the output + self-attentions and cross-attentions if ``output_attentions=True``. + The following one element can be the cached past key and value + projection states. + """ + residual = x + + if past_key_value is not None: + self_past_key_value = past_key_value[:2] + cross_past_key_value = past_key_value[2:] + else: + self_past_key_value, cross_past_key_value = None, None + + # Self-Attention block + if self.pre_norm: + x = self.self_attn_ln(x) + x, self_attn_weights, present_key_value = self.self_attn( + query=x, + past_key_value=self_past_key_value, + attn_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=self_attn_bias, + ) + if self.normformer: + x = self.self_attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.self_attn_ln(x) + + # Cross-Attention block + if encoder_hidden_states is not None: + residual = x + if self.pre_norm: + x = self.cross_attn_ln(x) + x, cross_attn_weights, cross_key_value = self.cross_attn.forward( + query=x, + key_value=encoder_hidden_states, + attn_mask=encoder_attention_mask, + past_key_value=cross_past_key_value, + output_attentions=output_attentions, + attn_bias=cross_attn_bias) + if self.normformer: + x = self.cross_attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.cross_attn_ln(x) + + present_key_value = present_key_value + cross_key_value + + residual = x + + # FFN block + if self.pre_norm: + x = self.ffn_ln(x) + x = self.act(self.fc1(x)) + x = self.act_drop(x) + if self.normformer: + x = self.ffn_mid_ln(x) + x = self.fc2(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.ffn_ln(x) + + outputs = [x] + + if output_attentions: + outputs.extend([self_attn_weights, cross_attn_weights]) + + if use_cache: + outputs.append(present_key_value) + + return outputs + + +class OFAEncoder(BaseModule): + """The encoder module of OFA. + + Args: + embed_tokens (nn.Embedding): The embedding module to embed the + input tokens. + embed_images (dict | nn.Module): The module to embed the input + images into features. The output number of channels should + be 1024. + num_layers (int): The number of encoder layers. Defaults to 6. + num_heads (int): The number of heads of attention. Defaults to 12. + dropout_rate (float): The prob of dropout for embedding and + transformer layers. Defaults to 0. + drop_path_rate (float): The prob of droppath for transformer layers. + Defaults to 0. + max_source_positions (int): The maximum length of the input tokens. + Defaults to 1024. + token_bucket_size (int): The token bucket size, it's used as the + maximum relative position index in relative position embedding + of input tokens. Defaults to 256. + image_bucket_size (int): The image bucket size, it's used to generate + the image relative position embedding table. It should be larger + than the shape of image feature map. Defaults to 42. + attn_scale_factor (float): The scale factor to calculate qk scale in + attentions. Defaults to 2. + scale_embedding (bool): Whether to scale the embeddings by the square + root of the dimension. Defaults to False. + add_embedding_ln (bool): Whether to add an extra layer norm for token + embeddings. Defaults to True. + add_image_embedding_ln (bool): Whether to add an extra layer norm for + image embeddings. Defaults to True. + pre_norm (bool): Whether to do layer norm before attention and ffn + blocks in transformer layers. Defaults to True. + entangle_position_embedding (bool): Whether to add the position + embedding on the embeddings directly. Defaults to False. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + embed_tokens, + embed_images: dict, + num_layers=6, + num_heads=12, + dropout_rate=0., + drop_path_rate=0., + max_source_positions=1024, + token_bucket_size=256, + image_bucket_size=42, + attn_scale_factor=2., + scale_embedding=False, + add_embedding_ln=True, + add_type_embed=True, + add_image_embedding_ln=True, + pre_norm=True, + entangle_position_embedding=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + self.num_layers = num_layers + embedding_dim = embed_tokens.embedding_dim + self.embedding_dim = embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = max_source_positions + self.num_heads = num_heads + + # Build embedding process components + self.embed_tokens = embed_tokens + self.embedding_scale = math.sqrt( + embedding_dim) if scale_embedding else 1.0 + + if not isinstance(embed_images, nn.Module): + self.embed_images = MODELS.build(embed_images) + else: + self.embed_images = embed_images + self.image_proj = nn.Linear(1024, embedding_dim) + + if add_embedding_ln: + self.embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.embedding_ln = None + + if add_type_embed: + self.embed_type = nn.Embedding(2, embedding_dim) + else: + self.embed_type = None + + if add_image_embedding_ln: + self.image_embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.image_embedding_ln = None + + self.entangle_position_embedding = entangle_position_embedding + + # Build position embedding + self.embed_positions = nn.Embedding(self.max_source_positions + 2, + embedding_dim) + self.pos_ln = nn.LayerNorm(embedding_dim) + self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, + embedding_dim) + self.image_pos_ln = nn.LayerNorm(embedding_dim) + + self.pos_scaling = float(embedding_dim / num_heads * + attn_scale_factor)**-0.5 + self.pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate > 0. else nn.Identity() + + # Register token relative position embedding table + self.token_bucket_size = token_bucket_size + token_num_rel_dis = 2 * token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(token_bucket_size, + self.max_source_positions) + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.token_rel_pos_table_list = nn.ModuleList() + + # Register image relative position embedding table + self.image_bucket_size = image_bucket_size + image_num_rel_dis = (2 * image_bucket_size - + 1) * (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(image_bucket_size, + image_num_rel_dis) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.image_rel_pos_table_list = nn.ModuleList() + + # Build encoder layers + self.layers = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + for index in range(self.num_layers): + layer = OFAEncoderLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + dropout_rate=dropout_rate, + drop_path_rate=dpr[index], + scale_factor=attn_scale_factor, + pre_norm=pre_norm, + ) + self.layers.append(layer) + token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) + image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) + nn.init.constant_(token_pos_table.weight, 0.) + nn.init.constant_(image_pos_table.weight, 0.) + self.token_rel_pos_table_list.append(token_pos_table) + self.image_rel_pos_table_list.append(image_pos_table) + + if pre_norm: + self.final_ln = nn.LayerNorm(embedding_dim) + else: + self.final_ln = None + + main_input_name = 'input_ids' + + def forward(self, + input_ids, + images, + images_mask, + output_attentions=False, + output_hidden_states=False, + sample_patch_num=None): + padding_mask = input_ids.eq(self.padding_idx) + has_pads = padding_mask.any() + token_embedding = self.embed_tokens(input_ids) + token_embedding = self.embedding_scale * token_embedding + + # Embed the token position + src_pos_idx = torch.arange(input_ids.size(-1), device=input_ids.device) + src_pos_idx = src_pos_idx.expand(*input_ids.shape).contiguous() + pos_embedding = self.embed_positions(src_pos_idx) + + # Embed the input tokens + x = self.process_embedding( + embedding=token_embedding, + type_tokens=input_ids.new_zeros(token_embedding.shape[:2]), + pos_embedding=pos_embedding, + embedding_ln=self.embedding_ln, + ) + pos_embedding = self.pos_ln(pos_embedding) + + # Embed the input images + if images is not None: + (image_tokens, image_padding_mask, image_position_ids, + image_pos_embedding) = self.get_image_tokens( + images, + sample_patch_num, + images_mask, + ) + image_embedding = self.image_proj(image_tokens) + + image_x = self.process_embedding( + embedding=image_embedding, + type_tokens=input_ids.new_ones(image_embedding.shape[:2]), + pos_embedding=image_pos_embedding, + embedding_ln=self.image_embedding_ln, + ) + image_pos_embedding = self.image_pos_ln(image_pos_embedding) + + x = torch.cat([image_x, x], dim=1) + padding_mask = torch.cat([image_padding_mask, padding_mask], dim=1) + pos_embedding = torch.cat([image_pos_embedding, pos_embedding], + dim=1) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + # Decoupled position embedding + B, L = pos_embedding.shape[:2] + pos_q = self.pos_q_linear(pos_embedding).view( + B, L, self.num_heads, -1).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embedding).view(B, L, self.num_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + all_hidden_states = [] if output_hidden_states else None + all_attentions = [] if output_attentions else None + + for idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(x) + + self_attn_bias = abs_pos_bias.clone() + # Add decoupled position embedding for input tokens. + token_len = input_ids.size(1) + rel_pos_bias = self.get_rel_pos_bias(input_ids, idx) + self_attn_bias[:, :, -token_len:, -token_len:] += rel_pos_bias + + # Add decoupled position embedding for images + if images is not None: + token_len = image_tokens.size(1) + rel_pos_bias = self.get_image_rel_pos_bias( + image_position_ids, idx) + self_attn_bias[:, :, :token_len, :token_len] += rel_pos_bias + + if has_pads: + attention_mask = _expand_mask(padding_mask, dtype=x.dtype) + else: + attention_mask = None + + out = layer( + x, + attention_mask=attention_mask, + attn_bias=self_attn_bias, + output_attentions=output_attentions) + x = out[0] + + if output_attentions: + all_attentions.append(out[1]) + + if output_hidden_states: + all_hidden_states.append(x) + + if self.final_ln is not None: + x = self.final_ln(x) + + return OFAEncoderOutput( + last_hidden_state=x, # (B, L, C) + padding_mask=padding_mask, # (B, L) + position_embedding=pos_embedding, # (B, L, C) + hidden_states=all_hidden_states, # list of (B, L, C) + attentions=all_attentions, # list of (B, num_heads, L, head_dims) + ) + + def get_image_tokens(self, images, sample_patch_num, images_mask): + image_embedding = self.embed_images(images)[-1] + B, C, H, W = image_embedding.shape + num_patches = H * W + + padding_mask = images.new_zeros((B, num_patches)).bool() + position_col = torch.arange(W).unsqueeze(0) + position_row = torch.arange(H).unsqueeze(1) * self.image_bucket_size + position_idx = (position_col + position_row + 1).view(-1) + position_idx = position_idx.to(images.device).expand(B, num_patches) + + # (B, C, H, W) -> (B, C, H*W) -> (B, H*W, C) + image_embedding = image_embedding.flatten(2).transpose(1, 2) + if sample_patch_num is not None: + patch_orders = torch.stack([ + torch.randperm(num_patches)[:sample_patch_num] + for _ in range(B) + ]) + num_patches = sample_patch_num + image_embedding = image_embedding.gather( + dim=1, index=patch_orders.unsqueeze(2).expand(-1, -1, C)) + padding_mask = padding_mask.gather(1, patch_orders) + position_idx = position_idx.gather(1, patch_orders) + + pos_embedding = self.embed_image_positions(position_idx) + padding_mask[~images_mask] = True + return image_embedding, padding_mask, position_idx, pos_embedding + + def process_embedding(self, + embedding, + pos_embedding=None, + type_tokens=None, + embedding_ln=None): + if self.entangle_position_embedding and pos_embedding is not None: + embedding += pos_embedding + if self.embed_type is not None: + embedding += self.embed_type(type_tokens) + if embedding_ln is not None: + embedding = embedding_ln(embedding) + embedding = self.dropout(embedding) + + return embedding + + def get_rel_pos_bias(self, x, idx): + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + +class OFADecoder(BaseModule): + """The decoder module of OFA. + + Args: + embed_tokens (nn.Embedding): The embedding module to embed the + input tokens. + num_layers (int): The number of decoder layers. Defaults to 6. + num_heads (int): The number of heads of attention. Defaults to 12. + dropout_rate (float): The prob of dropout for embedding and + transformer layers. Defaults to 0. + drop_path_rate (float): The prob of droppath for transformer layers. + Defaults to 0. + max_target_positions (int): The maximum length of the input tokens. + Defaults to 1024. + code_image_size (int): The resolution of the generated image in the + image infilling task. Defaults to 128. + token_bucket_size (int): The token bucket size, it's used as the + maximum relative position index in relative position embedding + of input tokens. Defaults to 256. + image_bucket_size (int): The image bucket size, it's used to generate + the image relative position embedding table. It should be larger + than the shape of image feature map. Defaults to 42. + attn_scale_factor (float): The scale factor to calculate qk scale in + attentions. Defaults to 2. + scale_embedding (bool): Whether to scale the embeddings by the square + root of the dimension. Defaults to False. + add_embedding_ln (bool): Whether to add an extra layer norm for token + embeddings. Defaults to True. + add_code_embedding_ln (bool): Whether to add an extra layer norm for + code embeddings. Defaults to True. + pre_norm (bool): Whether to do layer norm before attention and ffn + blocks in transformer layers. Defaults to True. + entangle_position_embedding (bool): Whether to add the position + embedding on the embeddings directly. Defaults to False. + share_input_output_embed (bool): Share the weights of the input token + embedding module and the output projection module. + Defaults to True. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + embed_tokens, + num_layers=6, + num_heads=12, + dropout_rate=0., + drop_layer_rate=0., + drop_path_rate=0., + max_target_positions=1024, + code_image_size=128, + token_bucket_size=256, + image_bucket_size=42, + attn_scale_factor=2., + scale_embedding=False, + add_embedding_ln=True, + add_code_embedding_ln=True, + pre_norm=True, + entangle_position_embedding=False, + share_input_output_embed=True, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self._future_mask = torch.empty(0) + + self.num_layers = num_layers + embedding_dim = embed_tokens.embedding_dim + self.embedding_dim = embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = max_target_positions + self.num_heads = num_heads + + # Build embedding process components + self.embed_tokens = embed_tokens + self.embedding_scale = math.sqrt( + embedding_dim) if scale_embedding else 1.0 + + if add_embedding_ln: + self.embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.embedding_ln = None + + if add_code_embedding_ln: + self.code_embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.code_embedding_ln = None + + # Build position embedding + self.embed_positions = nn.Embedding(self.max_target_positions + 2, + embedding_dim) + self.pos_ln = nn.LayerNorm(embedding_dim) + self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, + embedding_dim) + self.image_pos_ln = nn.LayerNorm(embedding_dim) + + self.pos_scaling = float(embedding_dim / num_heads * + attn_scale_factor)**-0.5 + self.self_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.self_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + self.cross_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.cross_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + + self.entangle_position_embedding = entangle_position_embedding + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate > 0. else nn.Identity() + if drop_layer_rate > 0.: + raise NotImplementedError + + # Register token relative position embedding table + self.token_bucket_size = token_bucket_size + token_num_rel_dis = 2 * token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(token_bucket_size) + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.token_rel_pos_table_list = nn.ModuleList() + + # Register image relative position embedding table + self.image_bucket_size = image_bucket_size + image_num_rel_dis = (2 * image_bucket_size - + 1) * (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(image_bucket_size, + image_num_rel_dis) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.image_rel_pos_table_list = nn.ModuleList() + + self.window_size = code_image_size // 8 + + position_col = torch.arange(self.window_size).unsqueeze(0) + position_row = torch.arange( + self.window_size).unsqueeze(1) * self.image_bucket_size + image_position_idx = (position_col + position_row + 1) + image_position_idx = torch.cat( + [torch.tensor([0]), image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.register_buffer('image_position_idx', image_position_idx) + + # Build decoder layers + self.layers = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + for index in range(self.num_layers): + layer = OFADecoderLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + dropout_rate=dropout_rate, + drop_path_rate=dpr[index], + scale_factor=attn_scale_factor, + pre_norm=pre_norm, + ) + self.layers.append(layer) + token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) + image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) + nn.init.constant_(token_pos_table.weight, 0.) + nn.init.constant_(image_pos_table.weight, 0.) + self.token_rel_pos_table_list.append(token_pos_table) + self.image_rel_pos_table_list.append(image_pos_table) + + if pre_norm: + self.final_ln = nn.LayerNorm(embedding_dim) + else: + self.final_ln = None + + # Build output projection + if share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + vocab_size = self.embed_tokens.num_embeddings + self.output_projection = nn.Linear( + embedding_dim, vocab_size, bias=False) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=embedding_dim**-0.5, + ) + + main_input_name = 'input_ids' + + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + code_masks: Optional[torch.Tensor] = None, + encoder_pos_embedding: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + + if past_key_values is not None and len(past_key_values) > 0: + B, _, L_past, _ = past_key_values[0][0].shape + L = L_past + 1 + else: + B, L = input_ids.shape + L_past = 0 + + # Embed the token position + target_pos_idx = torch.arange( + L, device=input_ids.device).expand([B, L]).contiguous() + pos_embedding = self.embed_positions(target_pos_idx) + + # Embed the code positions + if code_masks is not None and torch.any(code_masks): + image_position_idx = self.image_position_idx[:input_ids.size(1)] + image_position_idx = image_position_idx.unsqueeze(0).expand(B, L) + pos_embedding[code_masks] = self.embed_image_positions( + image_position_idx)[code_masks] + + # Self-attention position bias (B, num_heads, L_t, L_t) + self_abs_pos_bias = self.get_pos_info(self.pos_ln(pos_embedding)) + if code_masks is not None and torch.any(code_masks): + self_image_abs_pos_bias = self.get_pos_info( + self.image_pos_ln(pos_embedding)) + self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] + + # Cross-attention position bias (B, num_heads, L_t, L_s) + cross_abs_pos_bias = self.get_pos_info( + self.pos_ln(pos_embedding), encoder_pos_embedding) + if code_masks is not None and torch.any(code_masks): + cross_image_abs_pos_bias = self.get_pos_info( + self.image_pos_ln(pos_embedding), encoder_pos_embedding) + cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[ + code_masks] + + all_prev_output_tokens = input_ids.clone() + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + cross_abs_pos_bias = cross_abs_pos_bias[:, :, -1:, :] + pos_embedding = pos_embedding[:, -1:, :] + + # Embed the input tokens + x = self.embed_tokens(input_ids) * self.embedding_scale + + if self.entangle_position_embedding: + x += pos_embedding + + if self.embedding_ln is not None: + if (code_masks is None or not code_masks.any() + or self.code_embedding_ln is None): + x = self.embedding_ln(x) + elif code_masks is not None and code_masks.all(): + x = self.code_embedding_ln(x) + else: + x[~code_masks] = self.embedding_ln(x[~code_masks]) + x[code_masks] = self.code_embedding_ln(x[code_masks]) + + x = self.dropout(x) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_ids.shape, x.dtype, L_past) + attention_mask = attention_mask.to(x.device) + + # decoder layers + all_hidden_states = [] if output_hidden_states else None + all_self_attns = [] if output_attentions else None + all_cross_attentions = [] if ( + output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = [] if use_cache else None + + for idx, layer in enumerate(self.layers): + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states.append(x) + + if past_key_values is not None and len(past_key_values) > 0: + past_key_value = past_key_values[idx] + else: + past_key_value = None + + self_attn_bias = self_abs_pos_bias.clone() + if code_masks is None or not code_masks.any(): + self_attn_bias += self.get_rel_pos_bias( + all_prev_output_tokens, idx) + elif code_masks is not None and code_masks.all(): + self_attn_bias += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx) + else: + self_attn_bias[~code_masks] += self.get_rel_pos_bias( + all_prev_output_tokens, idx) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx) + + if past_key_value is not None: + self_attn_bias = self_attn_bias[:, :, -1:, :] + + out = layer( + x, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + self_attn_bias=self_attn_bias, + cross_attn_bias=cross_abs_pos_bias, + ) + x = out.pop(0) + + if output_attentions: + all_self_attns.append(out.pop(0)) + if encoder_hidden_states is not None: + all_cross_attentions.append(out.pop(0)) + + if use_cache: + next_decoder_cache.append(out.pop(0)) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (x, ) + + if self.final_ln is not None: + x = self.final_ln(x) + + x = self.output_projection(x) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=x, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + def _prepare_decoder_attention_mask( + self, + attention_mask, + input_shape, + dtype, + past_key_values_length, + ): + r""" + Create causal mask for unidirectional decoding. + [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + dtype, + past_key_values_length=past_key_values_length).to( + attention_mask.device) + + if attention_mask is not None: + # (B, L_s) -> (B, 1, L_t, L_s) + expanded_attention_mask = _expand_mask( + attention_mask, dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attention_mask if combined_attention_mask is None else + expanded_attention_mask + combined_attention_mask) + + return combined_attention_mask + + def get_pos_info(self, pos_embedding, src_pos_embedding=None): + B, tgt_len = pos_embedding.shape[:2] + if src_pos_embedding is not None: + src_len = src_pos_embedding.size(1) + pos_q = self.cross_pos_q_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + pos_q = pos_q * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embedding).view( + B, src_len, self.num_heads, -1).transpose(1, 2) + else: + pos_q = self.self_pos_q_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + pos_q = pos_q * self.pos_scaling + pos_k = self.self_pos_k_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + return abs_pos_bias + + def get_rel_pos_bias(self, x, idx): + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + +class OFAEncoderDecoder(BaseModule, GenerationMixin): + """The OFA main architecture with an encoder and a decoder. + + Args: + encoder_cfg (dict): The config of the encoder, accept the keyword + arguments of :class:`OFAEncoder`. + decoder_cfg (dict): The config of the decoder, accept the keyword + arguments of :class:`OFADecoder`. + padding_idx (int): The index of the padding token. + vocab_size (int): The size of the vocabulary. + embedding_dim (int): The embedding dimensions of both the encoder + and the decoder. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of :class:`~transformers.GenerationConfig`. + Defaults to an empty dict. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + base_model_prefix = '' + + def __init__( + self, + encoder_cfg, + decoder_cfg, + padding_idx, + vocab_size, + embedding_dim, + generation_cfg=dict(), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + self.padding_idx = padding_idx + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + embed_tokens = nn.Embedding(vocab_size, embedding_dim, padding_idx) + + self.encoder = OFAEncoder(embed_tokens, **encoder_cfg) + self.decoder = OFADecoder(embed_tokens, **decoder_cfg) + + self.config = PretrainedConfig( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + bos_token_id=0, + decoder_start_token_id=0, + pad_token_id=1, + eos_token_id=2, + forced_eos_token_id=2, + use_cache=False, + is_encoder_decoder=True, + ) + self.config.update(generation_cfg) + + self.generation_config = GenerationConfig.from_model_config( + self.config) + + @property + def device(self): + return next(self.parameters()).device + + def can_generate(self): + return True + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + def get_normalized_probs(self, net_output, log_probs: bool, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + def get_normalized_probs_scriptable( + self, + net_output, + log_probs: bool, + sample=None, + ): + """Scriptable helper function for get_normalized_probs in. + + ~BaseFairseqModel. + """ + if hasattr(self, 'decoder'): + return self.decoder.get_normalized_probs(net_output, log_probs, + sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + main_input_name = 'input_ids' + + def forward(self, + input_ids=None, + images=None, + images_mask=None, + sample_patch_num=None, + decoder_input_ids=None, + code_masks=None, + attention_mask=None, + encoder_outputs=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + constrain_fn=None, + return_dict=False): + """Forword the module. + + Args: + input_ids (torch.Tensor): The indices of the input tokens in the + vocabulary, and padding will be ignored by default. The indices + can be obtained using :class:`OFATokenizer`. + The shape is (B, L). + images (torch.Tensor): The input images. The shape is (B, 3, H, W). + images_mask (torch.Tensor): The mask of all available images. The + shape is (B, ). + sample_patch_num (int): The number of patches to sample for the + images. Defaults to None, which means to use all patches. + decoder_input_ids (torch.Tensor): The indices of the input tokens + for the decoder. + code_masks (torch.Tensor): The mask of all samples for image + generation. The shape is (B, ). + attention_mask (torch.Tensor): The attention mask for decoding. + The shape is (B, L). + encoder_outputs (OFAEncoderOutput): The encoder outputs with hidden + states, positional embeddings, and padding masks. + past_key_values (Tuple[Tuple[torch.Tensor]]): If use cache, the + parameter is a tuple of length ``num_layers``. Every item is + also a tuple with four tensors, two for the key and value of + self-attention, two for the key and value of cross-attention. + use_cache (bool): Whether to use cache for faster inference. + Defaults to False. + output_attentions (bool): Whether to output attention weights. + Defaults to False. + output_hidden_states (bool): Whether to output hidden states. + Defaults to False. + constrain_fn (Callable, optional): The function to constrain the + output logits. Defaults to None. + return_dict (bool): Not used, it's only for compat with the + interface of the ``generate`` of ``transformers``. + + Returns: + Seq2SeqLMOutput: + + - logits (``torch.Tensor``): The last decoder hidden states. + The shape is (B, L, C). + - past_key_values (``Tuple[Tuple[torch.Tensor]]``): The past keys + and values for faster inference. + - decoder_hidden_states (``Tuple[torch.Tensor]``): the decoder + hidden states of all layers. + - decoder_attentions (``Tuple[torch.Tensor]``): The self-attention + weights of all layers in the decoder. + - cross_attentions (``Tuple[torch.Tensor]``): The cross-attention + weights of all layers in the decoder. + - encoder_last_hidden_state (``torch.Tensor``): The last encoder + hidden states. + - encoder_hidden_states (``Tuple[torch.Tensor]``): The encoder + hidden states of all layers, including the embeddings. + - encoder_attentions (``Tuple[torch.Tensor]``): The self-attention + weights of all layers in the encoder. + """ + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + images=images, + images_mask=images_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + sample_patch_num=sample_patch_num, + ) + + if decoder_input_ids.eq(self.padding_idx).any(): + attention_mask = decoder_input_ids.eq(self.padding_idx) + + encoder_hidden_states = encoder_outputs.last_hidden_state + encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, + encoder_hidden_states.dtype, + decoder_input_ids.shape[-1]) + src_pos_embed = encoder_outputs.position_embedding + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_masks, + encoder_pos_embedding=src_pos_embed, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # The constrain operation for fine-tuned model in OFA is applied + # before log_softmax, therefore we cannot use + # `prefix_allowed_tokens_fn` to implement it. + if constrain_fn is not None: + logits = constrain_fn(decoder_input_ids, + decoder_outputs.last_hidden_state) + else: + logits = decoder_outputs.last_hidden_state + + return Seq2SeqLMOutput( + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + decoder_input_ids=None, + past=None, + attention_mask=None, + code_masks=None, + use_cache=False, + encoder_outputs=None, + constrain_fn=None, + **kwargs): + # if attention_mask is None: + attention_mask = decoder_input_ids.new_zeros(decoder_input_ids.shape) + + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + 'input_ids': None, + 'images': None, + 'images_mask': None, + 'sample_patch_num': None, + 'attention_mask': attention_mask, + 'encoder_outputs': encoder_outputs, + 'past_key_values': past, + 'decoder_input_ids': decoder_input_ids, + 'code_masks': code_masks, + 'use_cache': use_cache, + 'constrain_fn': constrain_fn, + } + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None): + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = [ + 'decoder_', 'cross_attn', 'use_cache', 'attention_mask', + 'constrain_fn' + ] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + if encoder_kwargs.get('images_mask') is None: + encoder_kwargs['images_mask'] = torch.tensor([True] * + inputs_tensor.size(0)) + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name or self.main_input_name + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs['encoder_outputs']: ModelOutput = encoder( + **encoder_kwargs) + model_kwargs['attention_mask'] = None + + return model_kwargs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat( + 1, expand_size).view(-1).to(input_ids.device)) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs['attention_mask'] = attention_mask.index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError('If `is_encoder_decoder` is True, make ' + 'sure that `encoder_outputs` is defined.') + encoder_outputs['last_hidden_state'] = encoder_outputs.\ + last_hidden_state.index_select(0, expanded_return_idx) + encoder_outputs['position_embedding'] = encoder_outputs.\ + position_embedding.index_select(0, expanded_return_idx) + encoder_outputs['padding_mask'] = encoder_outputs.\ + padding_mask.index_select(0, expanded_return_idx) + model_kwargs['encoder_outputs'] = encoder_outputs + return input_ids, model_kwargs diff --git a/mmpretrain/models/multimodal/otter/__init__.py b/mmpretrain/models/multimodal/otter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38a45a3d17458eae2471846b43498aa06cdfaac3 --- /dev/null +++ b/mmpretrain/models/multimodal/otter/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .otter import Otter + +__all__ = ['Otter'] diff --git a/mmpretrain/models/multimodal/otter/otter.py b/mmpretrain/models/multimodal/otter/otter.py new file mode 100644 index 0000000000000000000000000000000000000000..7d30b509410fca6bb6bb61ba2756f851e388f944 --- /dev/null +++ b/mmpretrain/models/multimodal/otter/otter.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler + + +@MODELS.register_module() +class Otter(Flamingo): + """The Otter model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to an. + shot_prompt_tmpl (str): Prompt used for few-shot inference. + Defaults to ``User:Please describe the image. + GPT:{caption}<|endofchunk|>``. + final_prompt_tmpl (str): Final part of prompt used for inference. + Defaults to 'User:Please describe the image. GPT:'. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + _no_split_modules = [ + 'TransformerEncoderLayer', 'PerceiverAttention', + 'GatedCrossAttentionBlock', 'FlamingoLayer' + ] + + def __init__( + self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + zeroshot_prompt: str = '', + shot_prompt_tmpl: str = ('User:Please describe the image. ' + 'GPT:{caption}<|endofchunk|>'), + final_prompt_tmpl: str = ('User:Please describe the image. ' + 'GPT:'), + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(Flamingo, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Otter special tokens to the tokenizer + self.tokenizer.add_special_tokens({ + 'additional_special_tokens': + ['<|endofchunk|>', '', ''] + }) + self.tokenizer.bos_token_id = 1 + if self.tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + self.tokenizer.add_special_tokens({'pad_token': ''}) + + # Template to format the prompt input + self.zeroshot_prompt = zeroshot_prompt + self.shot_prompt_tmpl = shot_prompt_tmpl + self.final_prompt_tmpl = final_prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + self.vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + self.vision_encoder.is_init = True + + self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) + + # init language encoder related modules + self.lang_encoder = ExtendModule(**lang_encoder) + self.lang_encoder.resize_token_embeddings(len(self.tokenizer)) + self.lang_encoder.media_token_id = self.tokenizer.encode('')[-1] + + # other necessary parameters + self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1] + self.generation_cfg = generation_cfg + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_adapter_hook) + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = output + elif self.task == 'vqa': + data_sample.pred_answer = output + + return data_samples diff --git a/mmpretrain/models/necks/__init__.py b/mmpretrain/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2952a691758843436dd70ad6a11a390216ac724a --- /dev/null +++ b/mmpretrain/models/necks/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beitv2_neck import BEiTV2Neck +from .cae_neck import CAENeck +from .densecl_neck import DenseCLNeck +from .gap import GlobalAveragePooling +from .gem import GeneralizedMeanPooling +from .hr_fuse import HRFuseScales +from .itpn_neck import iTPNPretrainDecoder +from .linear_neck import LinearNeck +from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder +from .milan_neck import MILANPretrainDecoder +from .mixmim_neck import MixMIMPretrainDecoder +from .mocov2_neck import MoCoV2Neck +from .nonlinear_neck import NonLinearNeck +from .simmim_neck import SimMIMLinearDecoder +from .spark_neck import SparKLightDecoder +from .swav_neck import SwAVNeck + +__all__ = [ + 'GlobalAveragePooling', + 'GeneralizedMeanPooling', + 'HRFuseScales', + 'LinearNeck', + 'BEiTV2Neck', + 'CAENeck', + 'DenseCLNeck', + 'MAEPretrainDecoder', + 'ClsBatchNormNeck', + 'MILANPretrainDecoder', + 'MixMIMPretrainDecoder', + 'MoCoV2Neck', + 'NonLinearNeck', + 'SimMIMLinearDecoder', + 'SwAVNeck', + 'iTPNPretrainDecoder', + 'SparKLightDecoder', +] diff --git a/mmpretrain/models/necks/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f892de3761d6a89d117171f330745483ce9673e5 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/beitv2_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/beitv2_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18ef46107230629a84ae7f5fedb8c06e9cdbaa77 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/beitv2_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/cae_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/cae_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..378b7f070cce8738207a2d0b51e81218621c9334 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/cae_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/densecl_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/densecl_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..850699c3547e3cb61d8384701339827028f24ea0 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/densecl_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/gap.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/gap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f26d85f3a6b5a1b06889ef9812ba7f9da5fb0695 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/gap.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/gem.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/gem.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ee9172f0aafc5dcd1730cb9a8d6d34a7367c221 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/gem.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/hr_fuse.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/hr_fuse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57ea4c45bef41a01b9d9b170a9ed5bf065652e70 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/hr_fuse.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/itpn_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/itpn_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55e7568899ba938a754e1c5f64bdecad59724381 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/itpn_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/linear_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/linear_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cca0ef850c14ed72295d521af68c0107bbd6b29 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/linear_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/mae_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/mae_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c61b40158a211aad787f120d3d35328e8b621956 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/mae_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/milan_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/milan_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5a759d131812434a9d3916af468a0a74d7696c Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/milan_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/mixmim_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/mixmim_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6afb55df4f6a442d1a67744923222aa9f1ec67 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/mixmim_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/mocov2_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/mocov2_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a600a4dc32c363ecbcf805f8daa7edef1cc50084 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/mocov2_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/nonlinear_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/nonlinear_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c446379773916f9fcd7064981b58f01055fcb6c5 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/nonlinear_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/simmim_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/simmim_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29fb3a431c2bce3ce1d9f1bf8bbf0e4181f5c7e7 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/simmim_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/spark_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/spark_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..465933f70f714b34cbbfd5d9fe35b50677bc075a Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/spark_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/__pycache__/swav_neck.cpython-311.pyc b/mmpretrain/models/necks/__pycache__/swav_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a39b7f1a753c3a9cc318ffa103c2e639e911f186 Binary files /dev/null and b/mmpretrain/models/necks/__pycache__/swav_neck.cpython-311.pyc differ diff --git a/mmpretrain/models/necks/beitv2_neck.py b/mmpretrain/models/necks/beitv2_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..745e3879f5e3a4b9269687797728354cb6cf7d4e --- /dev/null +++ b/mmpretrain/models/necks/beitv2_neck.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Neck(BaseModule): + """Neck for BEiTV2 Pre-training. + + This module construct the decoder for the final prediction. + + Args: + num_layers (int): Number of encoder layers of neck. Defaults to 2. + early_layers (int): The layer index of the early output from the + backbone. Defaults to 9. + backbone_arch (str): Vision Transformer architecture. Defaults to base. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initialization value for the + learnable scaling of attention and FFN. Defaults to 0.1. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'depth': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'depth': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + }), + } + + def __init__( + self, + num_layers: int = 2, + early_layers: int = 9, + backbone_arch: str = 'base', + drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.1, + use_rel_pos_bias: bool = False, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + + if isinstance(backbone_arch, str): + backbone_arch = backbone_arch.lower() + assert backbone_arch in set(self.arch_zoo), \ + (f'Arch {backbone_arch} is not in default archs ' + f'{set(self.arch_zoo)}') + self.arch_settings = self.arch_zoo[backbone_arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(backbone_arch, dict) and essential_keys <= set( + backbone_arch + ), f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = backbone_arch + + # stochastic depth decay rule + self.early_layers = early_layers + depth = self.arch_settings['depth'] + dpr = np.linspace(0, drop_path_rate, + max(depth, early_layers + num_layers)) + + self.patch_aggregation = nn.ModuleList() + for i in range(early_layers, early_layers + num_layers): + _layer_cfg = dict( + embed_dims=self.arch_settings['embed_dims'], + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + use_rel_pos_bias=use_rel_pos_bias) + self.patch_aggregation.append( + BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.rescale_patch_aggregation_init_weight() + + embed_dims = self.arch_settings['embed_dims'] + _, norm = build_norm_layer(norm_cfg, embed_dims) + self.add_module('norm', norm) + + def rescale_patch_aggregation_init_weight(self): + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.patch_aggregation): + rescale(layer.attn.proj.weight.data, + self.early_layers + layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, + self.early_layers + layer_id + 1) + + def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x (Tuple[torch.Tensor]): Features of tokens. + rel_pos_bias (torch.Tensor): Shared relative position bias table. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``x``: The final layer features from backbone, which are normed + in ``BEiTV2Neck``. + - ``x_cls_pt``: The early state features from backbone, which are + consist of final layer cls_token and early state patch_tokens + from backbone and sent to PatchAggregation layers in the neck. + """ + + early_states, x = inputs[0], inputs[1] + x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1) + for layer in self.patch_aggregation: + x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias) + + # shared norm + x, x_cls_pt = self.norm(x), self.norm(x_cls_pt) + + # remove cls_token + x = x[:, 1:] + x_cls_pt = x_cls_pt[:, 1:] + return x, x_cls_pt diff --git a/mmpretrain/models/necks/cae_neck.py b/mmpretrain/models/necks/cae_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..81fc30111362ca6f602a0d3f456fbc991926a99f --- /dev/null +++ b/mmpretrain/models/necks/cae_neck.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS +from ..utils import CrossMultiheadAttention + + +class CAETransformerRegressorLayer(BaseModule): + """Transformer layer for the regressor of CAE. + + This module is different from conventional transformer encoder layer, for + its queries are the masked tokens, but its keys and values are the + concatenation of the masked and unmasked tokens. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): The number of heads in multi-head attention. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + drop_rate (float): The dropout rate. Defaults to 0.0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The init value of gamma. + Defaults to 0.0. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + num_fcs: int = 2, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.0, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN', eps=1e-6) + ) -> None: + super().__init__() + + # NOTE: cross attention + _, self.norm1_q_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm1_k_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm1_v_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm2_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2) + self.cross_attn = CrossMultiheadAttention( + embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=drop_rate) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=None, + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = DropPath(drop_prob=drop_path_rate) + + if layer_scale_init_value > 0: + self.gamma_1_cross = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2_cross = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + else: + self.gamma_1_cross = nn.Parameter( + torch.ones((embed_dims)), requires_grad=False) + self.gamma_2_cross = nn.Parameter( + torch.ones((embed_dims)), requires_grad=False) + + def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, + pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn( + self.norm1_q_cross(x_q + pos_q), + k=self.norm1_k_cross(x_kv + pos_k), + v=self.norm1_v_cross(x_kv))) + x = self.norm2_cross(x) + x = x + self.drop_path(self.gamma_2_cross * self.ffn(x)) + + return x + + +@MODELS.register_module() +class CAENeck(BaseModule): + """Neck for CAE Pre-training. + + This module construct the latent prediction regressor and the decoder + for the latent prediction and final prediction. + + Args: + num_classes (int): The number of classes for final prediction. Defaults + to 8192. + embed_dims (int): The embed dims of latent feature in regressor and + decoder. Defaults to 768. + regressor_depth (int): The number of regressor blocks. Defaults to 6. + decoder_depth (int): The number of decoder blocks. Defaults to 8. + num_heads (int): The number of head in multi-head attention. Defaults + to 12. + mlp_ratio (int): The expand ratio of latent features in MLP. defaults + to 4. + qkv_bias (bool): Whether or not to use qkv bias. Defaults to True. + qk_scale (float, optional): The scale applied to the results of qk. + Defaults to None. + drop_rate (float): The dropout rate. Defaults to 0. + attn_drop_rate (float): The dropout rate in attention block. Defaults + to 0. + norm_cfg (dict): The config of normalization layer. Defaults to + dict(type='LN', eps=1e-6). + layer_scale_init_value (float, optional): The init value of gamma. + Defaults to None. + mask_tokens_num (int): The number of mask tokens. Defaults to 75. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_classes: int = 8192, + embed_dims: int = 768, + regressor_depth: int = 6, + decoder_depth: int = 8, + num_heads: int = 12, + mlp_ratio: int = 4, + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_cfg: dict = dict(type='LN', eps=1e-6), + layer_scale_init_value: float = None, + mask_tokens_num: int = 75, + init_cfg: dict = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.num_features = self.embed_dim = embed_dims + self.mask_token_num = mask_tokens_num + + # regressor + regressor_drop_path_rates = [ + x.item() + for x in torch.linspace(0, drop_path_rate, regressor_depth) + ] + self.regressors = nn.ModuleList([ + CAETransformerRegressorLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=regressor_drop_path_rates[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value) + for i in range(regressor_depth) + ]) + + # decoder + decoder_drop_path_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, decoder_depth) + ] + self.decoders = nn.ModuleList([ + BEiTTransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + # setting `use_rel_pos_bias` to False ignores the `window_size` + use_rel_pos_bias=False, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=decoder_drop_path_rates[i], + norm_cfg=norm_cfg) for i in range(decoder_depth) + ]) + + _, self.norm_regressor = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm_decoder = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + + self.head = nn.Linear( + embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + def init_weights(self) -> None: + """Initialization.""" + super().init_weights() + self.apply(self._init_weights) + trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.head.weight, std=0.02) + + def _init_weights(self, m: nn.Module) -> None: + """Initialization.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, x_unmasked: torch.Tensor, pos_embed_masked: torch.Tensor, + pos_embed_unmasked: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x_unmasked (torch.Tensor): Features of unmasked tokens. + pos_embed_masked (torch.Tensor): Position embedding of masked + tokens. + pos_embed_unmasked (torch.Tensor): Position embedding of unmasked + tokens. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``logits``: Final prediction. + - ``latent_pred``: Latent prediction. + """ + x_masked = self.mask_token.expand(x_unmasked.shape[0], + self.mask_token_num, -1) + # regressor + for regressor in self.regressors: + x_masked = regressor( + x_masked, torch.cat([x_unmasked, x_masked], dim=1), + pos_embed_masked, + torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1)) + x_masked = self.norm_regressor(x_masked) + latent_pred = x_masked + + # decoder + x_masked = x_masked + pos_embed_masked + for decoder in self.decoders: + x_masked = decoder(x_masked, rel_pos_bias=None) + x_masked = self.norm_decoder(x_masked) + + logits = self.head(x_masked) + + return logits, latent_pred diff --git a/mmpretrain/models/necks/densecl_neck.py b/mmpretrain/models/necks/densecl_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..bee9a2368d8917ece7b4b8ab8d1398ce951ede24 --- /dev/null +++ b/mmpretrain/models/necks/densecl_neck.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DenseCLNeck(BaseModule): + """The non-linear neck of DenseCL. + + Single and dense neck in parallel: fc-relu-fc, conv-relu-conv. + Borrowed from the authors' `code `_. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + num_grid (int): The grid size of dense features. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + hid_channels: int, + out_channels: int, + num_grid: Optional[int] = None, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + self.with_pool = True if num_grid is not None else False + if self.with_pool: + self.pool = nn.AdaptiveAvgPool2d((num_grid, num_grid)) + self.mlp2 = nn.Sequential( + nn.Conv2d(in_channels, hid_channels, 1), nn.ReLU(inplace=True), + nn.Conv2d(hid_channels, out_channels, 1)) + self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function of neck. + + Args: + x (Tuple[torch.Tensor]): feature map of backbone. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - ``avgpooled_x``: Global feature vectors. + - ``x``: Dense feature vectors. + - ``avgpooled_x2``: Dense feature vectors for queue. + """ + assert len(x) == 1 + x = x[0] + + avgpooled_x = self.avgpool(x) + avgpooled_x = self.mlp(avgpooled_x.view(avgpooled_x.size(0), -1)) + + if self.with_pool: + x = self.pool(x) # sxs + x = self.mlp2(x) # sxs: bxdxsxs + avgpooled_x2 = self.avgpool2(x) # 1x1: bxdx1x1 + x = x.view(x.size(0), x.size(1), -1) # bxdxs^2 + avgpooled_x2 = avgpooled_x2.view(avgpooled_x2.size(0), -1) # bxd + return avgpooled_x, x, avgpooled_x2 diff --git a/mmpretrain/models/necks/gap.py b/mmpretrain/models/necks/gap.py new file mode 100644 index 0000000000000000000000000000000000000000..0877743ad1e5a75976feb14f5d34942c0b7b8ee4 --- /dev/null +++ b/mmpretrain/models/necks/gap.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class GlobalAveragePooling(nn.Module): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. + Default: 2 + """ + + def __init__(self, dim=2): + super(GlobalAveragePooling, self).__init__() + assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \ + f'{1, 2, 3}, get {dim} instead.' + if dim == 1: + self.gap = nn.AdaptiveAvgPool1d(1) + elif dim == 2: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + else: + self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) + + def init_weights(self): + pass + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([self.gap(x) for x in inputs]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = self.gap(inputs) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpretrain/models/necks/gem.py b/mmpretrain/models/necks/gem.py new file mode 100644 index 0000000000000000000000000000000000000000..f5648be86303caa6f2c25786fe8c3058c2f98d7e --- /dev/null +++ b/mmpretrain/models/necks/gem.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.parameter import Parameter + +from mmpretrain.registry import MODELS + + +def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor: + if clamp: + x = x.clamp(min=eps) + return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + + +@MODELS.register_module() +class GeneralizedMeanPooling(nn.Module): + """Generalized Mean Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + p (float): Parameter value. Defaults to 3. + eps (float): epsilon. Defaults to 1e-6. + clamp (bool): Use clamp before pooling. Defaults to True + p_trainable (bool): Toggle whether Parameter p is trainable or not. + Defaults to True. + """ + + def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True): + assert p >= 1, "'p' must be a value greater than 1" + super(GeneralizedMeanPooling, self).__init__() + self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable) + self.eps = eps + self.clamp = clamp + self.p_trainable = p_trainable + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([ + gem(x, p=self.p, eps=self.eps, clamp=self.clamp) + for x in inputs + ]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpretrain/models/necks/hr_fuse.py b/mmpretrain/models/necks/hr_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..4a97f86f9fb9e4cce89e950e54674d5ec3d9b1f7 --- /dev/null +++ b/mmpretrain/models/necks/hr_fuse.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..backbones.resnet import Bottleneck, ResLayer + + +@MODELS.register_module() +class HRFuseScales(BaseModule): + """Fuse feature map of multiple scales in HRNet. + + Args: + in_channels (list[int]): The input channels of all scales. + out_channels (int): The channels of fused feature map. + Defaults to 2048. + norm_cfg (dict): dictionary to construct norm layers. + Defaults to ``dict(type='BN', momentum=0.1)``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``. + """ + + def __init__(self, + in_channels, + out_channels=2048, + norm_cfg=dict(type='BN', momentum=0.1), + init_cfg=dict(type='Normal', layer='Linear', std=0.01)): + super(HRFuseScales, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + + block_type = Bottleneck + out_channels = [128, 256, 512, 1024] + + # Increase the channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + increase_layers = [] + for i in range(len(in_channels)): + increase_layers.append( + ResLayer( + block_type, + in_channels=in_channels[i], + out_channels=out_channels[i], + num_blocks=1, + stride=1, + )) + self.increase_layers = nn.ModuleList(increase_layers) + + # Downsample feature maps in each scale. + downsample_layers = [] + for i in range(len(in_channels) - 1): + downsample_layers.append( + ConvModule( + in_channels=out_channels[i], + out_channels=out_channels[i + 1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + bias=False, + )) + self.downsample_layers = nn.ModuleList(downsample_layers) + + # The final conv block before final classifier linear layer. + self.final_layer = ConvModule( + in_channels=out_channels[3], + out_channels=self.out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + bias=False, + ) + + def forward(self, x): + assert isinstance(x, tuple) and len(x) == len(self.in_channels) + + feat = self.increase_layers[0](x[0]) + for i in range(len(self.downsample_layers)): + feat = self.downsample_layers[i](feat) + \ + self.increase_layers[i + 1](x[i + 1]) + + return (self.final_layer(feat), ) diff --git a/mmpretrain/models/necks/itpn_neck.py b/mmpretrain/models/necks/itpn_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3626af634b185fef9b0b2fb47c1fdc15e1139b --- /dev/null +++ b/mmpretrain/models/necks/itpn_neck.py @@ -0,0 +1,388 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.models.backbones.hivit import BlockWithRPE +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import build_2d_sincos_position_embedding + + +class PatchSplit(nn.Module): + """The up-sample module used in neck (transformer pyramid network) + + Args: + dim (int): the input dimension (channel number). + fpn_dim (int): the fpn dimension (channel number). + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__(self, dim, fpn_dim, norm_cfg): + super().__init__() + _, self.norm = build_norm_layer(norm_cfg, dim) + self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False) + self.fpn_dim = fpn_dim + + def forward(self, x): + B, N, H, W, C = x.shape + x = self.norm(x) + x = self.reduction(x) + x = x.reshape(B, N, H, W, 2, 2, + self.fpn_dim).permute(0, 1, 2, 4, 3, 5, + 6).reshape(B, N, 2 * H, 2 * W, + self.fpn_dim) + return x + + +@MODELS.register_module() +class iTPNPretrainDecoder(BaseModule): + """The neck module of iTPN (transformer pyramid network). + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 512. + fpn_dim (int): The fpn dimension (channel number). + fpn_depth (int): The layer number of feature pyramid. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + reconstruction_type (str): The itpn supports 2 kinds of supervisions. + Defaults to 'pixel'. + num_outs (int): The output number of neck (transformer pyramid + network). Defaults to 3. + predict_feature_dim (int): The output dimension to supervision. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 512, + fpn_dim: int = 256, + fpn_depth: int = 2, + decoder_embed_dim: int = 512, + decoder_depth: int = 6, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + reconstruction_type: str = 'pixel', + num_outs: int = 3, + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + predict_feature_dim: Optional[float] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_patches = num_patches + assert reconstruction_type in ['pixel', 'clip'], \ + 'iTPN method only support `pixel` and `clip`, ' \ + f'but got `{reconstruction_type}`.' + self.reconstruction_type = reconstruction_type + self.num_outs = num_outs + + self.build_transformer_pyramid( + num_outs=num_outs, + embed_dim=embed_dim, + fpn_dim=fpn_dim, + fpn_depth=fpn_depth, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + rpe=False, + norm_cfg=norm_cfg, + ) + + # merge the output + self.decoder_embed = nn.ModuleList() + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim, bias=True), + )) + + if self.num_outs >= 2: + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True), + )) + if self.num_outs >= 3: + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True), + )) + + if reconstruction_type == 'pixel': + self.mask_token = nn.Parameter( + torch.zeros(1, 1, decoder_embed_dim)) + + # create new position embedding, different from that in encoder + # and is not learnable + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, decoder_embed_dim), + requires_grad=False) + + self.decoder_blocks = nn.ModuleList([ + TransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + self.decoder_norm_name, decoder_norm = build_norm_layer( + norm_cfg, decoder_embed_dim, postfix=1) + self.add_module(self.decoder_norm_name, decoder_norm) + + # Used to map features to pixels + if predict_feature_dim is None: + predict_feature_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + else: + _, norm = build_norm_layer(norm_cfg, embed_dim) + self.add_module('norm', norm) + + def build_transformer_pyramid(self, + num_outs=3, + embed_dim=512, + fpn_dim=256, + fpn_depth=2, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + rpe=False, + norm_cfg=None): + Hp = None + mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim} + if num_outs > 1: + if embed_dim != fpn_dim: + self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim) + else: + self.align_dim_16tofpn = None + self.fpn_modules = nn.ModuleList() + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg)) + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=False, + norm_cfg=norm_cfg, + )) + + self.align_dim_16to8 = nn.Linear( + mlvl_dims['8'], fpn_dim, bias=False) + self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg) + self.block_16to8 = nn.Sequential(*[ + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg, + ) for _ in range(fpn_depth) + ]) + + if num_outs > 2: + self.align_dim_8to4 = nn.Linear( + mlvl_dims['4'], fpn_dim, bias=False) + self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg) + self.block_8to4 = nn.Sequential(*[ + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg, + ) for _ in range(fpn_depth) + ]) + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg)) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MAE decoder.""" + super().init_weights() + + if self.reconstruction_type == 'pixel': + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=False) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + else: + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.fpn_modules): + if isinstance(layer, BlockWithRPE): + if layer.attn is not None: + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + @property + def decoder_norm(self): + """The normalization layer of decoder.""" + return getattr(self, self.decoder_norm_name) + + def forward(self, + x: torch.Tensor, + ids_restore: torch.Tensor = None) -> torch.Tensor: + """The forward function. + + The process computes the visible patches' features vectors and the mask + tokens to output feature vectors, which will be used for + reconstruction. + + Args: + x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + ids_restore (torch.Tensor): ids to restore original image. + + Returns: + torch.Tensor: The reconstructed feature vectors, which is of + shape B x (num_patches) x C. + """ + + features = x[:2] + x = x[-1] + B, L, _ = x.shape + x = x[..., None, None, :] + Hp = Wp = math.sqrt(L) + + outs = [x] if self.align_dim_16tofpn is None else [ + self.align_dim_16tofpn(x) + ] + if self.num_outs >= 2: + x = self.block_16to8( + self.split_16to8(x) + self.align_dim_16to8(features[1])) + outs.append(x) + if self.num_outs >= 3: + x = self.block_8to4( + self.split_8to4(x) + self.align_dim_8to4(features[0])) + outs.append(x) + if self.num_outs > 3: + outs = [ + out.reshape(B, Hp, Wp, *out.shape[-3:]).permute( + 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3], + Wp * out.shape[-2]).contiguous() + for out in outs + ] + if self.num_outs >= 4: + outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) + if self.num_outs >= 5: + outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) + + for i, out in enumerate(outs): + out = self.fpn_modules[i](out) + outs[i] = out + + if self.reconstruction_type == 'pixel': + feats = [] + for feat, layer in zip(outs, self.decoder_embed): + x = layer(feat).reshape(B, L, -1) + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x = torch.cat([x, mask_tokens], dim=1) + x = torch.gather( + x, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + feats.append(x) + x = feats.pop(0) + # add pos embed + x = x + self.decoder_pos_embed + + for i, feat in enumerate(feats): + x = x + feats[i] + # apply Transformer blocks + for i, blk in enumerate(self.decoder_blocks): + x = blk(x) + x = self.decoder_norm(x) + x = self.decoder_pred(x) + return x + else: + feats = [] + for feat, layer in zip(outs, self.decoder_embed): + x = layer(feat).reshape(B, L, -1) + feats.append(x) + x = feats.pop(0) + for i, feat in enumerate(feats): + x = x + feats[i] + + x = self.norm(x) + + return x diff --git a/mmpretrain/models/necks/linear_neck.py b/mmpretrain/models/necks/linear_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdbee264325c8db0a054f765651a5dbadc968db --- /dev/null +++ b/mmpretrain/models/necks/linear_neck.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class LinearNeck(BaseModule): + """Linear neck with Dimension projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + gap_dim (int): Dimensions of each sample channel, can be one of + {0, 1, 2, 3}. Defaults to 0. + norm_cfg (dict, optional): dictionary to construct and + config norm layer. Defaults to dict(type='BN1d'). + act_cfg (dict, optional): dictionary to construct and + config activate layer. Defaults to None. + init_cfg (dict, optional): dictionary to initialize weights. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + gap_dim: int = 0, + norm_cfg: Optional[dict] = dict(type='BN1d'), + act_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.norm_cfg = copy.deepcopy(norm_cfg) + self.act_cfg = copy.deepcopy(act_cfg) + + assert gap_dim in [0, 1, 2, 3], 'GlobalAveragePooling dim only ' \ + f'support {0, 1, 2, 3}, get {gap_dim} instead.' + if gap_dim == 0: + self.gap = nn.Identity() + elif gap_dim == 1: + self.gap = nn.AdaptiveAvgPool1d(1) + elif gap_dim == 2: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + elif gap_dim == 3: + self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) + + self.fc = nn.Linear(in_features=in_channels, out_features=out_channels) + + if norm_cfg: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + else: + self.norm = nn.Identity() + + if act_cfg: + self.act = build_activation_layer(act_cfg) + else: + self.act = nn.Identity() + + def forward(self, inputs: Union[Tuple, + torch.Tensor]) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Union[Tuple, torch.Tensor]): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + + Returns: + Tuple[torch.Tensor]: A tuple of output features. + """ + assert isinstance(inputs, (tuple, torch.Tensor)), ( + 'The inputs of `LinearNeck` must be tuple or `torch.Tensor`, ' + f'but get {type(inputs)}.') + if isinstance(inputs, tuple): + inputs = inputs[-1] + + x = self.gap(inputs) + x = x.view(x.size(0), -1) + out = self.act(self.norm(self.fc(x))) + return (out, ) diff --git a/mmpretrain/models/necks/mae_neck.py b/mmpretrain/models/necks/mae_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..773692dcb3a94d85d2d2085360fd339493a24db3 --- /dev/null +++ b/mmpretrain/models/necks/mae_neck.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import build_2d_sincos_position_embedding + + +@MODELS.register_module() +class MAEPretrainDecoder(BaseModule): + """Decoder for MAE Pre-training. + + Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + + Example: + >>> from mmpretrain.models import MAEPretrainDecoder + >>> import torch + >>> self = MAEPretrainDecoder() + >>> self.eval() + >>> inputs = torch.rand(1, 50, 1024) + >>> ids_restore = torch.arange(0, 196).unsqueeze(0) + >>> level_outputs = self.forward(inputs, ids_restore) + >>> print(tuple(level_outputs.shape)) + (1, 196, 768) + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + predict_feature_dim: Optional[float] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_patches = num_patches + + # used to convert the dim of features from encoder to the dim + # compatible with that of decoder + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + # create new position embedding, different from that in encoder + # and is not learnable + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, decoder_embed_dim), + requires_grad=False) + + self.decoder_blocks = nn.ModuleList([ + TransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + self.decoder_norm_name, decoder_norm = build_norm_layer( + norm_cfg, decoder_embed_dim, postfix=1) + self.add_module(self.decoder_norm_name, decoder_norm) + + # Used to map features to pixels + if predict_feature_dim is None: + predict_feature_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MAE decoder.""" + super().init_weights() + + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=True) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + + @property + def decoder_norm(self): + """The normalization layer of decoder.""" + return getattr(self, self.decoder_norm_name) + + def forward(self, x: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """The forward function. + + The process computes the visible patches' features vectors and the mask + tokens to output feature vectors, which will be used for + reconstruction. + + Args: + x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + ids_restore (torch.Tensor): ids to restore original image. + + Returns: + torch.Tensor: The reconstructed feature vectors, which is of + shape B x (num_patches) x C. + """ + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + x = torch.cat([x[:, :1, :], x_], dim=1) + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + +@MODELS.register_module() +class ClsBatchNormNeck(BaseModule): + """Normalize cls token across batch before head. + + This module is proposed by MAE, when running linear probing. + + Args: + input_features (int): The dimension of features. + affine (bool): a boolean value that when set to ``True``, this module + has learnable affine parameters. Defaults to False. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-6. + init_cfg (Dict or List[Dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + input_features: int, + affine: bool = False, + eps: float = 1e-6, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps) + + def forward( + self, + inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]: + """The forward function.""" + # Only apply batch norm to cls_token + inputs = [self.bn(input_) for input_ in inputs] + return tuple(inputs) diff --git a/mmpretrain/models/necks/milan_neck.py b/mmpretrain/models/necks/milan_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..b48b76787231cfe9671e9f12900b6db1987a7e2a --- /dev/null +++ b/mmpretrain/models/necks/milan_neck.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from torch import nn + +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import PromptMultiheadAttention +from .mae_neck import MAEPretrainDecoder + + +class PromptTransformerEncoderLayer(TransformerEncoderLayer): + """Prompt Transformer Encoder Layer for MILAN. + + This module is specific for the prompt encoder in MILAN. It will not update + the visible tokens from the encoder. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): Enable bias for qkv if True. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels=int, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + self.attn = PromptMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias) + + def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """Forward function for `PromptMultiheadAttention`. + + Args: + x (torch.Tensor): Mask token features with shape N x L_m x C. + visible_tokens (torch.Tensor): The visible tokens features from + encoder with shape N x L_v x C. + ids_restore (torch.Tensor): The ids of all tokens in the original + image with shape N x L. + + Returns: + torch Tensor: Output features with shape N x L x C. + """ + x = x + self.attn(self.norm1(x), visible_tokens, ids_restore) + x = self.ffn(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class MILANPretrainDecoder(MAEPretrainDecoder): + """Prompt decoder for MILAN. + + This decoder is used in MILAN pretraining, which will not update these + visible tokens from the encoder. + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + predict_feature_dim (int): The dimension of the feature to be + predicted. Defaults to 512. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + predict_feature_dim: int = 512, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + num_patches=num_patches, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + decoder_depth=decoder_depth, + decoder_num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + # map the dim of features from decoder to the dim compatible with + # that of CLIP + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + + # use prompt transformer encoder layer, instead of the conventional + # transformer encoder layer + self.decoder_blocks = nn.ModuleList([ + PromptTransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + def forward(self, x: torch.Tensor, ids_restore: torch.Tensor, + ids_keep: torch.Tensor, + ids_dump: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): The input features, which is of shape (N, L, C). + ids_restore (torch.Tensor): The indices to restore these tokens + to the original image. + ids_keep (torch.Tensor): The indices of tokens to be kept. + ids_dump (torch.Tensor): The indices of tokens to be masked. + + Returns: + torch.Tensor: The reconstructed features, which is of shape + (N, L, C). + """ + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + x = torch.cat([x[:, :1, :], x_], dim=1) + + # add pos embed + x = x + self.decoder_pos_embed + + # split mask tokens and visible tokens + visible_tokens = torch.cat([ + x[:, :1, :], + torch.gather( + x[:, 1:, :], + dim=1, + index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + ], + dim=1) + x = torch.gather( + x[:, 1:, :], + dim=1, + index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + + for blk in self.decoder_blocks: + x = blk(x, visible_tokens, ids_restore) + + # full sequence recovery + x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, + x.shape[-1])) # unshuffle + x = torch.cat([visible_tokens[:, :1, :], x_], dim=1) + + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x diff --git a/mmpretrain/models/necks/mixmim_neck.py b/mmpretrain/models/necks/mixmim_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..8d67ee2bd6b48136f2ae6b298e11bd7758fa414b --- /dev/null +++ b/mmpretrain/models/necks/mixmim_neck.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from ..utils import build_2d_sincos_position_embedding +from .mae_neck import MAEPretrainDecoder + + +@MODELS.register_module() +class MixMIMPretrainDecoder(MAEPretrainDecoder): + """Decoder for MixMIM Pretraining. + + Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + encoder_stride (int): The output stride of MixMIM backbone. Defaults + to 32. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + encoder_stride: int = 32, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + num_patches=num_patches, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + decoder_depth=decoder_depth, + decoder_num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, decoder_embed_dim), + requires_grad=False) + self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MixMIM decoder.""" + super(MAEPretrainDecoder, self).init_weights() + + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=False) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): The input features, which is of shape (N, L, C). + mask (torch.Tensor): The tensor to indicate which tokens a + re masked. + + Returns: + torch.Tensor: The reconstructed features, which is of shape + (N, L, C). + """ + + x = self.decoder_embed(x) + B, L, C = x.shape + + mask_tokens = self.mask_token.expand(B, L, -1) + x1 = x * (1 - mask) + mask_tokens * mask + x2 = x * mask + mask_tokens * (1 - mask) + x = torch.cat([x1, x2], dim=0) + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for idx, blk in enumerate(self.decoder_blocks): + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x diff --git a/mmpretrain/models/necks/mocov2_neck.py b/mmpretrain/models/necks/mocov2_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad9107812eb9aaaaff8cbc1a7d5c3d39e92dfa1 --- /dev/null +++ b/mmpretrain/models/necks/mocov2_neck.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV2Neck(BaseModule): + """The non-linear neck of MoCo v2: fc-relu-fc. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + with_avg_pool (bool): Whether to apply the global + average pooling after backbone. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + hid_channels: int, + out_channels: int, + with_avg_pool: bool = True, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function. + + Args: + x (Tuple[torch.Tensor]): The feature map of backbone. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + assert len(x) == 1 + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) + return (self.mlp(x.view(x.size(0), -1)), ) diff --git a/mmpretrain/models/necks/nonlinear_neck.py b/mmpretrain/models/necks/nonlinear_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..ef684d39d1f7f5dc7361ccbf631d3ce712d65ac5 --- /dev/null +++ b/mmpretrain/models/necks/nonlinear_neck.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class NonLinearNeck(BaseModule): + """The non-linear neck. + + Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated. + For the default setting, the repeated time is 1. + The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + num_layers (int): Number of fc layers. Defaults to 2. + with_bias (bool): Whether to use bias in fc layers (except for the + last). Defaults to False. + with_last_bn (bool): Whether to add the last BN layer. + Defaults to True. + with_last_bn_affine (bool): Whether to have learnable affine parameters + in the last BN layer (set False for SimSiam). Defaults to True. + with_last_bias (bool): Whether to use bias in the last fc layer. + Defaults to False. + with_avg_pool (bool): Whether to apply the global average pooling + after backbone. Defaults to True. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + hid_channels: int, + out_channels: int, + num_layers: int = 2, + with_bias: bool = False, + with_last_bn: bool = True, + with_last_bn_affine: bool = True, + with_last_bias: bool = False, + with_avg_pool: bool = True, + norm_cfg: dict = dict(type='SyncBN'), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + super(NonLinearNeck, self).__init__(init_cfg) + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.relu = nn.ReLU(inplace=True) + self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias) + self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1] + + self.fc_names = [] + self.bn_names = [] + for i in range(1, num_layers): + this_channels = out_channels if i == num_layers - 1 \ + else hid_channels + if i != num_layers - 1: + self.add_module( + f'fc{i}', + nn.Linear(hid_channels, this_channels, bias=with_bias)) + self.add_module(f'bn{i}', + build_norm_layer(norm_cfg, this_channels)[1]) + self.bn_names.append(f'bn{i}') + else: + self.add_module( + f'fc{i}', + nn.Linear( + hid_channels, this_channels, bias=with_last_bias)) + if with_last_bn: + self.add_module( + f'bn{i}', + build_norm_layer( + dict(**norm_cfg, affine=with_last_bn_affine), + this_channels)[1]) + self.bn_names.append(f'bn{i}') + else: + self.bn_names.append(None) + self.fc_names.append(f'fc{i}') + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function. + + Args: + x (Tuple[torch.Tensor]): The feature map of backbone. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + assert len(x) == 1 + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc0(x) + x = self.bn0(x) + for fc_name, bn_name in zip(self.fc_names, self.bn_names): + fc = getattr(self, fc_name) + x = self.relu(x) + x = fc(x) + if bn_name is not None: + bn = getattr(self, bn_name) + x = bn(x) + return (x, ) diff --git a/mmpretrain/models/necks/simmim_neck.py b/mmpretrain/models/necks/simmim_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1e29bcf195ecb800a22a2c43917e62718b5ffe --- /dev/null +++ b/mmpretrain/models/necks/simmim_neck.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMLinearDecoder(BaseModule): + """Linear Decoder For SimMIM pretraining. + + This neck reconstructs the original image from the shrunk feature map. + + Args: + in_channels (int): Channel dimension of the feature map. + encoder_stride (int): The total stride of the encoder. + """ + + def __init__(self, in_channels: int, encoder_stride: int) -> None: + super().__init__() + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=encoder_stride**2 * 3, + kernel_size=1), + nn.PixelShuffle(encoder_stride), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = self.decoder(x) + return x diff --git a/mmpretrain/models/necks/spark_neck.py b/mmpretrain/models/necks/spark_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..ac129da389711f900e4444fae38fdbc7ae91b9e5 --- /dev/null +++ b/mmpretrain/models/necks/spark_neck.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def is_pow2n(x): + return x > 0 and (x & (x - 1) == 0) + + +class ConvBlock2x(BaseModule): + """The definition of convolution block.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + norm_cfg: dict, + act_cfg: dict, + last_act: bool, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False) + self.norm1 = build_norm_layer(norm_cfg, mid_channels) + self.activate1 = MODELS.build(act_cfg) + + self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False) + self.norm2 = build_norm_layer(norm_cfg, out_channels) + self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity() + + def forward(self, x: torch.Tensor): + out = self.conv1(x) + out = self.norm1(out) + out = self.activate1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.activate2(out) + return out + + +class DecoderConvModule(BaseModule): + """The convolution module of decoder with upsampling.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = True, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + padding = (kernel_size - scale_factor) // 2 + self.upsample = nn.ConvTranspose2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=scale_factor, + padding=padding, + bias=True) + + conv_blocks_list = [ + ConvBlock2x( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + norm_cfg=norm_cfg, + last_act=last_act, + act_cfg=act_cfg) for _ in range(num_conv_blocks) + ] + self.conv_blocks = nn.Sequential(*conv_blocks_list) + + def forward(self, x): + x = self.upsample(x) + return self.conv_blocks(x) + + +@MODELS.register_module() +class SparKLightDecoder(BaseModule): + """The decoder for SparK, which upsamples the feature maps. + + Args: + feature_dim (int): The dimension of feature map. + upsample_ratio (int): The ratio of upsample, equal to downsample_raito + of the algorithm. + mid_channels (int): The middle channel of `DecoderConvModule`. Defaults + to 0. + kernel_size (int): The kernel size of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 4. + scale_factor (int): The scale_factor of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 2. + num_conv_blocks (int): The number of convolution blocks in + `DecoderConvModule`. Defaults to 1. + norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN'). + act_cfg (dict): Activation config. Defaults to dict(type='ReLU6'). + last_act (bool): Whether apply the last activation in + `DecoderConvModule`. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + feature_dim: int, + upsample_ratio: int, + mid_channels: int = 0, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']), + dict(type='TruncNormal', std=0.02, layer=['Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm']) + ], + ): + super().__init__(init_cfg=init_cfg) + self.feature_dim = feature_dim + + assert is_pow2n(upsample_ratio) + n = round(math.log2(upsample_ratio)) + channels = [feature_dim // 2**i for i in range(n + 1)] + + self.decoder = nn.ModuleList([ + DecoderConvModule( + in_channels=c_in, + out_channels=c_out, + mid_channels=c_in if mid_channels == 0 else mid_channels, + kernel_size=kernel_size, + scale_factor=scale_factor, + num_conv_blocks=num_conv_blocks, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + last_act=last_act) + for (c_in, c_out) in zip(channels[:-1], channels[1:]) + ]) + self.proj = nn.Conv2d( + channels[-1], 3, kernel_size=1, stride=1, bias=True) + + def forward(self, to_dec): + x = 0 + for i, d in enumerate(self.decoder): + if i < len(to_dec) and to_dec[i] is not None: + x = x + to_dec[i] + x = self.decoder[i](x) + return self.proj(x) diff --git a/mmpretrain/models/necks/swav_neck.py b/mmpretrain/models/necks/swav_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..807ae8b9b3155e9dd14ef95fe5fca526919ee11d --- /dev/null +++ b/mmpretrain/models/necks/swav_neck.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SwAVNeck(BaseModule): + """The non-linear neck of SwAV: fc-bn-relu-fc-normalization. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + with_avg_pool (bool): Whether to apply the global average pooling after + backbone. Defaults to True. + with_l2norm (bool): whether to normalize the output after projection. + Defaults to True. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + hid_channels: int, + out_channels: int, + with_avg_pool: bool = True, + with_l2norm: bool = True, + norm_cfg: dict = dict(type='SyncBN'), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + super().__init__(init_cfg) + self.with_avg_pool = with_avg_pool + self.with_l2norm = with_l2norm + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if out_channels == 0: + self.projection_neck = nn.Identity() + elif hid_channels == 0: + self.projection_neck = nn.Linear(in_channels, out_channels) + else: + self.norm = build_norm_layer(norm_cfg, hid_channels)[1] + self.projection_neck = nn.Sequential( + nn.Linear(in_channels, hid_channels), + self.norm, + nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels), + ) + + def forward_projection(self, x: torch.Tensor) -> torch.Tensor: + """Compute projection. + + Args: + x (torch.Tensor): The feature vectors after pooling. + + Returns: + torch.Tensor: The output features with projection or L2-norm. + """ + x = self.projection_neck(x) + if self.with_l2norm: + x = nn.functional.normalize(x, dim=1, p=2) + return x + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + + Args: + x (List[torch.Tensor]): list of feature maps, len(x) according to + len(num_crops). + + Returns: + torch.Tensor: The projection vectors. + """ + avg_out = [] + for _x in x: + _x = _x[0] + if self.with_avg_pool: + _out = self.avgpool(_x) + avg_out.append(_out) + feat_vec = torch.cat(avg_out) # [sum(num_crops) * N, C] + feat_vec = feat_vec.view(feat_vec.size(0), -1) + output = self.forward_projection(feat_vec) + return output diff --git a/mmpretrain/models/peft/__init__.py b/mmpretrain/models/peft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f43e14890fdbff8b64a9046a5c7f06d62cfec8d --- /dev/null +++ b/mmpretrain/models/peft/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .lora import LoRAModel + +__all__ = [ + 'LoRAModel', +] diff --git a/mmpretrain/models/peft/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/peft/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b47ccb0ebed63f605b7ec809757c32ba5530cd8 Binary files /dev/null and b/mmpretrain/models/peft/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/peft/__pycache__/lora.cpython-311.pyc b/mmpretrain/models/peft/__pycache__/lora.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dac39e54bd697d1d276213cd6c3f3918c62eb665 Binary files /dev/null and b/mmpretrain/models/peft/__pycache__/lora.cpython-311.pyc differ diff --git a/mmpretrain/models/peft/lora.py b/mmpretrain/models/peft/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1bae7fdd23bbeb3fa4ff58fde2f6d1176de8b6 --- /dev/null +++ b/mmpretrain/models/peft/lora.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import re +from typing import Any, List + +import torch +from mmengine.logging import print_log +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +class LoRALinear(nn.Module): + r"""Implements LoRA in a linear layer. + + Args: + original_layer (nn.Linear): The linear layer to be finetuned. + alpha (int): The scale factor of LoRA. Defaults to 1. + rank (int): The rank of LoRA. Defaults to 0. + drop_rate (float): The drop out rate for LoRA. Defaults to 0. + + Note: + The forward process of LoRA linear layer is: + + .. math:: + `y = W_0 x + BAx * (\alpha / r)` + + Where :math:`x` is the input, :math:`y` is the output, + :math:`W_0` is the parameter of the original layer, + :math:`A` and :math:`B` are the low-rank decomposition matrixs, + :math: `\alpha` is the scale factor and :math: `r` is the rank. + """ + + def __init__(self, + original_layer: nn.Linear, + alpha: int = 1, + rank: int = 0, + drop_rate: float = 0.): + super(LoRALinear, self).__init__() + in_features = original_layer.in_features + out_features = original_layer.out_features + + self.lora_dropout = nn.Dropout(drop_rate) + self.lora_down = nn.Linear(in_features, rank, bias=False) + self.lora_up = nn.Linear(rank, out_features, bias=False) + self.scaling = alpha / rank + + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) + + self.original_layer = original_layer + + def forward(self, x: torch.Tensor): + out = self.original_layer(x) + + lora_x = self.lora_dropout(x) + lora_out = self.lora_up(self.lora_down(lora_x)) * self.scaling + + return out + lora_out + + +@MODELS.register_module() +class LoRAModel(BaseModule): + """Implements LoRA in a module. + + An PyTorch implement of : `LoRA: Low-Rank Adaptation + of Large Language Models `_ + + Args: + module (dict): The config of the module to be finetuned. See + :mod:`mmpretrain.models` + alpha (int): The scale factor of LoRA. Defaults to 1. + rank (int): The rank of LoRA. Defaults to 0. + drop_rate (float): The drop out rate for LoRA. Defaults to 0. + targets (List[dict]): The target layers to be applied with the LoRA. + Defaults to a empty list. Specify by regular expression or suffix. + + Examples: + >>> model = LoRAModel( + ... module=dict(type='VisionTransformer', arch='b'), + ... alpha=4, + ... rank=4, + ... drop_rate=0.1, + ... targets=[ + ... dict(type='.*qkv'), # regular expression + ... dict(type='proj', alpha=8, rank=8), # suffix + ... ]) + """ + + def __init__(self, + module: dict, + alpha: int = 1, + rank: int = 0, + drop_rate: float = 0., + targets: List[dict] = list()): + + super().__init__() + + module = MODELS.build(module) + module.init_weights() + + self.module = module + self.alpha = alpha + self.rank = rank + self.drop_rate = drop_rate + + assert len(targets) != 0, \ + 'The length of target layers should not be 0.' + + self.targets = targets + + self.applied = False + self.apply_lora() + + if not self.applied: + raise ValueError( + 'No lora layer is replaced. Please check targets.') + + self._set_lora_trainable() + self._register_state_dict_hooks() + + def apply_lora(self): + """Apply LoRA to target layers.""" + module_names = [k for k, _ in self.module.named_modules()] + for module_name in module_names: + for target in self.targets: + target_name = target['type'] + target_alpha = target.get('alpha', self.alpha) + target_rank = target.get('rank', self.rank) + target_drop_rate = target.get('drop_rate', self.drop_rate) + + if re.fullmatch(target_name, module_name) or \ + module_name.endswith(target_name): + current_module = self.module.get_submodule(module_name) + if isinstance(current_module, nn.Linear): + print_log( + f'Set LoRA for {module_name} ' + f'with alpha: {target_alpha}, ' + f'rank: {target_rank}, ' + f'drop rate: {target_drop_rate}', + logger='current') + + self._replace_module(module_name, current_module, + target_alpha, target_rank, + target_drop_rate) + self.applied = True + + def _replace_module(self, module_name: str, current_module: nn.Module, + alpha: int, rank: int, drop_rate: float): + """Replace target layer with LoRA linear layer in place.""" + parent_module_name = '.'.join(module_name.split('.')[:-1]) + parent_module = self.module.get_submodule(parent_module_name) + + target_name = module_name.split('.')[-1] + target_module = LoRALinear(current_module, alpha, rank, drop_rate) + setattr(parent_module, target_name, target_module) + + def _set_lora_trainable(self): + """Set only the lora parameters trainable.""" + for name, param in self.named_parameters(): + if '.lora_' in name: + param.requires_grad = True + else: + param.requires_grad = False + + def _register_state_dict_hooks(self): + """Register state dict hooks. + + Register state dict saving hooks to save only the lora parameters to + the state dict. And register state dict loading hooks to handle the + incompatible keys while loading the state dict. + """ + + def _state_dict_hook(module, state_dict, prefix, local_metadata): + """Save only the lora parameters to the state dict.""" + keys = [k for k, _ in state_dict.items()] + for key in keys: + if '.lora_' not in key: + state_dict.pop(key) + + self._register_state_dict_hook(_state_dict_hook) + + def _load_state_dict_post_hook(module, incompatible_keys): + """Handle the incompatible keys while loading the state dict.""" + missing_keys = incompatible_keys.missing_keys.copy() + for key in missing_keys: + if '.lora_' not in key: + incompatible_keys.missing_keys.remove(key) + + unexpected_keys = incompatible_keys.unexpected_keys.copy() + for key in unexpected_keys: + if '.lora_' not in key: + incompatible_keys.unexpected_keys.remove(key) + + self.register_load_state_dict_post_hook(_load_state_dict_post_hook) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + try: + return super(LoRAModel, self).__getattr__(name) + except AttributeError: + return self.module.__getattribute__(name) diff --git a/mmpretrain/models/retrievers/__init__.py b/mmpretrain/models/retrievers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..593b637d6eb7e44184fdf6ceb70470253639b013 --- /dev/null +++ b/mmpretrain/models/retrievers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseRetriever +from .image2image import ImageToImageRetriever + +__all__ = ['BaseRetriever', 'ImageToImageRetriever'] diff --git a/mmpretrain/models/retrievers/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/retrievers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0565c033f05dec7f9735ea5170d46fa2368fc9ff Binary files /dev/null and b/mmpretrain/models/retrievers/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/retrievers/__pycache__/base.cpython-311.pyc b/mmpretrain/models/retrievers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddba37a3ba7465181fc09e9c2533619a80b89e4e Binary files /dev/null and b/mmpretrain/models/retrievers/__pycache__/base.cpython-311.pyc differ diff --git a/mmpretrain/models/retrievers/__pycache__/image2image.cpython-311.pyc b/mmpretrain/models/retrievers/__pycache__/image2image.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db9e334dc40c657b4e4e88d9f9f3cda4fed3daa0 Binary files /dev/null and b/mmpretrain/models/retrievers/__pycache__/image2image.cpython-311.pyc differ diff --git a/mmpretrain/models/retrievers/base.py b/mmpretrain/models/retrievers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..15816798f3fadc612b51634994178eb5f8860fb8 --- /dev/null +++ b/mmpretrain/models/retrievers/base.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch.utils.data import DataLoader + + +class BaseRetriever(BaseModel, metaclass=ABCMeta): + """Base class for retriever. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + Attributes: + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__( + self, + prototype: Union[DataLoader, dict, str, torch.Tensor] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + ): + super(BaseRetriever, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + self.prototype = prototype + self.prototype_inited = False + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'loss'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def loss(self, inputs: torch.Tensor, + data_samples: List[BaseDataElement]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + raise NotImplementedError + + def predict(self, + inputs: tuple, + data_samples: Optional[List[BaseDataElement]] = None, + **kwargs) -> List[BaseDataElement]: + """Predict results from the extracted features. + + Args: + inputs (tuple): The features extracted from the backbone. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + raise NotImplementedError + + def matching(self, inputs: torch.Tensor): + """Compare the prototype and calculate the similarity. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C). + """ + raise NotImplementedError + + def prepare_prototype(self): + """Preprocessing the prototype before predict.""" + raise NotImplementedError + + def dump_prototype(self, path): + """Save the features extracted from the prototype to the specific path. + + Args: + path (str): Path to save feature. + """ + raise NotImplementedError diff --git a/mmpretrain/models/retrievers/image2image.py b/mmpretrain/models/retrievers/image2image.py new file mode 100644 index 0000000000000000000000000000000000000000..a00c1dceb102ee692c44090b62dcfa19dc441f3b --- /dev/null +++ b/mmpretrain/models/retrievers/image2image.py @@ -0,0 +1,314 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +from mmengine.runner import Runner +from torch.utils.data import DataLoader + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .base import BaseRetriever + + +@MODELS.register_module() +class ImageToImageRetriever(BaseRetriever): + """Image To Image Retriever for supervised retrieval task. + + Args: + image_encoder (Union[dict, List[dict]]): Encoder for extracting + features. + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + head (dict, optional): The head module to calculate loss from + processed features. See :mod:`mmpretrain.models.heads`. Notice + that if the head is not set, `loss` method cannot be used. + Defaults to None. + similarity_fn (Union[str, Callable]): The way that the similarity + is calculated. If `similarity` is callable, it is used directly + as the measure function. If it is a string, the appropriate + method will be used. The larger the calculated value, the + greater the similarity. Defaults to "cosine_similarity". + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in + :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + topk (int): Return the topk of the retrieval result. `-1` means + return all. Defaults to -1. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + image_encoder: Union[dict, List[dict]], + prototype: Union[DataLoader, dict, str, torch.Tensor], + head: Optional[dict] = None, + pretrained: Optional[str] = None, + similarity_fn: Union[str, Callable] = 'cosine_similarity', + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + topk: int = -1, + init_cfg: Optional[dict] = None): + + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super(ImageToImageRetriever, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(image_encoder, nn.Module): + image_encoder = MODELS.build(image_encoder) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.image_encoder = image_encoder + self.head = head + + self.similarity = similarity_fn + + assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), ( + 'The `prototype` in `ImageToImageRetriever` must be a path, ' + 'a torch.Tensor, a dataloader or a dataloader dict format config.') + self.prototype = prototype + self.prototype_inited = False + self.topk = topk + + @property + def similarity_fn(self): + """Returns a function that calculates the similarity.""" + # If self.similarity_way is callable, return it directly + if isinstance(self.similarity, Callable): + return self.similarity + + if self.similarity == 'cosine_similarity': + # a is a tensor with shape (N, C) + # b is a tensor with shape (M, C) + # "cosine_similarity" will get the matrix of similarity + # with shape (N, M). + # The higher the score is, the more similar is + return lambda a, b: torch.cosine_similarity( + a.unsqueeze(1), b.unsqueeze(0), dim=-1) + else: + raise RuntimeError(f'Invalid function "{self.similarity_fn}".') + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + return self.extract_feat(inputs) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: + Tensor: The output of encoder. + """ + + feat = self.image_encoder(inputs) + return feat + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def matching(self, inputs: torch.Tensor): + """Compare the prototype and calculate the similarity. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C). + Returns: + dict: a dictionary of score and prediction label based on fn. + """ + sim = self.similarity_fn(inputs, self.prototype_vecs) + sorted_sim, indices = torch.sort(sim, descending=True, dim=-1) + predictions = dict( + score=sim, pred_label=indices, pred_score=sorted_sim) + return predictions + + def predict(self, + inputs: tuple, + data_samples: Optional[List[DataSample]] = None, + **kwargs) -> List[DataSample]: + """Predict results from the extracted features. + + Args: + inputs (tuple): The features extracted from the backbone. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + Returns: + List[DataSample]: the raw data_samples with + the predicted results + """ + if not self.prototype_inited: + self.prepare_prototype() + + feats = self.extract_feat(inputs) + if isinstance(feats, tuple): + feats = feats[-1] + + # Matching of similarity + result = self.matching(feats) + return self._get_predictions(result, data_samples) + + def _get_predictions(self, result, data_samples): + """Post-process the output of retriever.""" + pred_scores = result['score'] + pred_labels = result['pred_label'] + if self.topk != -1: + topk = min(self.topk, pred_scores.size()[-1]) + pred_labels = pred_labels[:, :topk] + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + return data_samples + + def _get_prototype_vecs_from_dataloader(self, data_loader): + """get prototype_vecs from dataloader.""" + self.eval() + num = len(data_loader.dataset) + + prototype_vecs = None + for data_batch in track_on_main_process(data_loader, + 'Prepare prototype'): + data = self.data_preprocessor(data_batch, False) + feat = self(**data) + if isinstance(feat, tuple): + feat = feat[-1] + + if prototype_vecs is None: + dim = feat.shape[-1] + prototype_vecs = torch.zeros(num, dim) + for i, data_sample in enumerate(data_batch['data_samples']): + sample_idx = data_sample.get('sample_idx') + prototype_vecs[sample_idx] = feat[i] + + assert prototype_vecs is not None + dist.all_reduce(prototype_vecs) + return prototype_vecs + + def _get_prototype_vecs_from_path(self, proto_path): + """get prototype_vecs from prototype path.""" + data = [None] + if dist.is_main_process(): + data[0] = torch.load(proto_path) + dist.broadcast_object_list(data, src=0) + prototype_vecs = data[0] + assert prototype_vecs is not None + return prototype_vecs + + @torch.no_grad() + def prepare_prototype(self): + """Used in meta testing. This function will be called before the meta + testing. Obtain the vector based on the prototype. + + - torch.Tensor: The prototype vector is the prototype + - str: The path of the extracted feature path, parse data structure, + and generate the prototype feature vector set + - Dataloader or config: Extract and save the feature vectors according + to the dataloader + """ + device = next(self.image_encoder.parameters()).device + if isinstance(self.prototype, torch.Tensor): + prototype_vecs = self.prototype + elif isinstance(self.prototype, str): + prototype_vecs = self._get_prototype_vecs_from_path(self.prototype) + elif isinstance(self.prototype, (dict, DataLoader)): + loader = Runner.build_dataloader(self.prototype) + prototype_vecs = self._get_prototype_vecs_from_dataloader(loader) + + self.register_buffer( + 'prototype_vecs', prototype_vecs.to(device), persistent=False) + self.prototype_inited = True + + def dump_prototype(self, path): + """Save the features extracted from the prototype to specific path. + + Args: + path (str): Path to save feature. + """ + if not self.prototype_inited: + self.prepare_prototype() + torch.save(self.prototype_vecs, path) diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08c1ed59ddcf51924b3de7df1c995bf84c6bb753 --- /dev/null +++ b/mmpretrain/models/selfsup/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .barlowtwins import BarlowTwins +from .base import BaseSelfSupervisor +from .beit import VQKD, BEiT, BEiTPretrainViT +from .byol import BYOL +from .cae import CAE, CAEPretrainViT, DALLEEncoder +from .densecl import DenseCL +from .eva import EVA +from .itpn import iTPN, iTPNHiViT +from .mae import MAE, MAEHiViT, MAEViT +from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT +from .mff import MFF, MFFViT +from .milan import MILAN, CLIPGenerator, MILANViT +from .mixmim import MixMIM, MixMIMPretrainTransformer +from .moco import MoCo +from .mocov3 import MoCoV3, MoCoV3ViT +from .simclr import SimCLR +from .simmim import SimMIM, SimMIMSwinTransformer +from .simsiam import SimSiam +from .spark import SparK +from .swav import SwAV + +__all__ = [ + 'BaseSelfSupervisor', + 'BEiTPretrainViT', + 'VQKD', + 'CAEPretrainViT', + 'DALLEEncoder', + 'MAEViT', + 'MAEHiViT', + 'iTPNHiViT', + 'iTPN', + 'HOGGenerator', + 'MaskFeatViT', + 'CLIPGenerator', + 'MILANViT', + 'MixMIMPretrainTransformer', + 'MoCoV3ViT', + 'SimMIMSwinTransformer', + 'MoCo', + 'MoCoV3', + 'BYOL', + 'SimCLR', + 'SimSiam', + 'BEiT', + 'CAE', + 'MAE', + 'MaskFeat', + 'MILAN', + 'MixMIM', + 'SimMIM', + 'EVA', + 'DenseCL', + 'BarlowTwins', + 'SwAV', + 'SparK', + 'MFF', + 'MFFViT', +] diff --git a/mmpretrain/models/selfsup/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26ed18a8ca0706c0486ccd230f33158b20f86386 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/barlowtwins.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/barlowtwins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6403f7cfbe93fbc207cf8c019c52f0fc8bc8b1 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/barlowtwins.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/base.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fe3096afa5ee77f5607723454345549d774fe77 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/base.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/beit.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/beit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73edb64a8a3d19a2a0e4315a78ad7b1db4bf1afe Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/beit.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/byol.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/byol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6123f06e858fddde6eaacff98ae8c2405c927729 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/byol.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/cae.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/cae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3acbf632fa25e574de2542f2df2292cbd632fe3 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/cae.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/densecl.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/densecl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae773355d4d97a2d520d7a7b83765afbbf7ca8dd Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/densecl.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/eva.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/eva.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86de60297bf33eb2f711b416b941fd4a2afe7a07 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/eva.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/itpn.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/itpn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf3191ca2812398dcabddc7df10d95b760cfed6 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/itpn.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mae.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/mae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a412b57083a6e571414d04ba36416a9411878bb Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mae.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/maskfeat.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/maskfeat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a4c60ddefb5b16e566d9d6822c1ffb8d8dd2a52 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/maskfeat.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mff.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/mff.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f84c9da2fed0936d0b7b2fbfc8fa80bbaa143bac Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mff.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/milan.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/milan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5d8aa9beb5f8416ecf0a110ed29a6c734f293c Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/milan.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mixmim.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/mixmim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e02c8b4cfd2c143f05a43ea5b0078eef68457dd Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mixmim.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/moco.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/moco.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4765e60a6b51903fb3fee8fa9756feabaf3b29ef Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/moco.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/mocov3.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/mocov3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ddc15fc78d04f357ccb715f89dc007a625b7eb3 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/mocov3.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/simclr.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/simclr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17315ed32ce0264d7f670227bdbd6d9fe3439b0d Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/simclr.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/simmim.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/simmim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03e692784a8ff4b3cc474160882a4dbebbc06dfd Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/simmim.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/simsiam.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/simsiam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47d8e73c7e1e0c4a1c298a05d3f75f53df2df8b7 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/simsiam.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/spark.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/spark.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a465c131c0a3750ac8d87adf1956b4e03b61d13 Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/spark.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/__pycache__/swav.cpython-311.pyc b/mmpretrain/models/selfsup/__pycache__/swav.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70bae58c3e0b28610034c6473de2c097be38340d Binary files /dev/null and b/mmpretrain/models/selfsup/__pycache__/swav.cpython-311.pyc differ diff --git a/mmpretrain/models/selfsup/barlowtwins.py b/mmpretrain/models/selfsup/barlowtwins.py new file mode 100644 index 0000000000000000000000000000000000000000..4c75cd0caca6ab2dc4c4a14e365fda5daa9bdb83 --- /dev/null +++ b/mmpretrain/models/selfsup/barlowtwins.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class BarlowTwins(BaseSelfSupervisor): + """BarlowTwins. + + Implementation of `Barlow Twins: Self-Supervised Learning via Redundancy + Reduction `_. + Part of the code is borrowed from: + ``_. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + + z1 = self.neck(self.backbone(img_v1))[0] # NxC + z2 = self.neck(self.backbone(img_v2))[0] # NxC + + loss = self.head.loss(z1, z2) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/base.py b/mmpretrain/models/selfsup/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9d53a72871dff7b4fc59cd591686350026a875bb --- /dev/null +++ b/mmpretrain/models/selfsup/base.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta): + """BaseModel for Self-Supervised Learning. + + All self-supervised algorithms should inherit this module. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + target_generator: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + if target_generator is not None and not isinstance( + target_generator, nn.Module): + target_generator = MODELS.build(target_generator) + + self.backbone = backbone + self.neck = neck + self.head = head + self.target_generator = target_generator + + @property + def with_neck(self) -> bool: + """Check if the model has a neck module.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Check if the model has a head module.""" + return hasattr(self, 'head') and self.head is not None + + @property + def with_target_generator(self) -> bool: + """Check if the model has a target_generator module.""" + return hasattr( + self, 'target_generator') and self.target_generator is not None + + def forward(self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method currently accepts two modes: "tensor" and "loss": + + - "tensor": Forward the backbone network and return the feature + tensor(s) tensor without any post-processing, same as a common + PyTorch Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor or List[torch.Tensor]): The input tensor with + shape (N, C, ...) in general. + data_samples (List[DataSample], optional): The other data of + every samples. It's required for some algorithms + if ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The default behavior is extracting features from backbone. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + + Returns: + tuple | Tensor: The output feature tensor(s). + """ + x = self.backbone(inputs) + return x + + @abstractmethod + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + This is a abstract method, and subclass should overwrite this methods + if needed. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + raise NotImplementedError + + def get_layer_depth(self, param_name: str): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + Tuple[int, int]: The layer-wise depth and the max depth. + """ + if hasattr(self.backbone, 'get_layer_depth'): + return self.backbone.get_layer_depth(param_name, 'backbone.') + else: + raise NotImplementedError( + f"The backbone {type(self.backbone)} doesn't " + 'support `get_layer_depth` by now.') diff --git a/mmpretrain/models/selfsup/beit.py b/mmpretrain/models/selfsup/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..c301f7d5cae07370f26b4cd531190b8c3c90e24b --- /dev/null +++ b/mmpretrain/models/selfsup/beit.py @@ -0,0 +1,357 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +from einops import rearrange +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from torch import nn + +from mmpretrain.models.backbones import BEiTViT +from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class VQKD(BaseModule): + """Vector-Quantized Knowledge Distillation. + + The module only contains encoder and VectorQuantizer part + Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py + + Args: + encoder_config (dict): The config of encoder. + decoder_config (dict, optional): The config of decoder. Currently, + VQKD only support to build encoder. Defaults to None. + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + decay (float): The decay parameter of EMA. Defaults to 0.99. + beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. + quantize_kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + encoder_config: dict, + decoder_config: Optional[dict] = None, + num_embed: int = 8192, + embed_dims: int = 32, + decay: float = 0.99, + beta: float = 1.0, + quantize_kmeans_init: bool = True, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.encoder = BEiTViT(**encoder_config) + if decoder_config is not None: + self.decoder = BEiTViT(**decoder_config) + + self.quantize = NormEMAVectorQuantizer( + num_embed=num_embed, + embed_dims=embed_dims, + beta=beta, + decay=decay, + kmeans_init=quantize_kmeans_init, + ) + + # task layer + self.encode_task_layer = nn.Sequential( + nn.Linear(self.encoder.arch_settings['embed_dims'], + self.encoder.arch_settings['embed_dims']), nn.Tanh(), + nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims)) + + def get_tokens(self, x: torch.Tensor) -> dict: + """Get tokens for beit pre-training.""" + _, embed_ind, _ = self.encode(x) + output = {} + output['token'] = embed_ind.view(x.shape[0], -1) + output['input_img'] = x + + return output + + def encode( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode the input images and get corresponding results.""" + encoder_features = self.encoder(x)[0] + B, C, N1, N2 = encoder_features.shape + encoder_features = encoder_features.permute(0, 2, 3, + 1).reshape(B, N1 * N2, C) + + with torch.cuda.amp.autocast(enabled=False): + to_quantizer_features = self.encode_task_layer( + encoder_features.type_as(self.encode_task_layer[-1].weight)) + + N = to_quantizer_features.shape[1] + h, w = int(math.sqrt(N)), int(math.sqrt(N)) + + to_quantizer_features = rearrange( + to_quantizer_features, 'b (h w) c -> b c h w', h=h, + w=w) # reshape for quantizer + quantize, loss, embed_ind = self.quantize(to_quantizer_features) + + return quantize, embed_ind, loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The forward function. + + Currently, only support to get tokens. + """ + return self.get_tokens(x)['token'] + + +@MODELS.register_module() +class BEiTPretrainViT(BEiTViT): + """Vision Transformer for BEiT pre-training. + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + use_abs_pos_emb (bool): Whether or not use absolute position embedding. + Defaults to False. + use_rel_pos_bias (bool): Whether or not use relative position bias. + Defaults to False. + use_shared_rel_pos_bias (bool): Whether or not use shared relative + position bias. Defaults to True. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + frozen_stages: int = -1, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + layer_scale_init_value: int = 0.1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(padding=0), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + frozen_stages=frozen_stages, + use_abs_pos_emb=use_abs_pos_emb, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_rel_pos_bias=use_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=0.02) + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor]: + """The BEiT style forward function. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images, which is of shape (B x C x H x W). + mask (torch.Tensor, optional): Mask for input, which is of shape + (B x patch_resolution[0] x patch_resolution[1]). + + Returns: + Tuple[torch.Tensor]: Hidden features. + """ + if mask is None: + return super().forward(x) + + else: + x, patch_resolution = self.patch_embed(x) + + # replace the masked visual tokens by mask_token + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + self.shared_rel_pos_bias = self.rel_pos_bias().to( + mask.device) if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + +@MODELS.register_module() +class BEiT(BaseSelfSupervisor): + """BEiT v1/v2. + + Implementation of `BEiT: BERT Pre-Training of Image Transformers + `_ and `BEiT v2: Masked Image Modeling + with Vector-Quantized Visual Tokenizers + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs[0], mask) + + # inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(inputs[1]) + target = target.detach() + + if self.with_neck: + # BEiT v2 + feats, feats_cls_pt = self.neck( + img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias) + loss = self.head.loss(feats, feats_cls_pt, target, mask) + else: + # BEiT v1 + loss = self.head.loss(img_latent[0], target, mask) + + if isinstance(loss, torch.Tensor): + losses = dict(loss=loss) + return losses + elif isinstance(loss, Tuple): + # the loss_1 and loss_2 are general reconstruction loss (patch + # feature vectors from last layer of backbone) and early state + # reconstruction loss (patch feature vectors from intermediate + # layer of backbone) + loss_1, loss_2 = loss[0], loss[1] + losses = dict() + # the key with prefix 'loss', like loss_1 and loss_2, will be used + # as the final criterion + losses['loss_1'] = loss_1 + losses['loss_2'] = loss_2 + return losses diff --git a/mmpretrain/models/selfsup/byol.py b/mmpretrain/models/selfsup/byol.py new file mode 100644 index 0000000000000000000000000000000000000000..803e4005da8620b0e5a93fb29cb65e90a78f345f --- /dev/null +++ b/mmpretrain/models/selfsup/byol.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import CosineEMA +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class BYOL(BaseSelfSupervisor): + """BYOL. + + Implementation of `Bootstrap Your Own Latent: A New Approach to + Self-Supervised Learning `_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features + to compact feature vectors. + head (dict): Config dict for module of head functions. + base_momentum (float): The base momentum coefficient for the target + network. Defaults to 0.004. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + base_momentum: float = 0.004, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.target_net = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + # compute online features + proj_online_v1 = self.neck(self.backbone(img_v1))[0] + proj_online_v2 = self.neck(self.backbone(img_v2))[0] + # compute target features + with torch.no_grad(): + # update the target net + self.target_net.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + proj_target_v1 = self.target_net(img_v1)[0] + proj_target_v2 = self.target_net(img_v2)[0] + + loss_1 = self.head.loss(proj_online_v1, proj_target_v2) + loss_2 = self.head.loss(proj_online_v2, proj_target_v1) + + losses = dict(loss=2. * (loss_1 + loss_2)) + return losses diff --git a/mmpretrain/models/selfsup/cae.py b/mmpretrain/models/selfsup/cae.py new file mode 100644 index 0000000000000000000000000000000000000000..67ac09188e9bf97cdbea63378aa4facb1e8348ab --- /dev/null +++ b/mmpretrain/models/selfsup/cae.py @@ -0,0 +1,472 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Part of code is modified from BEiT +# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py +import math +from collections import OrderedDict +from functools import partial +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones import BEiTViT +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +class Conv2d(nn.Module): + """Rewrite Conv2d module according to DALL-E code.""" + + def __init__(self, + n_in: int, + n_out: int, + kw: int, + use_float16: bool = True, + device: torch.device = torch.device('cpu'), + requires_grad: bool = False) -> None: + super().__init__() + + w = torch.empty((n_out, n_in, kw, kw), + dtype=torch.float32, + device=device, + requires_grad=requires_grad) + w.normal_(std=1 / math.sqrt(n_in * kw**2)) + + b = torch.zeros((n_out, ), + dtype=torch.float32, + device=device, + requires_grad=requires_grad) + self.kw = kw + self.w, self.b = nn.Parameter(w), nn.Parameter(b) + self.use_float16 = use_float16 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_float16 and 'cuda' in self.w.device.type: + if x.dtype != torch.float16: + x = x.half() + + w, b = self.w.half(), self.b.half() + else: + if x.dtype != torch.float32: + x = x.float() + + w, b = self.w, self.b + + return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) + + +class EncoderBlock(nn.Module): + """Rewrite EncoderBlock module according to DALL-E code.""" + + def __init__(self, + n_in: int, + n_out: int, + n_layers: int, + device: torch.device = None, + requires_grad: bool = False) -> None: + super().__init__() + self.n_hid = n_out // 4 + self.post_gain = 1 / (n_layers**2) + + make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) + self.id_path = make_conv(n_in, n_out, + 1) if n_in != n_out else nn.Identity() + self.res_path = nn.Sequential( + OrderedDict([ + ('relu_1', nn.ReLU()), + ('conv_1', make_conv(n_in, self.n_hid, 3)), + ('relu_2', nn.ReLU()), + ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_3', nn.ReLU()), + ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_4', nn.ReLU()), + ('conv_4', make_conv(self.n_hid, n_out, 1)), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +@MODELS.register_module(name='DALL-E') +class DALLEEncoder(BaseModule): + """DALL-E Encoder for feature extraction. + + Args: + group_count (int): Number of groups in DALL-E encoder. Defaults to 4. + n_hid (int): Dimension of hidden layers. Defaults to 256. + n_blk_per_group (int): Number of blocks per group. Defaults to 2. + input_channels: (int): The channels of input images. Defaults to 3. + vocab_size (int): Vocabulary size, indicating the number of classes. + Defaults to 8192. + device (torch.device): Device of parameters. Defaults to + ``torch.device('cpu')``. + requires_grad (bool): Require gradient or not. Defaults to False. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + group_count: int = 4, + n_hid: int = 256, + n_blk_per_group: int = 2, + input_channels: int = 3, + vocab_size: int = 8192, + device: torch.device = torch.device('cpu'), + requires_grad: bool = False, + init_cfg: Union[dict, List[dict], None] = None): + super().__init__(init_cfg=init_cfg) + self.input_channels = input_channels + + blk_range = range(n_blk_per_group) + n_layers = group_count * n_blk_per_group + make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) + make_blk = partial( + EncoderBlock, + n_layers=n_layers, + device=device, + requires_grad=requires_grad) + + self.blocks = nn.Sequential( + OrderedDict([ + ('input', make_conv(input_channels, 1 * n_hid, 7)), + ('group_1', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid)) + for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_2', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(1 * n_hid if i == 0 else 2 * n_hid, + 2 * n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_3', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(2 * n_hid if i == 0 else 4 * n_hid, + 4 * n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_4', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(4 * n_hid if i == 0 else 8 * n_hid, + 8 * n_hid)) for i in blk_range], + ]))), + ('output', + nn.Sequential( + OrderedDict([ + ('relu', nn.ReLU()), + ('conv', + make_conv( + 8 * n_hid, vocab_size, 1, use_float16=False)), + ]))), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function of DALL-E encoder. + + Args: + x (torch.Tensor): The input images with shape (B, C, H, W). + + Returns: + torch.Tensor: The output with shape (B, vocab_size, h, w). + """ + x = x.float() + if len(x.shape) != 4: + raise ValueError(f'input shape {x.shape} is not 4d') + if x.shape[1] != self.input_channels: + raise ValueError(f'input has {x.shape[1]} channels but model \ + built for {self.input_channels}') + if x.dtype != torch.float32: + raise ValueError('input must have dtype torch.float32') + + return self.blocks(x) + + +@MODELS.register_module() +class CAEPretrainViT(BEiTViT): + """Vision Transformer for CAE pre-training and the implementation is based + on BEiTViT. + + Args: + arch (str | dict): Vision Transformer architecture. Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float, optional): The init value of gamma in + BEiTTransformerEncoderLayer. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + arch: str = 'b', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + bias: bool = 'qv_bias', + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + frozen_stages: int = -1, + use_abs_pos_emb: bool = True, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = False, + layer_scale_init_value: float = None, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: dict = [ + dict(type='Constant', val=1, layer=['LayerNorm']), + dict(type='TruncNormal', std=0.02, layer=['Conv2d']), + dict(type='Xavier', distribution='uniform', layer=['Linear']) + ] + ) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + bias=bias, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + frozen_stages=frozen_stages, + use_abs_pos_emb=use_abs_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + self.pos_embed.requires_grad = False + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # initialize position embedding in backbone + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + trunc_normal_(self.cls_token, std=.02) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> torch.Tensor: + """Generate features for masked images. + + This function generates mask images and get the hidden features for + visible patches. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (torch.Tensor, optional): Mask for input, which is of shape + B x L. + + Returns: + torch.Tensor: hidden features. + """ + if mask is None: + return super().forward(x) + + else: + x, _ = self.patch_embed(x) + batch_size, _, dim = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + + # NOTE: unmasked embeddings + x_unmasked = x[~mask].reshape(batch_size, -1, dim) + x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) + + pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1, + dim) + pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( + batch_size, -1, dim) + pos_embed_unmasked = torch.cat( + (pos_embed[:, :1], pos_embed_unmasked), dim=1) + x_unmasked = x_unmasked + pos_embed_unmasked + + x_unmasked = self.drop_after_pos(x_unmasked) + + for i, layer in enumerate(self.layers): + x_unmasked = layer(x=x_unmasked, rel_pos_bias=None) + + if i == len(self.layers) - 1 and self.final_norm: + x_unmasked = self.norm1(x_unmasked) + + return x_unmasked + + +@MODELS.register_module() +class CAE(BaseSelfSupervisor): + """CAE. + + Implementation of `Context Autoencoder for Self-Supervised Representation + Learning `_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of neck. + head (dict): Config dict for module of head functions. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + base_momentum (float): The base momentum coefficient for the target + network. Defaults to 0.0. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + target_generator: Optional[dict] = None, + base_momentum: float = 0.0, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + target_generator=target_generator, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + self.momentum = base_momentum + self.teacher = MODELS.build(backbone) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + # init the weights of teacher with those of backbone + for param_backbone, param_teacher in zip(self.backbone.parameters(), + self.teacher.parameters()): + param_teacher.detach() + param_teacher.data.copy_(param_backbone.data) + param_teacher.requires_grad = False + + def momentum_update(self) -> None: + """Momentum update of the teacher network.""" + for param_bacbone, param_teacher in zip(self.backbone.parameters(), + self.teacher.parameters()): + param_teacher.data = param_teacher.data * self.momentum + \ + param_bacbone.data * (1. - self.momentum) + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + mask = mask.flatten(1).to(torch.bool) + + unmasked = self.backbone(inputs[0], mask) + + # get the latent prediction for the masked patches + with torch.no_grad(): + # inputs[0] is the prediction image + latent_target = self.teacher(inputs[0], ~mask) + latent_target = latent_target[:, 1:, :] + self.momentum_update() + + pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1) + pos_embed_masked = pos_embed[:, + 1:][mask].reshape(inputs[0].shape[0], -1, + pos_embed.shape[-1]) + pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( + inputs[0].shape[0], -1, pos_embed.shape[-1]) + + # input the unmasked tokens and masked tokens to the decoder + logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked, + pos_embed_unmasked) + + logits = logits.view(-1, logits.shape[-1]) + # inputs[1] is the target image + logits_target = self.target_generator(inputs[1]) + loss_main, loss_align = self.head.loss(logits, logits_target, + latent_pred, latent_target, + mask) + losses = dict() + + losses['loss'] = loss_main + loss_align + losses['main'] = loss_main + losses['align'] = loss_align + return losses diff --git a/mmpretrain/models/selfsup/densecl.py b/mmpretrain/models/selfsup/densecl.py new file mode 100644 index 0000000000000000000000000000000000000000..c969af17fa921a119f6b05b5a319e104f6422494 --- /dev/null +++ b/mmpretrain/models/selfsup/densecl.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_gather +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class DenseCL(BaseSelfSupervisor): + """DenseCL. + + Implementation of `Dense Contrastive Learning for Self-Supervised Visual + Pre-Training `_. + Borrowed from the authors' code: ``_. + The loss_lambda warmup is in `engine/hooks/densecl_hook.py`. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to compact + feature vectors. + head (dict): Config dict for module of head functions. + queue_len (int): Number of negative keys maintained in the queue. + Defaults to 65536. + feat_dim (int): Dimension of compact feature vectors. Defaults to 128. + momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.999. + loss_lambda (float): Loss weight for the single and dense contrastive + loss. Defaults to 0.5. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + queue_len: int = 65536, + feat_dim: int = 128, + momentum: float = 0.001, + loss_lambda: float = 0.5, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.encoder_k = ExponentialMovingAverage( + nn.Sequential(self.backbone, self.neck), momentum) + + self.queue_len = queue_len + self.loss_lambda = loss_lambda + + # create the queue + self.register_buffer('queue', torch.randn(feat_dim, queue_len)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + # create the second queue for dense output + self.register_buffer('queue2', torch.randn(feat_dim, queue_len)) + self.queue2 = nn.functional.normalize(self.queue2, dim=0) + self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: + """Update queue.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None: + """Update queue2.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue2_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue2[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue2_ptr[0] = ptr + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + im_q = inputs[0] + im_k = inputs[1] + # compute query features + q_b = self.backbone(im_q) # backbone features + q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2 + q_b = q_b[0] + q_b = q_b.view(q_b.size(0), q_b.size(1), -1) + + q = nn.functional.normalize(q, dim=1) + q2 = nn.functional.normalize(q2, dim=1) + q_grid = nn.functional.normalize(q_grid, dim=1) + q_b = nn.functional.normalize(q_b, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + # update the key encoder + self.encoder_k.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + # shuffle for making use of BN + im_k, idx_unshuffle = batch_shuffle_ddp(im_k) + + k_b = self.encoder_k.module[0](im_k) # backbone features + k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2 + k_b = k_b[0] + k_b = k_b.view(k_b.size(0), k_b.size(1), -1) + + k = nn.functional.normalize(k, dim=1) + k2 = nn.functional.normalize(k2, dim=1) + k_grid = nn.functional.normalize(k_grid, dim=1) + k_b = nn.functional.normalize(k_b, dim=1) + + # undo shuffle + k = batch_unshuffle_ddp(k, idx_unshuffle) + k2 = batch_unshuffle_ddp(k2, idx_unshuffle) + k_grid = batch_unshuffle_ddp(k_grid, idx_unshuffle) + k_b = batch_unshuffle_ddp(k_b, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # feat point set sim + backbone_sim_matrix = torch.matmul(q_b.permute(0, 2, 1), k_b) + densecl_sim_ind = backbone_sim_matrix.max(dim=2)[1] # NxS^2 + + indexed_k_grid = torch.gather(k_grid, 2, + densecl_sim_ind.unsqueeze(1).expand( + -1, k_grid.size(1), -1)) # NxCxS^2 + densecl_sim_q = (q_grid * indexed_k_grid).sum(1) # NxS^2 + + # dense positive logits: NS^2X1 + l_pos_dense = densecl_sim_q.view(-1).unsqueeze(-1) + + q_grid = q_grid.permute(0, 2, 1) + q_grid = q_grid.reshape(-1, q_grid.size(2)) + # dense negative logits: NS^2xK + l_neg_dense = torch.einsum( + 'nc,ck->nk', [q_grid, self.queue2.clone().detach()]) + + loss_single = self.head.loss(l_pos, l_neg) + loss_dense = self.head.loss(l_pos_dense, l_neg_dense) + + losses = dict() + losses['loss_single'] = loss_single * (1 - self.loss_lambda) + losses['loss_dense'] = loss_dense * self.loss_lambda + + self._dequeue_and_enqueue(k) + self._dequeue_and_enqueue2(k2) + + return losses diff --git a/mmpretrain/models/selfsup/eva.py b/mmpretrain/models/selfsup/eva.py new file mode 100644 index 0000000000000000000000000000000000000000..30779bec491ae7c95b6540cdc7d71a875da572de --- /dev/null +++ b/mmpretrain/models/selfsup/eva.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class EVA(BaseSelfSupervisor): + """EVA. + + Implementation of `EVA: Exploring the Limits of Masked Visual + Representation Learning at Scale `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + clip_feature, _ = self.target_generator(inputs) + + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + + clip_feature = clip_feature[:, 1:, :] + loss = self.head.loss(pred, clip_feature, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py new file mode 100644 index 0000000000000000000000000000000000000000..488a99631820e866c8cc743168b65f237fa136b2 --- /dev/null +++ b/mmpretrain/models/selfsup/itpn.py @@ -0,0 +1,359 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class iTPNHiViT(HiViT): + """HiViT for iTPN pre-training. + + Args: + img_size (int | tuple): Input image size. Defaults to 224. + patch_size (int | tuple): The patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim + in the first two stages. Defaults to 3. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in + the last stage. Defaults to 4. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): If True, add absolute position embedding to + the patch embedding. + rpe (bool): If True, add relative position embedding to + the patch embedding. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + reconstruction_type (str): The reconstruction of self-supervised + learning. Defaults to 'pixel'. + """ + + def __init__( + self, + arch='base', + img_size: int = 224, + patch_size: int = 16, + inner_patches: int = 4, + stem_mlp_ratio: int = 3., + mlp_ratio: int = 4., + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ape: bool = True, + rpe: bool = False, + layer_scale_init_value: float = 0.0, + mask_ratio: float = 0.75, + reconstruction_type: str = 'pixel', + **kwargs, + ): + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + stem_mlp_ratio=stem_mlp_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + ape=ape, + rpe=rpe, + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) + + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + + assert reconstruction_type in ['pixel', 'clip'], \ + 'iTPN method only support `pixel` and `clip`, ' \ + f'but got `{reconstruction_type}`.' + self.reconstruction_type = reconstruction_type + self.num_patches = self.patch_embed.num_patches + + if reconstruction_type == 'clip': + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().apply(self._init_weights) + + if self.reconstruction_type == 'clip': + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + else: + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=False) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + if isinstance(layer, BlockWithRPE): + if layer.attn is not None: + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def masking_id(self, batch_size, mask_ratio): + N, L = batch_size, self.pos_embed.size(1) + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand( + N, L, device=self.pos_embed.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=self.pos_embed.device) + mask[:, :ids_keep.size(1)] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return ids_keep, ids_restore, mask + + def forward_pixel( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) + + x = self.patch_embed(x) + + x = torch.gather( + x, + dim=1, + index=ids_keep[:, :, None, None, + None].expand(-1, -1, *x.shape[2:])) + + outs = [] + for blk in self.blocks[:-self.num_main_blocks]: + if isinstance(blk, PatchMerge): + outs.append(x) + x = blk(x) + + x = x[..., 0, 0, :] + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + pos_embed = torch.gather( + pos_embed.expand(B, -1, -1), + dim=1, + index=ids_keep[:, :, None].expand(-1, -1, + pos_embed.shape[2]), + ) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x) + + outs.append(x) + + return (tuple(outs), mask, ids_restore) + + def forward_clip(self, + x: torch.Tensor, + mask: Optional[bool] = True) -> Tuple: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + x = self.patch_embed(x) + + outs = [] + for blk in self.blocks[:-self.num_main_blocks]: + if isinstance(blk, PatchMerge): + outs.append(x) + x = blk(x) + + x = x[..., 0, 0, :] + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + x = x + pos_embed + x = self.pos_drop(x) + + rpe_index = True if self.rpe else None + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x, rpe_index) + + outs.append(x) + + return tuple(outs) + + def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + + if self.reconstruction_type == 'pixel': + return self.forward_pixel(x, mask) + return self.forward_clip(x, mask) + + +@MODELS.register_module() +class iTPN(BaseSelfSupervisor): + """iTPN. + + Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid + Networks `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + if self.backbone.reconstruction_type == 'pixel': + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + + loss = self.head.loss(pred, inputs, mask) + else: + mask = torch.stack( + [data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs[0], mask) + + # inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(inputs[1])[0] + target = target.detach() + + # iTPN contains a neck module + feats = self.neck(img_latent) + loss = self.head.loss(feats, target[:, 1:, :], mask) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mae.py b/mmpretrain/models/selfsup/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..01bc5bc5134e02488556eacd8cfc30c2fae44fea --- /dev/null +++ b/mmpretrain/models/selfsup/mae.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch + +from mmpretrain.models import HiViT, VisionTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MAEViT(VisionTransformer): + """Vision Transformer for MAE pre-training. + + A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + # position embedding is not learnable during pretraining + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.projection.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + torch.nn.init.normal_(self.cls_token, std=.02) + + def random_masking( + self, + x: torch.Tensor, + mask_ratio: float = 0.75 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: masked image, mask + and the ids to restore original image. + + - ``x_masked`` (torch.Tensor): masked image. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return (x, mask, ids_restore) + + +@MODELS.register_module() +class MAE(BaseSelfSupervisor): + """MAE. + + Implementation of `Masked Autoencoders Are Scalable Vision Learners + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + loss = self.head.loss(pred, inputs, mask) + losses = dict(loss=loss) + return losses + + +@MODELS.register_module() +class MAEHiViT(HiViT): + """HiViT for MAE pre-training. + + A PyTorch implement of: `HiViT: A Simple and More Efficient Design + of Hierarchical Vision Transformer `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + Defaults to 4, to downsample 4x at the first stage + inner_patches (int): The inner patches within a token + Defaults to 4 + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): the absolute position embedding + rpe (bool): the relative position embedding + Defaults to False + layer_scale_init_value (float): the layer scale init value + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + inner_patches: int = 4, + out_indices: Union[list, int] = [23], + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ape: bool = True, + rpe: bool = False, + layer_scale_init_value: float = 0.0, + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + ape=ape, + rpe=rpe, + layer_scale_init_value=layer_scale_init_value, + init_cfg=init_cfg) + + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_embed.num_patches + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding.""" + super().apply(self._init_weights) + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=False) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def masking_id( + self, batch_size, + mask_ratio) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + batch_size: The batch size of input data + mask_ratio: The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: the ids + for the tokens retained, the ids to restore original image, + and the mask + """ + N, L = batch_size, self.pos_embed.size(1) + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand( + N, L, device=self.pos_embed.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=self.pos_embed.device) + mask[:, :ids_keep.size(1)] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return ids_keep, ids_restore, mask + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) + + x = self.patch_embed(x) + + x = torch.gather( + x, + dim=1, + index=ids_keep[:, :, None, None, + None].expand(-1, -1, *x.shape[2:])) + + for blk in self.blocks[:-self.num_main_blocks]: + x = blk(x) + + x = x[..., 0, 0, :] + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + pos_embed = torch.gather( + pos_embed.expand(B, -1, -1), + dim=1, + index=ids_keep[:, :, None].expand(-1, -1, + pos_embed.shape[2]), + ) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x) + + return (x, mask, ids_restore) diff --git a/mmpretrain/models/selfsup/maskfeat.py b/mmpretrain/models/selfsup/maskfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9f0b296c44cdffe7f2a40caae04de0104abd60 --- /dev/null +++ b/mmpretrain/models/selfsup/maskfeat.py @@ -0,0 +1,336 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Sequence, Union + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models import VisionTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class HOGGenerator(BaseModule): + """Generate HOG feature for images. + + This module is used in MaskFeat to generate HOG feature. The code is + modified from file `slowfast/models/operators.py + `_. + Here is the link of `HOG wikipedia + `_. + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16) -> None: + super().__init__() + self.nbins = nbins + self.pool = pool + self.pi = math.pi + weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous() + weight_y = weight_x.transpose(2, 3).contiguous() + self.register_buffer('weight_x', weight_x) + self.register_buffer('weight_y', weight_y) + + self.gaussian_window = gaussian_window + if gaussian_window: + gaussian_kernel = self.get_gaussian_kernel(gaussian_window, + gaussian_window // 2) + self.register_buffer('gaussian_kernel', gaussian_kernel) + + def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor: + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor: + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + kernel_1d = _gaussian_fn(kernlen, std) + kernel_2d = kernel_1d[:, None] * kernel_1d[None, :] + return kernel_2d / kernel_2d.sum() + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + """Reshape HOG Features for output.""" + hog_feat = hog_feat.flatten(1, 2) + self.unfold_size = hog_feat.shape[-1] // 14 + hog_feat = hog_feat.permute(0, 2, 3, 1) + hog_feat = hog_feat.unfold(1, self.unfold_size, + self.unfold_size).unfold( + 2, self.unfold_size, self.unfold_size) + hog_feat = hog_feat.flatten(1, 2).flatten(2) + return hog_feat + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate hog feature for each batch images. + + Args: + x (torch.Tensor): Input images of shape (N, 3, H, W). + + Returns: + torch.Tensor: Hog features. + """ + # input is RGB image with shape [B 3 H W] + self.h, self.w = x.size(-2), x.size(-1) + x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') + gx_rgb = F.conv2d( + x, self.weight_x, bias=None, stride=1, padding=0, groups=3) + gy_rgb = F.conv2d( + x, self.weight_y, bias=None, stride=1, padding=0, groups=3) + norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) + phase = torch.atan2(gx_rgb, gy_rgb) + phase = phase / self.pi * self.nbins # [-9, 9] + + b, c, h, w = norm_rgb.shape + out = torch.zeros((b, c, self.nbins, h, w), + dtype=torch.float, + device=x.device) + phase = phase.view(b, c, 1, h, w) + norm_rgb = norm_rgb.view(b, c, 1, h, w) + if self.gaussian_window: + if h != self.gaussian_window: + assert h % self.gaussian_window == 0, 'h {} gw {}'.format( + h, self.gaussian_window) + repeat_rate = h // self.gaussian_window + temp_gaussian_kernel = self.gaussian_kernel.repeat( + [repeat_rate, repeat_rate]) + else: + temp_gaussian_kernel = self.gaussian_kernel + norm_rgb *= temp_gaussian_kernel + + out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) + + out = out.unfold(3, self.pool, self.pool) + out = out.unfold(4, self.pool, self.pool) + out = out.sum(dim=[-1, -2]) + + self.out = F.normalize(out, p=2, dim=2) + + return self._reshape(self.out) + + def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray: + """Generate HOG image according to HOG features.""" + assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \ + 'Check the input batch size and the channcel number, only support'\ + '"batch_size = 1".' + hog_image = np.zeros([self.h, self.w]) + cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu()) + cell_width = self.pool / 2 + max_mag = np.array(cell_gradient).max() + angle_gap = 360 / self.nbins + + for x in range(cell_gradient.shape[1]): + for y in range(cell_gradient.shape[2]): + cell_grad = cell_gradient[:, x, y] + cell_grad /= max_mag + angle = 0 + for magnitude in cell_grad: + angle_radian = math.radians(angle) + x1 = int(x * self.pool + + magnitude * cell_width * math.cos(angle_radian)) + y1 = int(y * self.pool + + magnitude * cell_width * math.sin(angle_radian)) + x2 = int(x * self.pool - + magnitude * cell_width * math.cos(angle_radian)) + y2 = int(y * self.pool - + magnitude * cell_width * math.sin(angle_radian)) + magnitude = 0 if magnitude < 0 else magnitude + cv2.line(hog_image, (y1, x1), (y2, x2), + int(255 * math.sqrt(magnitude))) + angle += angle_gap + return hog_image + + +@MODELS.register_module() +class MaskFeatViT(VisionTransformer): + """Vision Transformer for MaskFeat pre-training. + + A PyTorch implement of: `Masked Feature Prediction for Self-Supervised + Visual Pre-Training `_. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.parameter.Parameter( + torch.zeros(1, 1, self.embed_dims), requires_grad=True) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, mask token and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + nn.init.trunc_normal_(self.cls_token, std=.02) + nn.init.trunc_normal_(self.mask_token, std=.02) + nn.init.trunc_normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m: torch.nn.Module) -> None: + if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> torch.Tensor: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor, optional): Input masks. + + Returns: + torch.Tensor: Features with cls_tokens. + """ + if mask is None: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + + # masking: length -> length * mask_ratio + B, L, _ = x.shape + mask_tokens = self.mask_token.expand(B, L, -1) + mask = mask.unsqueeze(-1) + x = x * (1 - mask.int()) + mask_tokens * mask + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + return x + + +@MODELS.register_module() +class MaskFeat(BaseSelfSupervisor): + """MaskFeat. + + Implementation of `Masked Feature Prediction for Self-Supervised Visual + Pre-Training `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + mask = mask.flatten(1).bool() + + latent = self.backbone(inputs, mask) + B, L, C = latent.shape + pred = self.neck((latent.view(B * L, C), )) + pred = pred[0].view(B, L, -1) + hog = self.target_generator(inputs) + + # remove cls_token before compute loss + loss = self.head.loss(pred[:, 1:], hog, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mff.py b/mmpretrain/models/selfsup/mff.py new file mode 100644 index 0000000000000000000000000000000000000000..268505805777399c632643fa9ac1e4be6fc271c6 --- /dev/null +++ b/mmpretrain/models/selfsup/mff.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F + +from mmpretrain.models.selfsup.mae import MAE, MAEViT +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class MFFViT(MAEViT): + """Vision Transformer for MFF Pretraining. + + This class inherits all these functionalities from ``MAEViT``, and + add multi-level feature fusion to it. For more details, you can + refer to `Improving Pixel-based MIM by Reducing Wasted Modeling + Capability`. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + mask_ratio=mask_ratio, + init_cfg=init_cfg) + proj_layers = [ + torch.nn.Linear(self.embed_dims, self.embed_dims) + for _ in range(len(self.out_indices) - 1) + ] + self.proj_layers = torch.nn.ModuleList(proj_layers) + self.proj_weights = torch.nn.Parameter( + torch.ones(len(self.out_indices)).view(-1, 1, 1, 1)) + if len(self.out_indices) == 1: + self.proj_weights.requires_grad = False + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + res = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + if i != self.out_indices[-1]: + proj_x = self.proj_layers[self.out_indices.index(i)](x) + else: + proj_x = x + res.append(proj_x) + res = torch.stack(res) + proj_weights = F.softmax(self.proj_weights, dim=0) + res = res * proj_weights + res = res.sum(dim=0) + + # Use final norm + x = self.norm1(res) + return (x, mask, ids_restore, proj_weights.view(-1)) + + +@MODELS.register_module() +class MFF(MAE): + """MFF. + + Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling + Capability`. + """ + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore, weights = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + loss = self.head.loss(pred, inputs, mask) + weight_params = { + f'weight_{i}': weights[i] + for i in range(weights.size(0)) + } + losses = dict(loss=loss) + losses.update(weight_params) + return losses diff --git a/mmpretrain/models/selfsup/milan.py b/mmpretrain/models/selfsup/milan.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf86737af3499e6f6309aa5c5ddadef00f63740 --- /dev/null +++ b/mmpretrain/models/selfsup/milan.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.runner.checkpoint import _load_checkpoint + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_clip_model +from .base import BaseSelfSupervisor +from .mae import MAEViT + + +@MODELS.register_module() +class CLIPGenerator(nn.Module): + """Get the features and attention from the last layer of CLIP. + + This module is used to generate target features in masked image modeling. + + Args: + tokenizer_path (str): The path of the checkpoint of CLIP. + """ + + def __init__(self, tokenizer_path: str) -> None: + super().__init__() + self.tokenizer_path = tokenizer_path + self.tokenizer = build_clip_model( + _load_checkpoint(self.tokenizer_path), False) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the features and attention from the last layer of CLIP. + + Args: + x (torch.Tensor): The input image, which is of shape (N, 3, H, W). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The features and attention from + the last layer of CLIP, which are of shape (N, L, C) and (N, L, L), + respectively. + """ + # use the visual branch of CLIP to get the features + assert self.tokenizer is not None, 'Please check whether the ' \ + '`self.tokenizer` is initialized correctly.' + + clip_features = self.tokenizer.encode_image(x) + return clip_features + + +@MODELS.register_module() +class MILANViT(MAEViT): + """Vision Transformer for MILAN pre-training. + + Implementation of the encoder for `MILAN: Masked Image Pretraining on + Language Assisted Representation `_. + + This module inherits from MAEViT and only overrides the forward function + and replace random masking with attention masking. + """ + + def attention_masking( + self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate attention mask for MILAN. + + This is what is different from MAEViT, which uses random masking. + Attention masking generates attention mask for MILAN, according to + importance. The higher the importance, the more likely the patch is + kept. + + Args: + x (torch.Tensor): Input images, which is of shape B x L x C. + mask_ratio (float): The ratio of patches to be masked. + importance (torch.Tensor): Importance of each patch, which is of + shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: + + - ``x_masked``: masked image + - ``ids_restore``: the ids to restore original image + - ``ids_keep``: ids of the kept patches + - ``ids_dump``: ids of the removed patches + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = importance.to(x.device) # large is keep, small is remove + + # sort noise for each sample + ids_shuffle = torch.multinomial(noise, L, replacement=False) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_dump = ids_shuffle[:, len_keep:] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, ids_restore, ids_keep, ids_dump + + def forward( + self, + x: torch.Tensor, + importance: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the + ``importance`` is ``None``, the function generates mask and masks some + patches randomly and get the hidden features for visible patches. The + mask is generated by importance. The higher the importance, the more + likely the patch is kept. The importance is calculated by CLIP. + The higher the CLIP score, the more likely the patch is kept. The CLIP + score is calculated by cross attention between the class token and all + other tokens from the last layer. + If the ``importance`` is ``torch.Tensor``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + importance (torch.Tensor, optional): Importance of each patch, + which is of shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: masked image, the ids to restore original + image, ids of the kept patches, ids of the removed patches. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + - ``ids_keep`` (torch.Tensor): ids of the kept patches. + - ``ids_dump`` (torch.Tensor): ids of the removed patches. + """ + if importance is None: + return super(MAEViT, self).forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, ids_restore, ids_keep, ids_dump = self.attention_masking( + x, self.mask_ratio, importance) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return x, ids_restore, ids_keep, ids_dump + + +@MODELS.register_module() +class MILAN(BaseSelfSupervisor): + """MILAN. + + Implementation of `MILAN: Masked Image Pretraining on Language Assisted + Representation `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, importance=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + clip_feature, importance = self.target_generator(inputs) + importance = importance[:, 0, 1:] + latent, ids_restore, ids_keep, ids_dump = self.backbone( + inputs, importance) + pred = self.neck(latent, ids_restore, ids_keep, ids_dump) + + loss = self.head.loss(pred, clip_feature) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mixmim.py b/mmpretrain/models/selfsup/mixmim.py new file mode 100644 index 0000000000000000000000000000000000000000..b202f836f64358369276a9b85795fb6eec769fb7 --- /dev/null +++ b/mmpretrain/models/selfsup/mixmim.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from mmpretrain.models.backbones import MixMIMTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MixMIMPretrainTransformer(MixMIMTransformer): + """MixMIM backbone for MixMIM pre-training. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + attn_drop_rate (float): Attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory + cost. Defaults to False. + mask_ratio (bool): The base ratio of total number of patches to be + masked. Defaults to 0.5. + range_mask_ratio (float): The range of mask ratio. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'base', + mlp_ratio: float = 4, + img_size: int = 224, + patch_size: int = 4, + in_channels: int = 3, + window_size: List = [14, 14, 14, 7], + qkv_bias: bool = True, + patch_cfg: dict = dict(), + norm_cfg: dict = dict(type='LN'), + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + attn_drop_rate: float = 0.0, + use_checkpoint: bool = False, + mask_ratio: float = 0.5, + range_mask_ratio: float = 0.0, + init_cfg: Optional[dict] = None) -> None: + + super().__init__( + arch=arch, + mlp_ratio=mlp_ratio, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + window_size=window_size, + qkv_bias=qkv_bias, + patch_cfg=patch_cfg, + norm_cfg=norm_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + use_checkpoint=use_checkpoint, + init_cfg=init_cfg) + + self.mask_ratio = mask_ratio + self.range_mask_ratio = range_mask_ratio + + def init_weights(self): + """Initialize position embedding, patch embedding.""" + super(MixMIMTransformer, self).init_weights() + + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.absolute_pos_embed.shape[-1], + cls_token=False) + self.absolute_pos_embed.data.copy_(pos_embed.float()) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def random_masking(self, + x: torch.Tensor, + mask_ratio: float = 0.5) -> Tuple[torch.Tensor]: + """Generate the mask for MixMIM Pretraining. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.5. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - mask_s1 (torch.Tensor): mask with stride of + self.encoder_stride // 8. + - mask_s2 (torch.Tensor): mask with stride of + self.encoder_stride // 4. + - mask_s3 (torch.Tensor): mask with stride of + self.encoder_stride // 2. + - mask (torch.Tensor): mask with stride of + self.encoder_stride. + """ + + B, C, H, W = x.shape + out_H = H // self.encoder_stride + out_W = W // self.encoder_stride + s3_H, s3_W = out_H * 2, out_W * 2 + s2_H, s2_W = out_H * 4, out_W * 4 + s1_H, s1_W = out_H * 8, out_W * 8 + + seq_l = out_H * out_W + # use a shared mask for a batch images + mask = torch.zeros([1, 1, seq_l], device=x.device) + + mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio) + noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1] + # ascend: small is keep, large is removed + mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)] + mask.scatter_(2, mask_idx, 1) + mask = mask.reshape(1, 1, out_H, out_W) + mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest') + mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest') + mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest') + + mask = mask.reshape(1, out_H * out_W, 1).contiguous() + mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous() + mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous() + mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous() + + return mask_s1, mask_s2, mask_s3, mask + + def forward(self, + x: torch.Tensor, + mask: Optional[bool] = True) -> Tuple[torch.Tensor]: + """Generate features for masked images. + + This function generates mask and masks some patches randomly and get + the hidden features for visible patches. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward containing + ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - x (torch.Tensor): hidden features, which is of shape + B x L x C. + - mask_s4 (torch.Tensor): the mask tensor for the last layer. + """ + if mask is None or False: + return super().forward(x) + + else: + mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking( + x, self.mask_ratio) + + x, _ = self.patch_embed(x) + + x = x * (1. - mask_s1) + x.flip(0) * mask_s1 + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for idx, layer in enumerate(self.layers): + if idx == 0: + x = layer(x, attn_mask=mask_s1) + elif idx == 1: + x = layer(x, attn_mask=mask_s2) + elif idx == 2: + x = layer(x, attn_mask=mask_s3) + elif idx == 3: + x = layer(x, attn_mask=mask_s4) + + x = self.norm(x) + + return x, mask_s4 + + +@MODELS.register_module() +class MixMIM(BaseSelfSupervisor): + """MixMIM. + + Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient + Visual Representation Learning. `_. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + + head.update(dict(patch_size=neck['encoder_stride'])) + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + latent, mask = self.backbone(inputs) + x_rec = self.neck(latent, mask) + loss = self.head.loss(x_rec, inputs, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/moco.py b/mmpretrain/models/selfsup/moco.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff4cf8fd6d0d6bca4724965d3b6d09543317748 --- /dev/null +++ b/mmpretrain/models/selfsup/moco.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_gather +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MoCo(BaseSelfSupervisor): + """MoCo. + + Implementation of `Momentum Contrast for Unsupervised Visual + Representation Learning `_. + Part of the code is borrowed from: + ``_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to compact feature + vectors. + head (dict): Config dict for module of head functions. + queue_len (int): Number of negative keys maintained in the + queue. Defaults to 65536. + feat_dim (int): Dimension of compact feature vectors. + Defaults to 128. + momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.001. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + queue_len: int = 65536, + feat_dim: int = 128, + momentum: float = 0.001, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.encoder_k = ExponentialMovingAverage( + nn.Sequential(self.backbone, self.neck), momentum) + + # create the queue + self.queue_len = queue_len + self.register_buffer('queue', torch.randn(feat_dim, queue_len)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: + """Update queue.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue_ptr[0] = ptr + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + im_q = inputs[0] + im_k = inputs[1] + # compute query features from encoder_q + q = self.neck(self.backbone(im_q))[0] # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + # update the key encoder + self.encoder_k.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + # shuffle for making use of BN + im_k, idx_unshuffle = batch_shuffle_ddp(im_k) + + k = self.encoder_k(im_k)[0] # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # undo shuffle + k = batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + loss = self.head.loss(l_pos, l_neg) + # update the queue + self._dequeue_and_enqueue(k) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mocov3.py b/mmpretrain/models/selfsup/mocov3.py new file mode 100644 index 0000000000000000000000000000000000000000..61b803387fdc129bc29056ee369fa3ad36c13e07 --- /dev/null +++ b/mmpretrain/models/selfsup/mocov3.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import reduce +from operator import mul +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones import VisionTransformer +from mmpretrain.models.utils import (build_2d_sincos_position_embedding, + to_2tuple) +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import CosineEMA +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MoCoV3ViT(VisionTransformer): + """Vision Transformer for MoCoV3 pre-training. + + A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for + Image Recognition at Scale `_. + + Part of the code is modified from: + ``_. + + Args: + stop_grad_conv1 (bool): whether to stop the gradient of + convolution layer in `PatchEmbed`. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + stop_grad_conv1: bool = False, + frozen_stages: int = -1, + norm_eval: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = None, + **kwargs) -> None: + + # add MoCoV3 ViT-small arch + self.arch_zoo.update( + dict.fromkeys( + ['mocov3-s', 'mocov3-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 1536, + })) + + super().__init__(init_cfg=init_cfg, **kwargs) + self.patch_size = kwargs['patch_size'] + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.init_cfg = init_cfg + + if stop_grad_conv1: + self.patch_embed.projection.weight.requires_grad = False + self.patch_embed.projection.bias.requires_grad = False + + self._freeze_stages() + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding, qkv layers and cls + token.""" + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + # Use fixed 2D sin-cos position embedding + pos_emb = build_2d_sincos_position_embedding( + patches_resolution=self.patch_resolution, + embed_dims=self.embed_dims, + cls_token=True) + self.pos_embed.data.copy_(pos_emb) + self.pos_embed.requires_grad = False + + # xavier_uniform initialization for PatchEmbed + val = math.sqrt( + 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + + self.embed_dims)) + nn.init.uniform_(self.patch_embed.projection.weight, -val, val) + nn.init.zeros_(self.patch_embed.projection.bias) + + # initialization for linear layers + for name, m in self.named_modules(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / + float(m.weight.shape[0] // 3 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + else: + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + nn.init.normal_(self.cls_token, std=1e-6) + + def _freeze_stages(self) -> None: + """Freeze patch_embed layer, some parameters and stages.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + self.cls_token.requires_grad = False + self.pos_embed.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i == (self.num_layers) and self.final_norm: + for param in getattr(self, 'norm1').parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> None: + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class MoCoV3(BaseSelfSupervisor): + """MoCo v3. + + Implementation of `An Empirical Study of Training Self-Supervised Vision + Transformers `_. + + Args: + backbone (dict): Config dict for module of backbone + neck (dict): Config dict for module of deep features to compact feature + vectors. + head (dict): Config dict for module of head functions. + base_momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.01. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + base_momentum: float = 0.01, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.momentum_encoder = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + view_1 = inputs[0] + view_2 = inputs[1] + + # compute query features, [N, C] each + q1 = self.neck(self.backbone(view_1))[0] + q2 = self.neck(self.backbone(view_2))[0] + + # compute key features, [N, C] each, no gradient + with torch.no_grad(): + # update momentum encoder + self.momentum_encoder.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + k1 = self.momentum_encoder(view_1)[0] + k2 = self.momentum_encoder(view_2)[0] + + loss = self.head.loss(q1, k2) + self.head.loss(q2, k1) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/simclr.py b/mmpretrain/models/selfsup/simclr.py new file mode 100644 index 0000000000000000000000000000000000000000..4b19ab4053de21a865fbaf864f654ff3ad8840f1 --- /dev/null +++ b/mmpretrain/models/selfsup/simclr.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple + +import torch +from mmengine.dist import all_gather, get_rank + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +class GatherLayer(torch.autograd.Function): + """Gather tensors from all process, supporting backward propagation.""" + + @staticmethod + def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]: + ctx.save_for_backward(input) + output = all_gather(input) + return tuple(output) + + @staticmethod + def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor: + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[get_rank()] + return grad_out + + +@MODELS.register_module() +class SimCLR(BaseSelfSupervisor): + """SimCLR. + + Implementation of `A Simple Framework for Contrastive Learning of Visual + Representations `_. + """ + + @staticmethod + def _create_buffer( + batch_size: int, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the mask and the index of positive samples. + + Args: + batch_size (int): The batch size. + device (torch.device): The device of backend. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - The mask for feature selection. + - The index of positive samples. + - The mask of negative samples. + """ + mask = 1 - torch.eye(batch_size * 2, dtype=torch.uint8).to(device) + pos_idx = ( + torch.arange(batch_size * 2).to(device), + 2 * torch.arange(batch_size, dtype=torch.long).unsqueeze(1).repeat( + 1, 2).view(-1, 1).squeeze().to(device)) + neg_mask = torch.ones((batch_size * 2, batch_size * 2 - 1), + dtype=torch.uint8).to(device) + neg_mask[pos_idx] = 0 + return mask, pos_idx, neg_mask + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + inputs = torch.stack(inputs, 1) + inputs = inputs.reshape((inputs.size(0) * 2, inputs.size(2), + inputs.size(3), inputs.size(4))) + x = self.backbone(inputs) + z = self.neck(x)[0] # (2n)xd + + z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10) + z = torch.cat(GatherLayer.apply(z), dim=0) # (2N)xd + assert z.size(0) % 2 == 0 + N = z.size(0) // 2 + s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N) + mask, pos_idx, neg_mask = self._create_buffer(N, s.device) + + # remove diagonal, (2N)x(2N-1) + s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1) + positive = s[pos_idx].unsqueeze(1) # (2N)x1 + + # select negative, (2N)x(2N-2) + negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1) + + loss = self.head.loss(positive, negative) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/simmim.py b/mmpretrain/models/selfsup/simmim.py new file mode 100644 index 0000000000000000000000000000000000000000..635a3297df2c3f361b8a63f1ea7c5d1f9c34c28b --- /dev/null +++ b/mmpretrain/models/selfsup/simmim.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models import SwinTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SimMIMSwinTransformer(SwinTransformer): + """Swin Transformer for SimMIM pre-training. + + Args: + Args: + arch (str | dict): Swin Transformer architecture + Defaults to 'T'. + img_size (int | tuple): The size of input image. + Defaults to 224. + in_channels (int): The num of input channels. + Defaults to 3. + drop_rate (float): Dropout rate after embedding. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. + Defaults to 0.1. + out_indices (tuple): Layers to be outputted. Defaults to (3, ). + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer at end + of backbone. Defaults to dict(type='LN') + stage_cfgs (Sequence | dict): Extra config dict for each + stage. Defaults to empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to empty dict. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'T', + img_size: Union[Tuple[int, int], int] = 224, + in_channels: int = 3, + drop_rate: float = 0., + drop_path_rate: float = 0.1, + out_indices: tuple = (3, ), + use_abs_pos_embed: bool = False, + with_cp: bool = False, + frozen_stages: bool = -1, + norm_eval: bool = False, + norm_cfg: dict = dict(type='LN'), + stage_cfgs: Union[Sequence, dict] = dict(), + patch_cfg: dict = dict(), + pad_small_map: bool = False, + init_cfg: Optional[dict] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + in_channels=in_channels, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + out_indices=out_indices, + use_abs_pos_embed=use_abs_pos_embed, + with_cp=with_cp, + frozen_stages=frozen_stages, + norm_eval=norm_eval, + norm_cfg=norm_cfg, + stage_cfgs=stage_cfgs, + patch_cfg=patch_cfg, + pad_small_map=pad_small_map, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + trunc_normal_(self.mask_token, mean=0, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + """Initialize weights.""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> Sequence[torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor, optional): Masks for images. + + Returns: + tuple: A tuple containing features from multi-stages. + """ + if mask is None: + return super().forward(x) + + else: + x, hw_shape = self.patch_embed(x) + B, L, _ = x.shape + + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + +@MODELS.register_module() +class SimMIM(BaseSelfSupervisor): + """SimMIM. + + Implementation of `SimMIM: A Simple Framework for Masked Image Modeling + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs, mask) + img_rec = self.neck(img_latent[0]) + loss = self.head.loss(img_rec, inputs, mask) + losses = dict(loss=loss) + + return losses diff --git a/mmpretrain/models/selfsup/simsiam.py b/mmpretrain/models/selfsup/simsiam.py new file mode 100644 index 0000000000000000000000000000000000000000..a502cd770d0b497368dc7fc1d93caac01ec65db1 --- /dev/null +++ b/mmpretrain/models/selfsup/simsiam.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SimSiam(BaseSelfSupervisor): + """SimSiam. + + Implementation of `Exploring Simple Siamese Representation Learning + `_. The operation of fixing learning rate + of predictor is in `engine/hooks/simsiam_hook.py`. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + + z1 = self.neck(self.backbone(img_v1))[0] # NxC + z2 = self.neck(self.backbone(img_v2))[0] # NxC + + loss_1 = self.head.loss(z1, z2) + loss_2 = self.head.loss(z2, z1) + + losses = dict(loss=0.5 * (loss_1 + loss_2)) + return losses diff --git a/mmpretrain/models/selfsup/spark.py b/mmpretrain/models/selfsup/spark.py new file mode 100644 index 0000000000000000000000000000000000000000..d5570a5a9b17212aa400c3c6518a8e75a5c8c6c2 --- /dev/null +++ b/mmpretrain/models/selfsup/spark.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils.norm import build_norm_layer +from ..utils.sparse_modules import SparseHelper +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SparK(BaseSelfSupervisor): + """Implementation of SparK. + + Implementation of `Designing BERT for Convolutional Networks: Sparse and + Hierarchical Masked Modeling `_. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py + """ + + def __init__( + self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + input_size: int = 224, + downsample_raito: int = 32, + mask_ratio: float = 0.6, + enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'), + enc_dec_norm_dim: int = 2048, + init_cfg: Optional[dict] = None, + ) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + self.input_size = input_size + self.downsample_raito = downsample_raito + feature_map_size = input_size // downsample_raito + self.feature_map_size = feature_map_size + + self.mask_ratio = mask_ratio + self.len_keep = round(feature_map_size * feature_map_size * + (1 - mask_ratio)) + + self.enc_dec_norm_cfg = enc_dec_norm_cfg + self.enc_dec_norms = nn.ModuleList() + self.enc_dec_projectors = nn.ModuleList() + self.mask_tokens = nn.ParameterList() + + proj_out_dim = self.neck.feature_dim + for i in range(len(self.backbone.out_indices)): + enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg, + enc_dec_norm_dim) + self.enc_dec_norms.append(enc_dec_norm) + + kernel_size = 1 if i <= 0 else 3 + proj_layer = nn.Conv2d( + enc_dec_norm_dim, + proj_out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=True) + if i == 0 and enc_dec_norm_dim == proj_out_dim: + proj_layer = nn.Identity() + self.enc_dec_projectors.append(proj_layer) + + mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1)) + trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02) + self.mask_tokens.append(mask_token) + + enc_dec_norm_dim //= 2 + proj_out_dim //= 2 + feature_map_size *= 2 + + def mask(self, + shape: torch.Size, + device: Union[torch.device, str], + generator: Optional[torch.Generator] = None): + """Mask generation. + + Args: + shape (torch.Size): The shape of the input images. + device (Union[torch.device, str]): The device of the tensor. + generator (torch.Generator, optional): Generator for random + functions. Defaults to None + Returns: + torch.Tensor: The generated mask. + """ + B, C, H, W = shape + f = self.feature_map_size + idx = torch.rand(B, f * f, generator=generator).argsort(dim=1) + idx = idx[:, :self.len_keep].to(device) # (B, len_keep) + return torch.zeros( + B, f * f, dtype=torch.bool, device=device).scatter_( + dim=1, index=idx, value=True).view(B, 1, f, f) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # active mask of feature map, (B, 1, f, f) + active_mask_feature_map = self.mask(inputs.shape, inputs.device) + SparseHelper._cur_active = active_mask_feature_map + + # active mask of original input, (B, 1, H, W) + active_mask_origin = active_mask_feature_map.repeat_interleave( + self.downsample_raito, + 2).repeat_interleave(self.downsample_raito, 3) + masked_img = inputs * active_mask_origin + + # get hierarchical encoded sparse features in a list + # containing four feature maps + feature_maps = self.backbone(masked_img) + + # from the smallest feature map to the largest + feature_maps = list(feature_maps) + feature_maps.reverse() + + cur_active = active_mask_feature_map + feature_maps_to_dec = [] + for i, feature_map in enumerate(feature_maps): + if feature_map is not None: + # fill in empty positions with [mask] embeddings + feature_map = self.enc_dec_norms[i](feature_map) + mask_token = self.mask_tokens[i].expand_as(feature_map) + feature_map = torch.where( + cur_active.expand_as(feature_map), feature_map, + mask_token.to(feature_map.dtype)) + feature_map = self.enc_dec_projectors[i](feature_map) + feature_maps_to_dec.append(feature_map) + + # dilate the mask map + cur_active = cur_active.repeat_interleave( + 2, dim=2).repeat_interleave( + 2, dim=3) + + # decode and reconstruct + rec_img = self.neck(feature_maps_to_dec) + + # compute loss + loss = self.head(rec_img, inputs, active_mask_feature_map) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/swav.py b/mmpretrain/models/selfsup/swav.py new file mode 100644 index 0000000000000000000000000000000000000000..efe0eab483319bd2dfde8929a2285e684cd3fc38 --- /dev/null +++ b/mmpretrain/models/selfsup/swav.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SwAV(BaseSelfSupervisor): + """SwAV. + + Implementation of `Unsupervised Learning of Visual Features by Contrasting + Cluster Assignments `_. + + The queue is built in ``mmpretrain/engine/hooks/swav_hook.py``. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """Forward computation during training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + # multi-res forward passes + idx_crops = torch.cumsum( + torch.unique_consecutive( + torch.tensor([input.shape[-1] for input in inputs]), + return_counts=True)[1], 0) + start_idx = 0 + output = [] + for end_idx in idx_crops: + _out = self.backbone(torch.cat(inputs[start_idx:end_idx])) + output.append(_out) + start_idx = end_idx + output = self.neck(output) + + loss = self.head.loss(output) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/tta/__init__.py b/mmpretrain/models/tta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..568e64ffdc743b4694045f39a46deb5083b2688a --- /dev/null +++ b/mmpretrain/models/tta/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .score_tta import AverageClsScoreTTA + +__all__ = ['AverageClsScoreTTA'] diff --git a/mmpretrain/models/tta/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/tta/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..751f0365fcfdbbf898ccd40c77c5636b8dc5eb54 Binary files /dev/null and b/mmpretrain/models/tta/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/tta/__pycache__/score_tta.cpython-311.pyc b/mmpretrain/models/tta/__pycache__/score_tta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0cdfbbdeb48db6aac5cf7dc0604f9213b7ac61 Binary files /dev/null and b/mmpretrain/models/tta/__pycache__/score_tta.cpython-311.pyc differ diff --git a/mmpretrain/models/tta/score_tta.py b/mmpretrain/models/tta/score_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8a0786577c6cdb5076957df0ed60aac9d307cb --- /dev/null +++ b/mmpretrain/models/tta/score_tta.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.model import BaseTTAModel + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class AverageClsScoreTTA(BaseTTAModel): + + def merge_preds( + self, + data_samples_list: List[List[DataSample]], + ) -> List[DataSample]: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[List[DataSample]]): List of predictions + of all enhanced data. + + Returns: + List[DataSample]: Merged prediction. + """ + merged_data_samples = [] + for data_samples in data_samples_list: + merged_data_samples.append(self._merge_single_sample(data_samples)) + return merged_data_samples + + def _merge_single_sample(self, data_samples): + merged_data_sample: DataSample = data_samples[0].new() + merged_score = sum(data_sample.pred_score + for data_sample in data_samples) / len(data_samples) + merged_data_sample.set_pred_score(merged_score) + return merged_data_sample diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e59d71d524308cbda3f4f693d1fb066b4a5981fa --- /dev/null +++ b/mmpretrain/models/utils/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL +from .attention import (BEiTAttention, ChannelMultiheadAttention, + CrossMultiheadAttention, LeAttention, + MultiheadAttention, PromptMultiheadAttention, + ShiftWindowMSA, WindowMSA, WindowMSAV2) +from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix +from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp +from .channel_shuffle import channel_shuffle +from .clip_generator_helper import QuickGELU, build_clip_model +from .data_preprocessor import (ClsDataPreprocessor, + MultiModalDataPreprocessor, + SelfSupDataPreprocessor, + TwoNormDataPreprocessor, VideoDataPreprocessor) +from .ema import CosineEMA +from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, + resize_relative_position_bias_table) +from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple +from .inverted_residual import InvertedResidual +from .layer_scale import LayerScale +from .make_divisible import make_divisible +from .norm import GRN, LayerNorm2d, build_norm_layer +from .position_encoding import (ConditionalPositionEncoding, + PositionEncodingFourier, RotaryEmbeddingFast, + build_2d_sincos_position_embedding) +from .res_layer_extra_norm import ResLayerExtraNorm +from .se_layer import SELayer +from .sparse_modules import (SparseAvgPooling, SparseBatchNorm2d, SparseConv2d, + SparseHelper, SparseLayerNorm2D, SparseMaxPooling, + SparseSyncBatchNorm2d) +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .vector_quantizer import NormEMAVectorQuantizer + +__all__ = [ + 'channel_shuffle', + 'make_divisible', + 'InvertedResidual', + 'SELayer', + 'to_ntuple', + 'to_2tuple', + 'to_3tuple', + 'to_4tuple', + 'PatchEmbed', + 'PatchMerging', + 'HybridEmbed', + 'RandomBatchAugment', + 'ShiftWindowMSA', + 'is_tracing', + 'MultiheadAttention', + 'ConditionalPositionEncoding', + 'resize_pos_embed', + 'resize_relative_position_bias_table', + 'ClsDataPreprocessor', + 'Mixup', + 'CutMix', + 'ResizeMix', + 'BEiTAttention', + 'LayerScale', + 'WindowMSA', + 'WindowMSAV2', + 'ChannelMultiheadAttention', + 'PositionEncodingFourier', + 'LeAttention', + 'GRN', + 'LayerNorm2d', + 'build_norm_layer', + 'CrossMultiheadAttention', + 'build_2d_sincos_position_embedding', + 'PromptMultiheadAttention', + 'NormEMAVectorQuantizer', + 'build_clip_model', + 'batch_shuffle_ddp', + 'batch_unshuffle_ddp', + 'SelfSupDataPreprocessor', + 'TwoNormDataPreprocessor', + 'VideoDataPreprocessor', + 'CosineEMA', + 'ResLayerExtraNorm', + 'MultiModalDataPreprocessor', + 'QuickGELU', + 'SwiGLUFFN', + 'SwiGLUFFNFused', + 'RotaryEmbeddingFast', + 'SparseAvgPooling', + 'SparseConv2d', + 'SparseHelper', + 'SparseMaxPooling', + 'SparseBatchNorm2d', + 'SparseLayerNorm2D', + 'SparseSyncBatchNorm2d', +] + +if WITH_MULTIMODAL: + from .huggingface import (no_load_hf_pretrained_model, register_hf_model, + register_hf_tokenizer) + from .tokenizer import (Blip2Tokenizer, BlipTokenizer, FullTokenizer, + OFATokenizer) + + __all__.extend([ + 'BlipTokenizer', 'OFATokenizer', 'Blip2Tokenizer', 'register_hf_model', + 'register_hf_tokenizer', 'no_load_hf_pretrained_model', 'FullTokenizer' + ]) diff --git a/mmpretrain/models/utils/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..964a43ae095b413aa4b5f18d475feb4ca23329e2 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/attention.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e53588193ac7a3773127d2f7c2cbbf60af96ff7 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/attention.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/batch_shuffle.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/batch_shuffle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c102f860a7e1ed30c552a12f88d4c6fe98e4ce4a Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/batch_shuffle.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/box_utils.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/box_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c02de76ac867ba709329130efc8bbdf417f10f Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/box_utils.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/channel_shuffle.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/channel_shuffle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36d77c3f5e5e64c25475f600f21256eda8466799 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/channel_shuffle.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/clip_generator_helper.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/clip_generator_helper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf947a89de812579cbff1c7c2013aa33587295d9 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/clip_generator_helper.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/data_preprocessor.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/data_preprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..331e2a66075b7698b201cf1caba709541da180d4 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/data_preprocessor.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/ema.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/ema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9fe1effceaeb8fb276001bcf09717c001887050 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/ema.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/embed.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15af2123f95d6b596eb647ab73ffb7434b693d89 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/embed.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/helpers.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29fe93ff01e501ee1713b7a08997f1e5962f4a47 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/helpers.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/inverted_residual.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/inverted_residual.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76ca7712447ce7771b7f619c836b6779e9ae33bb Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/inverted_residual.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/layer_scale.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/layer_scale.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06b5b5a89462dfe2fdc1ab2ccdd8aa56d241411d Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/layer_scale.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/make_divisible.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/make_divisible.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303dc7a76a1704bae9bdf633d8a0c58336c7e1a6 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/make_divisible.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/norm.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba272a93ba326dbf6c2867883799da2d49dbe3a Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/norm.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/position_encoding.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/position_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95208c21f5f494459f34b3d9b99d3aaf8d588e05 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/position_encoding.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/res_layer_extra_norm.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/res_layer_extra_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaa615dc1dfbdefb902ea897599fbe23e1a54304 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/res_layer_extra_norm.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/se_layer.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/se_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f228eae55f4966718d13a629b1b5dc3c02bde873 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/se_layer.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/sparse_modules.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/sparse_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8f1e74512837fc6147825efea2ae2ca125c0e4 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/sparse_modules.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/swiglu_ffn.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/swiglu_ffn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..583d308c4bd2dc7e6a20a6889af3dd042165f66f Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/swiglu_ffn.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/__pycache__/vector_quantizer.cpython-311.pyc b/mmpretrain/models/utils/__pycache__/vector_quantizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f6c5fe5249abb49a45ff6b49c9541de56d7599 Binary files /dev/null and b/mmpretrain/models/utils/__pycache__/vector_quantizer.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/attention.py b/mmpretrain/models/utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e92f6054dd83881b508ac5e87d9034cd86b3a36c --- /dev/null +++ b/mmpretrain/models/utils/attention.py @@ -0,0 +1,1129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import warnings +from functools import partial +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from .helpers import to_2tuple +from .layer_scale import LayerScale + +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +def scaled_dot_product_attention_pyimpl(query, + key, + value, + attn_mask=None, + dropout_p=0., + scale=None, + is_causal=False): + scale = scale or query.size(-1)**0.5 + if is_causal and attn_mask is not None: + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) + + attn_weight = query @ key.transpose(-2, -1) / scale + if attn_mask is not None: + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +if digit_version(torch.__version__) >= digit_version('2.0.0'): + scaled_dot_product_attention = F.scaled_dot_product_attention +else: + scaled_dot_product_attention = scaled_dot_product_attention_pyimpl + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + super(WindowMSA, self).init_weights() + + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class WindowMSAV2(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Based on implementation on Swin Transformer V2 original repo. Refers to + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py + for more details. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + attn_drop (float): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float): Dropout ratio of output. Defaults to 0. + cpb_mlp_hidden_dims (int): The hidden dimensions of the continuous + relative position bias network. Defaults to 512. + pretrained_window_size (tuple(int)): The height and width of the window + in pre-training. Defaults to (0, 0), which means not load + pretrained model. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + cpb_mlp_hidden_dims=512, + pretrained_window_size=(0, 0), + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + + # Use small network for continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear( + in_features=2, out_features=cpb_mlp_hidden_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear( + in_features=cpb_mlp_hidden_dims, + out_features=num_heads, + bias=False)) + + # Add learnable scalar for cosine attention + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # get relative_coords_table + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), + self.window_size[0], + dtype=torch.float32) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), + self.window_size[1], + dtype=torch.float32) + relative_coords_table = torch.stack( + torch_meshgrid([relative_coords_h, relative_coords_w])).permute( + 1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= ( + pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= ( + pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + self.register_buffer('relative_coords_table', relative_coords_table) + + # get pair-wise relative position index + # for each token inside the window + indexes_h = torch.arange(self.window_size[0]) + indexes_w = torch.arange(self.window_size[1]) + coordinates = torch.stack( + torch_meshgrid([indexes_h, indexes_w]), dim=0) # 2, Wh, Ww + coordinates = torch.flatten(coordinates, start_dim=1) # 2, Wh*Ww + # 2, Wh*Ww, Wh*Ww + relative_coordinates = coordinates[:, :, None] - coordinates[:, + None, :] + relative_coordinates = relative_coordinates.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + + relative_coordinates[:, :, 0] += self.window_size[ + 0] - 1 # shift to start from 0 + relative_coordinates[:, :, 1] += self.window_size[1] - 1 + relative_coordinates[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coordinates.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(embed_dims)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = ( + F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp( + self.logit_scale, max=np.log(1. / 0.01)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp( + self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +@MODELS.register_module() +class ShiftWindowMSA(BaseModule): + """Shift Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults to dict(type='DropPath', drop_prob=0.). + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + window_msa (Callable): To build a window multi-head attention module. + Defaults to :class:`WindowMSA`. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + **kwargs: Other keyword arguments to build the window multi-head + attention module. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + pad_small_map=False, + window_msa=WindowMSA, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + + self.shift_size = shift_size + self.window_size = window_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = window_msa( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(self.window_size), + **kwargs, + ) + + self.drop = build_dropout(dropout_layer) + self.pad_small_map = pad_small_map + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, f"The query length {L} doesn't match the input "\ + f'shape ({H}, {W}).' + query = query.view(B, H, W, C) + + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) == window_size: + # If not pad small feature map, avoid shifting when the window size + # is equal to the size of feature map. It's to align with the + # behavior of the original implementation. + shift_size = shift_size if self.pad_small_map else 0 + elif min(H, W) < window_size: + # In the original implementation, the window size will be shrunk + # to the size of feature map. The behavior is different with + # swin-transformer for downstream tasks. To support dynamic input + # shape, we don't allow this feature. + assert self.pad_small_map, \ + f'The input shape ({H}, {W}) is smaller than the window ' \ + f'size ({window_size}). Please set `pad_small_map=True`, or ' \ + 'decrease the `window_size`.' + + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if shift_size > 0: + query = torch.roll( + query, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + attn_mask = self.get_attn_mask((H_pad, W_pad), + window_size=window_size, + shift_size=shift_size, + device=query.device) + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(query, window_size) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, window_size, window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, + window_size) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + if H != H_pad or W != W_pad: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + + return x + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + @staticmethod + def get_attn_mask(hw_shape, window_size, shift_size, device=None): + if shift_size > 0: + img_mask = torch.zeros(1, *hw_shape, 1, device=device) + h_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = ShiftWindowMSA.window_partition( + img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) + attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + +class MultiheadAttention(BaseModule): + """Multi-head Attention Module. + + This module implements multi-head attention that supports different input + dims and embed dims. And it also supports a shortcut from ``value``, which + is useful if input dims is not the same with embed dims. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + use_layer_scale (bool): Whether to use layer scale. Defaults to False. + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + input_dims=None, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + qkv_bias=True, + qk_scale=None, + proj_bias=True, + v_shortcut=False, + use_layer_scale=False, + layer_scale_init_value=0., + init_cfg=None): + super(MultiheadAttention, self).__init__(init_cfg=init_cfg) + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + if qk_scale is not None: + self.scaled_dot_product_attention = partial( + scaled_dot_product_attention_pyimpl, + scale=self.head_dims**-0.5) + else: + self.scaled_dot_product_attention = scaled_dot_product_attention + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.out_drop = build_dropout(dropout_layer) + + if use_layer_scale: + warnings.warn('The `use_layer_scale` in `MultiheadAttention` will ' + 'be deprecated. Please use `layer_scale_init_value` ' + 'to control whether using layer scale or not.') + + if use_layer_scale or (layer_scale_init_value > 0): + layer_scale_init_value = layer_scale_init_value or 1e-5 + self.gamma1 = LayerScale( + embed_dims, layer_scale_init_value=layer_scale_init_value) + else: + self.gamma1 = nn.Identity() + + def forward(self, x): + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.gamma1(self.proj_drop(x))) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + The initial implementation is in MMSegmentation. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int, int]): The height and width of the window. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + bias (str): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + use_rel_pos_bias, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + if window_size is None: + assert not use_rel_pos_bias + else: + assert isinstance(window_size, tuple) + self.window_size = window_size + self.use_rel_pos_bias = use_rel_pos_bias + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + if self.use_rel_pos_bias: + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch_meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + def init_weights(self): + super().init_weights() + if self.use_rel_pos_bias: + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, rel_pos_bias=None): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + rel_pos_bias (tensor): input relative position bias with shape of + (num_heads, N, N). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + # use shared relative position bias + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class ChannelMultiheadAttention(BaseModule): + """Channel Multihead Self-attention Module. + + This module implements channel multi-head attention that supports different + input dims and embed dims. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shoutcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to False. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + qk_scale_type (str): The scale type of qk scale. + Defaults to 'learnable'. It can be 'learnable', 'fixed' or 'none'. + qk_scale (float, optional): If set qk_scale_type to 'none', this + should be specified with valid float number. Defaults to None. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + input_dims=None, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + qkv_bias=False, + proj_bias=True, + qk_scale_type='learnable', + qk_scale=None, + v_shortcut=False, + init_cfg=None): + super().__init__(init_cfg) + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + if qk_scale_type == 'learnable': + self.scale = nn.Parameter(torch.ones(num_heads, 1, 1)) + elif qk_scale_type == 'fixed': + self.scale = self.head_dims**-0.5 + elif qk_scale_type == 'none': + assert qk_scale is not None + self.scale = qk_scale + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.out_drop = build_dropout(dropout_layer) + + def forward(self, x): + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + + q, k, v = [item.transpose(-2, -1) for item in [qkv[0], qkv[1], qkv[2]]] + + q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, self.embed_dims) + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = qkv[2].squeeze(1) + x + return x + + +class LeAttention(BaseModule): + """LeViT Attention. Multi-head attention with attention bias, which is + proposed in `LeViT: a Vision Transformer in ConvNet’s Clothing for Faster + Inference`_ + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8. + key_dim (int): Dimension of key. Default: None. + attn_ratio (int): Ratio of attention heads. Default: 8. + resolution (tuple[int]): Input resolution. Default: (16, 16). + init_cfg (dict, optional): The Config for initialization. + """ + + def __init__(self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list( + itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer( + 'attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, + -1).split([self.key_dim, self.key_dim, self.d], + dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class CrossMultiheadAttention(BaseModule): + """Cross attention between queries and the union of keys and values. + + This module is different from ``MultiheadAttention``, for the attention + is computed between queries and the union of keys and values. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + """ + + def __init__(self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0.) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(embed_dims, embed_dims, bias=False) + self.k = nn.Linear(embed_dims, embed_dims, bias=False) + self.v = nn.Linear(embed_dims, embed_dims, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(embed_dims)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, + x: torch.Tensor, + k: torch.Tensor = None, + v: torch.Tensor = None) -> None: + """Forward function.""" + B, N, _ = x.shape + + N_k = k.shape[1] + N_v = v.shape[1] + + q_bias, k_bias, v_bias = None, None, None + if self.q_bias is not None: + q_bias = self.q_bias + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + v_bias = self.v_bias + + q = F.linear( + input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim) + k = F.linear( + input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim) + v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + + q = q.reshape(B, N, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_q, dim) + k = k.reshape(B, N_k, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_k, dim) + v = v.reshape(B, N_v, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_v, dim) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class PromptMultiheadAttention(MultiheadAttention): + """Prompt Multihead Attention for MILAN. + + This module is specific for the prompt encoder in MILAN. It will not update + the visible tokens from the encoder. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + return_attention (bool): If True, return the attention map, computed by + the cross attention between the class token and all other tokens. + Defaults to False. + init_cfg (Union[List[dict], dict], optional): The Config for + initialization. Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + input_dims: Optional[int] = None, + attn_drop: float = 0, + proj_drop: float = 0, + dropout_layer: dict = dict(type='Dropout', drop_prob=0.), + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + proj_bias: bool = True, + v_shortcut: bool = False, + use_layer_scale: bool = False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + input_dims=input_dims, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + v_shortcut=v_shortcut, + use_layer_scale=use_layer_scale, + init_cfg=init_cfg) + # no longer need qkv + del self.qkv + + # to project the mask tokens + self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias) + # to project al the tokens + self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias) + + def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """Forward function for `PromptMultiheadAttention`. + + Args: + x (torch.Tensor): Mask token features with shape N x L_m x C. + visible_tokens (torch.Tensor): The visible tokens features from + encoder with shape N x L_v x C. + ids_restore (torch.Tensor): The ids of all tokens in the original + image with shape N x L. + + Returns: + torch Tensor: Output features with shape N x L x C. + """ + x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1) + assert x_.shape[1] == ids_restore.shape[1] + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1) + + # full sequence shape + B, _, _ = x_.shape + q = self.q(x).reshape(B, x.shape[1], self.num_heads, + self.head_dims).permute(0, 2, 1, 3) + kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn_drop = self.attn_drop if self.training else 0. + attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.gamma1(self.proj_drop(x))) + return x diff --git a/mmpretrain/models/utils/batch_augments/__init__.py b/mmpretrain/models/utils/batch_augments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbc4e179608767f667ca1075e5134dbecb8c38d --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cutmix import CutMix +from .mixup import Mixup +from .resizemix import ResizeMix +from .wrapper import RandomBatchAugment + +__all__ = ('RandomBatchAugment', 'CutMix', 'Mixup', 'ResizeMix') diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/__init__.cpython-311.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..822487f85acdf94ceef9b419fc3386128c5395db Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/cutmix.cpython-311.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/cutmix.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b28fdfb2c718eda0eafdd0cc5856b4ab7b249ab1 Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/cutmix.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/mixup.cpython-311.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/mixup.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ee6c61159e34a2aeeee9c8b79ad6c370a835cb2 Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/mixup.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/resizemix.cpython-311.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/resizemix.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a35c358eb810819764fbd74d2d721b9e860b8f85 Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/resizemix.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/__pycache__/wrapper.cpython-311.pyc b/mmpretrain/models/utils/batch_augments/__pycache__/wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a378eee2b2b775ce00e0fd9245704fe1c90fb2bc Binary files /dev/null and b/mmpretrain/models/utils/batch_augments/__pycache__/wrapper.cpython-311.pyc differ diff --git a/mmpretrain/models/utils/batch_augments/cutmix.py b/mmpretrain/models/utils/batch_augments/cutmix.py new file mode 100644 index 0000000000000000000000000000000000000000..665427bf5e2ff3a5ae9d656e7d642db8b72acabb --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/cutmix.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS +from .mixup import Mixup + + +@BATCH_AUGMENTS.register_module() +class CutMix(Mixup): + r"""CutMix batch agumentation. + + CutMix is a method to improve the network's generalization capability. It's + proposed in `CutMix: Regularization Strategy to Train Strong Classifiers + with Localizable Features ` + + With this method, patches are cut and pasted among training images where + the ground truth labels are also mixed proportionally to the area of the + patches. + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`Mixup`. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. + correct_lam (bool): Whether to apply lambda correction when cutmix bbox + clipped by image borders. Defaults to True. + + .. note :: + If the ``cutmix_minmax`` is None, how to generate the bounding-box of + patches according to the ``alpha``? + + First, generate a :math:`\lambda`, details can be found in + :class:`Mixup`. And then, the area ratio of the bounding-box + is calculated by: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} + """ + + def __init__(self, + alpha: float, + cutmix_minmax: Optional[List[float]] = None, + correct_lam: bool = True): + super().__init__(alpha=alpha) + + self.cutmix_minmax = cutmix_minmax + self.correct_lam = correct_lam + + def rand_bbox_minmax( + self, + img_shape: Tuple[int, int], + count: Optional[int] = None) -> Tuple[int, int, int, int]: + """Min-Max CutMix bounding-box Inspired by Darknet cutmix + implementation. It generates a random rectangular bbox based on min/max + percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and + .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + count (int, optional): Number of bbox to generate. Defaults to None + """ + assert len(self.cutmix_minmax) == 2 + img_h, img_w = img_shape + cut_h = np.random.randint( + int(img_h * self.cutmix_minmax[0]), + int(img_h * self.cutmix_minmax[1]), + size=count) + cut_w = np.random.randint( + int(img_w * self.cutmix_minmax[0]), + int(img_w * self.cutmix_minmax[1]), + size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + def rand_bbox(self, + img_shape: Tuple[int, int], + lam: float, + margin: float = 0., + count: Optional[int] = None) -> Tuple[int, int, int, int]: + """Standard CutMix bounding-box that generates a random square bbox + based on lambda value. This implementation includes support for + enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin + (reduce amount of box outside image). Defaults to 0. + count (int, optional): Number of bbox to generate. Defaults to None + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + def cutmix_bbox_and_lam(self, + img_shape: Tuple[int, int], + lam: float, + count: Optional[int] = None) -> tuple: + """Generate bbox and apply lambda correction. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + count (int, optional): Number of bbox to generate. Defaults to None + """ + if self.cutmix_minmax is not None: + yl, yu, xl, xu = self.rand_bbox_minmax(img_shape, count=count) + else: + yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count) + if self.correct_lam or self.cutmix_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[0] * img_shape[1]) + return (yl, yu, xl, xu), lam + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + img_shape = batch_inputs.shape[-2:] + index = torch.randperm(batch_size) + + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return batch_inputs, mixed_scores diff --git a/mmpretrain/models/utils/batch_augments/mixup.py b/mmpretrain/models/utils/batch_augments/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..bedb2c3e5b6e62595e50f7494eeda7c14827b391 --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/mixup.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS + + +@BATCH_AUGMENTS.register_module() +class Mixup: + r"""Mixup batch augmentation. + + Mixup is a method to reduces the memorization of corrupt labels and + increases the robustness to adversarial examples. It's proposed in + `mixup: Beyond Empirical Risk Minimization + `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + are in the note. + + Note: + The :math:`\alpha` (``alpha``) determines a random distribution + :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample + a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random + distribution. + """ + + def __init__(self, alpha: float): + assert isinstance(alpha, float) and alpha > 0 + + self.alpha = alpha + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return mixed_inputs, mixed_scores + + def __call__(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): + """Mix the batch inputs and batch data samples.""" + assert batch_score.ndim == 2, \ + 'The input `batch_score` should be a one-hot format tensor, '\ + 'which shape should be ``(N, num_classes)``.' + + mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score.float()) + return mixed_inputs, mixed_score diff --git a/mmpretrain/models/utils/batch_augments/resizemix.py b/mmpretrain/models/utils/batch_augments/resizemix.py new file mode 100644 index 0000000000000000000000000000000000000000..89cfb72033e75065502a594f17124eb1f471116f --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/resizemix.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from mmpretrain.registry import BATCH_AUGMENTS +from .cutmix import CutMix + + +@BATCH_AUGMENTS.register_module() +class ResizeMix(CutMix): + r"""ResizeMix Random Paste layer for a batch of data. + + The ResizeMix will resize an image to a small patch and paste it on another + image. It's proposed in `ResizeMix: Mixing Data with Preserved Object + Information and True Labels `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`Mixup`. + lam_min(float): The minimum value of lam. Defaults to 0.1. + lam_max(float): The maximum value of lam. Defaults to 0.8. + interpolation (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | + 'area'. Defaults to 'bilinear'. + prob (float): The probability to execute resizemix. It should be in + range [0, 1]. Defaults to 1.0. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. + correct_lam (bool): Whether to apply lambda correction when cutmix bbox + clipped by image borders. Defaults to True + **kwargs: Any other parameters accpeted by :class:`CutMix`. + + Note: + The :math:`\lambda` (``lam``) is the mixing ratio. It's a random + variable which follows :math:`Beta(\alpha, \alpha)` and is mapped + to the range [``lam_min``, ``lam_max``]. + + .. math:: + \lambda = \frac{Beta(\alpha, \alpha)} + {\lambda_{max} - \lambda_{min}} + \lambda_{min} + + And the resize ratio of source images is calculated by :math:`\lambda`: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} + """ + + def __init__(self, + alpha: float, + lam_min: float = 0.1, + lam_max: float = 0.8, + interpolation: str = 'bilinear', + cutmix_minmax: Optional[List[float]] = None, + correct_lam: bool = True): + super().__init__( + alpha=alpha, cutmix_minmax=cutmix_minmax, correct_lam=correct_lam) + self.lam_min = lam_min + self.lam_max = lam_max + self.interpolation = interpolation + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + lam = lam * (self.lam_max - self.lam_min) + self.lam_min + img_shape = batch_inputs.shape[-2:] + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate( + batch_inputs[index], + size=(y2 - y1, x2 - x1), + mode=self.interpolation, + align_corners=False) + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return batch_inputs, mixed_scores diff --git a/mmpretrain/models/utils/batch_augments/wrapper.py b/mmpretrain/models/utils/batch_augments/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..10e5304c3ca1a42428870ea5a00416007ca2e35c --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/wrapper.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Union + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS + + +class RandomBatchAugment: + """Randomly choose one batch augmentation to apply. + + Args: + augments (Callable | dict | list): configs of batch + augmentations. + probs (float | List[float] | None): The probabilities of each batch + augmentations. If None, choose evenly. Defaults to None. + + Example: + >>> import torch + >>> import torch.nn.functional as F + >>> from mmpretrain.models import RandomBatchAugment + >>> augments_cfg = [ + ... dict(type='CutMix', alpha=1.), + ... dict(type='Mixup', alpha=1.) + ... ] + >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) + >>> imgs = torch.rand(16, 3, 32, 32) + >>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10) + >>> imgs, label = batch_augment(imgs, label) + + .. note :: + + To decide which batch augmentation will be used, it picks one of + ``augments`` based on the probabilities. In the example above, the + probability to use CutMix is 0.5, to use Mixup is 0.3, and to do + nothing is 0.2. + """ + + def __init__(self, augments: Union[Callable, dict, list], probs=None): + if not isinstance(augments, (tuple, list)): + augments = [augments] + + self.augments = [] + for aug in augments: + if isinstance(aug, dict): + self.augments.append(BATCH_AUGMENTS.build(aug)) + else: + self.augments.append(aug) + + if isinstance(probs, float): + probs = [probs] + + if probs is not None: + assert len(augments) == len(probs), \ + '``augments`` and ``probs`` must have same lengths. ' \ + f'Got {len(augments)} vs {len(probs)}.' + assert sum(probs) <= 1, \ + 'The total probability of batch augments exceeds 1.' + self.augments.append(None) + probs.append(1 - sum(probs)) + + self.probs = probs + + def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor): + """Randomly apply batch augmentations to the batch inputs and batch + data samples.""" + aug_index = np.random.choice(len(self.augments), p=self.probs) + aug = self.augments[aug_index] + + if aug is not None: + return aug(batch_input, batch_score) + else: + return batch_input, batch_score.float() diff --git a/mmpretrain/models/utils/batch_shuffle.py b/mmpretrain/models/utils/batch_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b03c5fec5f99295daed2872feff73dfc238140 --- /dev/null +++ b/mmpretrain/models/utils/batch_shuffle.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.dist import all_gather, broadcast, get_rank + + +@torch.no_grad() +def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Batch shuffle, for making use of BatchNorm. + + Args: + x (torch.Tensor): Data in each GPU. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation. + - x_gather[idx_this]: Shuffled data. + - idx_unshuffle: Index for restoring. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all) + + # broadcast to all gpus + broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + +@torch.no_grad() +def batch_unshuffle_ddp(x: torch.Tensor, + idx_unshuffle: torch.Tensor) -> torch.Tensor: + """Undo batch shuffle. + + Args: + x (torch.Tensor): Data in each GPU. + idx_unshuffle (torch.Tensor): Index for restoring. + + Returns: + torch.Tensor: Output of unshuffle operation. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] diff --git a/mmpretrain/models/utils/box_utils.py b/mmpretrain/models/utils/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79db516c990f51a7c952404d932b6de022684fb4 --- /dev/null +++ b/mmpretrain/models/utils/box_utils.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torchvision.ops.boxes as boxes + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2.0, (y0 + y1) / 2.0, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +def box_iou(boxes1, boxes2): + """Return intersection-over-union (Jaccard index) between two sets of + boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for + every element in boxes1 and boxes2 + """ + return boxes.box_iou(boxes1, boxes2) + + +def generalized_box_iou(boxes1, boxes2): + """Return generalized intersection-over-union (Jaccard index) between two + sets of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU + values for every element in boxes1 and boxes2 + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + + return boxes.generalized_box_iou(boxes1, boxes2) diff --git a/mmpretrain/models/utils/channel_shuffle.py b/mmpretrain/models/utils/channel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..27006a8065db35a14c4207ce6613104374b064ad --- /dev/null +++ b/mmpretrain/models/utils/channel_shuffle.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def channel_shuffle(x, groups): + """Channel Shuffle operation. + + This function enables cross-group information flow for multiple groups + convolution layers. + + Args: + x (Tensor): The input tensor. + groups (int): The number of groups to divide the input tensor + in the channel dimension. + + Returns: + Tensor: The output tensor after channel shuffle operation. + """ + + batch_size, num_channels, height, width = x.size() + assert (num_channels % groups == 0), ('num_channels should be ' + 'divisible by groups') + channels_per_group = num_channels // groups + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(batch_size, -1, height, width) + + return x diff --git a/mmpretrain/models/utils/clip_generator_helper.py b/mmpretrain/models/utils/clip_generator_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4f67f0ed6976585a20e15787fc6b94c41082d33d --- /dev/null +++ b/mmpretrain/models/utils/clip_generator_helper.py @@ -0,0 +1,394 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from collections import OrderedDict +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.logging import MMLogger +from torch import nn + +from mmpretrain.registry import MODELS + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +@MODELS.register_module() +class QuickGELU(nn.Module): + """A faster version of GELU.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + """Residual Attention Block (RAB). + + This module implements the same function as the MultiheadAttention, + but with a different interface, which is mainly used + in CLIP. + + Args: + d_model (int): The feature dimension. + n_head (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + Defaults to None. + """ + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: Optional[torch.Tensor] = None, + return_attention: bool = False) -> None: + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.return_attention = return_attention + + def attention(self, x: torch.Tensor) -> torch.Tensor: + """Attention function.""" + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + if self.return_attention: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask) + else: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask)[0] + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Forward function.""" + if self.return_attention: + x_, attention = self.attention(self.ln_1(x)) + x = x + x_ + x = x + self.mlp(self.ln_2(x)) + return x, attention + else: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +class VisionTransformer(nn.Module): + """Vision Transformer for CLIP. + + Args: + input_resolution (int): The image size. + patch_size (int): The patch size. + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + out_dim (int): The output dimension. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + finetune=False, + average_targets: int = 1) -> None: + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.finetune = finetune + if finetune is False: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.average_targets = average_targets + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function.""" + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attention, z = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x, attention + + +class CLIP(nn.Module): + """CLIP. + + Args: + embed_dim (int): The embedding dimension. + image_resolution (int): The image size. + vision_layers (int): The number of layers in the vision transformer. + vision_width (int): The feature dimension in the vision transformer. + vision_patch_size (int): The patch size in the vision transformer. + context_length (int): The context length. + vocab_size (int): The vocabulary size. + transformer_width (int): The feature dimension in the text transformer. + transformer_heads (int): The number of attention heads in the + text transformer. + transformer_layers (int): The number of layers in the text transformer. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__( + self, + embed_dim: int, + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + finetune: bool = False, + average_targets: int = 1, + ) -> None: + super().__init__() + + self.context_length = context_length + + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + finetune=finetune, + average_targets=average_targets, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self) -> torch.Tensor: + """Build the attention mask.""" + # lazily create causal attention mask, with full attention between the + # vision tokens pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self) -> torch.dtype: + """Get the dtype.""" + return self.visual.conv1.weight.dtype + + def encode_image(self, + image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode the image. + + Get the feature and attention mask from the last layer of the visual + branch of CLIP. + + Args: + image (torch.Tensor): The image tensor with shape NCHW. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask. + """ + return self.visual(image.type(self.dtype)) + + +def build_clip_model(state_dict: dict, + finetune: bool = False, + average_targets: int = 1) -> nn.Module: + """Build the CLIP model. + + Args: + state_dict (dict): The pretrained state dict. + finetune (bool): Whether to fineturn the model. + average_targets (bool): Whether to average the target. + + Returns: + nn.Module: The CLIP model. + """ + vit = 'visual.proj' in state_dict + + if vit: + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + finetune, + average_targets, + ) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + msg = model.load_state_dict(state_dict, strict=False) + MMLogger.get_current_instance().info(f'Load CLIP model: {msg}') + return model.eval() diff --git a/mmpretrain/models/utils/data_preprocessor.py b/mmpretrain/models/utils/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..c407bd4c9361b9fae329854d4a36dab929fef143 --- /dev/null +++ b/mmpretrain/models/utils/data_preprocessor.py @@ -0,0 +1,620 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from mmengine.model import (BaseDataPreprocessor, ImgDataPreprocessor, + stack_batch) + +from mmpretrain.registry import MODELS +from mmpretrain.structures import (DataSample, MultiTaskDataSample, + batch_label_to_onehot, cat_batch_labels, + tensor_split) +from .batch_augments import RandomBatchAugment + + +@MODELS.register_module() +class ClsDataPreprocessor(BaseDataPreprocessor): + """Image pre-processor for classification tasks. + + Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + to_onehot (bool): Whether to generate one-hot format gt-labels and set + to data samples. Defaults to False. + num_classes (int, optional): The number of classes. Defaults to None. + batch_augments (dict, optional): The batch augmentations settings, + including "augments" and "probs". For more details, see + :class:`mmpretrain.models.RandomBatchAugment`. + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + to_onehot: bool = False, + num_classes: Optional[int] = None, + batch_augments: Optional[dict] = None): + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.to_onehot = to_onehot + self.num_classes = num_classes + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + if batch_augments: + self.batch_augments = RandomBatchAugment(**batch_augments) + if not self.to_onehot: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().info( + 'Because batch augmentations are enabled, the data ' + 'preprocessor automatically enables the `to_onehot` ' + 'option to generate one-hot format labels.') + self.to_onehot = True + else: + self.batch_augments = None + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + inputs = self.cast_data(data['inputs']) + + if isinstance(inputs, torch.Tensor): + # The branch if use `default_collate` as the collate_fn in the + # dataloader. + + # ------ To RGB ------ + if self.to_rgb and inputs.size(1) == 3: + inputs = inputs.flip(1) + + # -- Normalization --- + inputs = inputs.float() + if self._enable_normalize: + inputs = (inputs - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = inputs.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + else: + # The branch if use `pseudo_collate` as the collate_fn in the + # dataloader. + + processed_inputs = [] + for input_ in inputs: + # ------ To RGB ------ + if self.to_rgb and input_.size(0) == 3: + input_ = input_.flip(0) + + # -- Normalization --- + input_ = input_.float() + if self._enable_normalize: + input_ = (input_ - self.mean) / self.std + + processed_inputs.append(input_) + # Combine padding and stack + inputs = stack_batch(processed_inputs, self.pad_size_divisor, + self.pad_value) + + data_samples = data.get('data_samples', None) + sample_item = data_samples[0] if data_samples is not None else None + + if isinstance(sample_item, DataSample): + batch_label = None + batch_score = None + + if 'gt_label' in sample_item: + gt_labels = [sample.gt_label for sample in data_samples] + batch_label, label_indices = cat_batch_labels(gt_labels) + batch_label = batch_label.to(self.device) + if 'gt_score' in sample_item: + gt_scores = [sample.gt_score for sample in data_samples] + batch_score = torch.stack(gt_scores).to(self.device) + elif self.to_onehot and 'gt_label' in sample_item: + assert batch_label is not None, \ + 'Cannot generate onehot format labels because no labels.' + num_classes = self.num_classes or sample_item.get( + 'num_classes') + assert num_classes is not None, \ + 'Cannot generate one-hot format labels because not set ' \ + '`num_classes` in `data_preprocessor`.' + batch_score = batch_label_to_onehot( + batch_label, label_indices, num_classes).to(self.device) + + # ----- Batch Augmentations ---- + if (training and self.batch_augments is not None + and batch_score is not None): + inputs, batch_score = self.batch_augments(inputs, batch_score) + + # ----- scatter labels and scores to data samples --- + if batch_label is not None: + for sample, label in zip( + data_samples, tensor_split(batch_label, + label_indices)): + sample.set_gt_label(label) + if batch_score is not None: + for sample, score in zip(data_samples, batch_score): + sample.set_gt_score(score) + elif isinstance(sample_item, MultiTaskDataSample): + data_samples = self.cast_data(data_samples) + + return {'inputs': inputs, 'data_samples': data_samples} + + +@MODELS.register_module() +class SelfSupDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for operations, like normalization and bgr to rgb. + + Compared with the :class:`mmengine.ImgDataPreprocessor`, this module + supports ``inputs`` as torch.Tensor or a list of torch.Tensor. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + + self._channel_conversion = to_rgb or bgr_to_rgb or rgb_to_bgr + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + assert isinstance(data, + dict), 'Please use default_collate in dataloader, \ + instead of pseudo_collate.' + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + # Here is what is different from :class:`mmengine.ImgDataPreprocessor` + # Since there are multiple views for an image for some algorithms, + # e.g. SimCLR, each item in inputs is a list, containing multi-views + # for an image. + if isinstance(batch_inputs, list): + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # normalization. + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + else: + # channel transform + if self._channel_conversion: + batch_inputs = batch_inputs[:, [2, 1, 0], ...] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = batch_inputs.float() + + # normalization. + if self._enable_normalize: + batch_inputs = (batch_inputs - self.mean) / self.std + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class TwoNormDataPreprocessor(SelfSupDataPreprocessor): + """Image pre-processor for CAE, BEiT v1/v2, etc. + + Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module + will normalize the prediction image and target image with different + normalization parameters. + + Args: + mean (Sequence[float or int], optional): The pixel mean of image + channels. If ``to_rgb=True`` it means the mean value of R, G, B + channels. If the length of `mean` is 1, it means all channels have + the same mean value, or the input is a gray image. If it is not + specified, images will not be normalized. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation of + image channels. If ``to_rgb=True`` it means the standard deviation + of R, G, B channels. If the length of `std` is 1, it means all + channels have the same standard deviation, or the input is a gray + image. If it is not specified, images will not be normalized. + Defaults to None. + second_mean (Sequence[float or int], optional): The description is + like ``mean``, it can be customized for targe image. Defaults to + None. + second_std (Sequence[float or int], optional): The description is + like ``std``, it can be customized for targe image. Defaults to + None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + non_blocking (bool): Whether block current process when transferring + data to device. Defaults to False. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + second_mean: Sequence[Union[float, int]] = None, + second_std: Sequence[Union[float, int]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + to_rgb=to_rgb, + non_blocking=non_blocking) + assert (second_mean is not None) and (second_std is not None), ( + 'mean and std should not be None while using ' + '`TwoNormDataPreprocessor`') + assert len(second_mean) == 3 or len(second_mean) == 1, ( + '`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(second_mean)} values') + assert len(second_std) == 3 or len(second_std) == 1, ( + '`std` should have 1 or 3 values, to be compatible with RGB ' + f'or gray image, but got {len(std)} values') + + self.register_buffer('second_mean', + torch.tensor(second_mean).view(-1, 1, 1), False) + self.register_buffer('second_std', + torch.tensor(second_std).view(-1, 1, 1), False) + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization and bgr2rgb conversion based on + ``BaseDataPreprocessor``. The ``batch_inputs`` in forward function is a + list. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # Normalization. Here is what is different from + # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target + # image and prediction image with different normalization params + if self._enable_normalize: + batch_inputs = [ + (batch_inputs[0] - self.mean) / self.std, + (batch_inputs[1] - self.second_mean) / self.second_std + ] + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class VideoDataPreprocessor(BaseDataPreprocessor): + """Video pre-processor for operations, like normalization and bgr to rgb + conversion . + + Compared with the :class:`mmaction.ActionDataPreprocessor`, this module + supports ``inputs`` as torch.Tensor or a list of torch.Tensor. + + Args: + mean (Sequence[float or int, optional): The pixel mean of channels + of images or stacked optical flow. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation + of channels of images or stacked optical flow. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): Whether to convert image from BGR to RGB. + Defaults to False. + format_shape (str): Format shape of input data. + Defaults to ``'NCHW'``. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + format_shape: str = 'NCHW') -> None: + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.format_shape = format_shape + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + if self.format_shape == 'NCHW': + normalizer_shape = (-1, 1, 1) + elif self.format_shape == 'NCTHW': + normalizer_shape = (-1, 1, 1, 1) + else: + raise ValueError(f'Invalid format shape: {format_shape}') + + self.register_buffer( + 'mean', + torch.tensor(mean, dtype=torch.float32).view(normalizer_shape), + False) + self.register_buffer( + 'std', + torch.tensor(std, dtype=torch.float32).view(normalizer_shape), + False) + else: + self._enable_normalize = False + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Data in the same format + as the model input. + """ + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + if isinstance(batch_inputs, list): + # channel transform + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = [ + _input[..., [2, 1, 0], :, :] for _input in batch_inputs + ] + elif self.format_shape == 'NCTHW': + batch_inputs = [ + _input[..., [2, 1, 0], :, :, :] + for _input in batch_inputs + ] + else: + raise ValueError( + f'Invalid format shape: {self.format_shape}') + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # normalization + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + + else: + # channel transform + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = batch_inputs[..., [2, 1, 0], :, :] + elif self.format_shape == 'NCTHW': + batch_inputs = batch_inputs[..., [2, 1, 0], :, :, :] + else: + raise ValueError( + f'Invalid format shape: {self.format_shape}') + + # convert to float after channel conversion to ensure efficiency + batch_inputs = batch_inputs.float() + + # normalization + if self._enable_normalize: + batch_inputs = (batch_inputs - self.mean) / self.std + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class MultiModalDataPreprocessor(BaseDataPreprocessor): + """Data pre-processor for image-text multimodality tasks. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + ): + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + data = self.cast_data(data) + + imgs = data.get('inputs', None) + + def _process_img(img): + # ------ To RGB ------ + if self.to_rgb and img.size(1) == 3: + img = img.flip(1) + + # -- Normalization --- + img = img.float() + if self._enable_normalize: + img = (img - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = img.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + img = F.pad(img, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + return img + + if isinstance(imgs, torch.Tensor): + imgs = _process_img(imgs) + elif isinstance(imgs, Sequence): + # B, T, C, H, W + imgs = torch.stack([_process_img(img) for img in imgs], dim=1) + elif imgs is not None: + raise ValueError(f'{type(imgs)} is not supported for imgs inputs.') + + data_samples = data.get('data_samples', None) + + return {'images': imgs, 'data_samples': data_samples} diff --git a/mmpretrain/models/utils/ema.py b/mmpretrain/models/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..63c5006bbb0d9ff967b3cce7d3b5ada0cc683468 --- /dev/null +++ b/mmpretrain/models/utils/ema.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import cos, pi +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.logging import MessageHub +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineEMA(ExponentialMovingAverage): + r"""CosineEMA is implemented for updating momentum parameter, used in BYOL, + MoCoV3, etc. + + All parameters are updated by the formula as below: + + .. math:: + + X'_{t+1} = (1 - m) * X'_t + m * X_t + + Where :math:`m` the the momentum parameter. And it's updated with cosine + annealing, including momentum adjustment following: + + .. math:: + m = m_{end} + (m_{end} - m_{start}) * (\cos\frac{k\pi}{K} + 1) / 2 + + where :math:`k` is the current step, :math:`K` is the total steps. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, + :math:`X'_{t}` is the moving average and :math:`X_t` is the new + observed value. The value of momentum is usually a small number, + allowing observed values to slowly update the ema parameters. See also + :external:py:class:`torch.nn.BatchNorm2d`. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The start momentum value. Defaults to 0.004. + end_momentum (float): The end momentum value for cosine annealing. + Defaults to 0. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.004, + end_momentum: float = 0., + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + self.end_momentum = end_momentum + + def avg_func(self, averaged_param: torch.Tensor, + source_param: torch.Tensor, steps: int) -> None: + """Compute the moving average of the parameters using the cosine + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + + Returns: + Tensor: The averaged parameters. + """ + message_hub = MessageHub.get_current_instance() + max_iters = message_hub.get_info('max_iters') + cosine_annealing = (cos(pi * steps / float(max_iters)) + 1) / 2 + momentum = self.end_momentum - (self.end_momentum - + self.momentum) * cosine_annealing + averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) diff --git a/mmpretrain/models/utils/embed.py b/mmpretrain/models/utils/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8299f9a06789768b26ea58260a2984024fbf801d --- /dev/null +++ b/mmpretrain/models/utils/embed.py @@ -0,0 +1,423 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import AdaptivePadding +from mmengine.model import BaseModule + +from .helpers import to_2tuple + + +def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): + """Resize pos_embed weights. + + Args: + pos_embed (torch.Tensor): Position embedding weights with shape + [1, L, C]. + src_shape (tuple): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (tuple): The resolution of downsampled new training + image, in format (H, W). + mode (str): Algorithm used for upsampling. Choose one from 'nearest', + 'linear', 'bilinear', 'bicubic' and 'trilinear'. + Defaults to 'bicubic'. + num_extra_tokens (int): The number of extra tokens, such as cls_token. + Defaults to 1. + + Returns: + torch.Tensor: The resized pos_embed of shape [1, L_new, C] + """ + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, \ + f"The length of `pos_embed` ({L}) doesn't match the expected " \ + f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ + '`img_size` argument.' + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + # The cubic interpolate algorithm only accepts float32 + dst_weight = F.interpolate( + src_weight.float(), size=dst_shape, align_corners=False, mode=mode) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + dst_weight = dst_weight.to(src_weight.dtype) + + return torch.cat((extra_tokens, dst_weight), dim=1) + + +def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head): + """Resize relative position bias table. + + Args: + src_shape (int): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (int): The resolution of downsampled new training + image, in format (H, W). + table (tensor): The relative position bias of the pretrained model. + num_head (int): Number of attention heads. + + Returns: + torch.Tensor: The resized relative position bias table. + """ + from scipy import interpolate + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_shape // 2) + if gp > dst_shape // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src_shape // 2): + dis.append(cur) + cur += q**(i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_shape // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + all_rel_pos_bias = [] + + for i in range(num_head): + z = table[:, i].view(src_shape, src_shape).float().numpy() + f_cubic = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f_cubic(dx, + dy)).contiguous().view(-1, + 1).to(table.device)) + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + return new_rel_pos_bias + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + img_size (int | tuple): The size of input image. Default: 224 + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None + conv_cfg (dict, optional): The config dict for conv layers. + Default: None + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None + """ + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=768, + norm_cfg=None, + conv_cfg=None, + init_cfg=None): + super(PatchEmbed, self).__init__(init_cfg) + warnings.warn('The `PatchEmbed` in mmpretrain will be deprecated. ' + 'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. ' + "It's more general and supports dynamic input shape") + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + self.img_size = img_size + self.embed_dims = embed_dims + + # Use conv layer to embed + conv_cfg = conv_cfg or dict() + _conv_cfg = dict( + type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1) + _conv_cfg.update(conv_cfg) + self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims) + + # Calculate how many patches a input image is splited to. + h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] - + self.projection.dilation[i] * + (self.projection.kernel_size[i] - 1) - 1) // + self.projection.stride[i] + 1 for i in range(2)] + + self.patches_resolution = (h_out, w_out) + self.num_patches = h_out * w_out + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't " \ + f'match model ({self.img_size[0]}*{self.img_size[1]}).' + # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim + x = self.projection(x).flatten(2).transpose(1, 2) + + if self.norm is not None: + x = self.norm(x) + + return x + + +# Modified from pytorch-image-models +class HybridEmbed(BaseModule): + """CNN Feature Map Embedding. + + Extract feature map from CNN, flatten, + project to embedding dim. + + Args: + backbone (nn.Module): CNN backbone + img_size (int | tuple): The size of input image. Default: 224 + feature_size (int | tuple, optional): Size of feature map extracted by + CNN backbone. Default: None + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_cfg (dict, optional): The config dict for conv layers. + Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_channels=3, + embed_dims=768, + conv_cfg=None, + init_cfg=None): + super(HybridEmbed, self).__init__(init_cfg) + assert isinstance(backbone, nn.Module) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of + # determining the exact dim of the output feature + # map for all networks, the feature metadata has + # reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of + # each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_channels, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + # last feature if backbone outputs list/tuple of features + o = o[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + self.num_patches = feature_size[0] * feature_size[1] + + # Use conv layer to embed + conv_cfg = conv_cfg or dict() + _conv_cfg = dict( + type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1) + _conv_cfg.update(conv_cfg) + self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + # last feature if backbone outputs list/tuple of features + x = x[-1] + x = self.projection(x).flatten(2).transpose(1, 2) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + Modified from mmcv, and this module supports specifying whether to use + post-norm. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map ((used in Swin Transformer)). Our + implementation uses :class:`torch.nn.Unfold` to merge patches, which is + about 25% faster than the original implementation. However, we need to + modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. To gets fully covered + by filter and stride you specified. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Defaults to None, which means to be set as + ``kernel_size``. + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Defaults to "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Defaults to 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults to False. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + use_post_norm (bool): Whether to use post normalization here. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + use_post_norm=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.use_post_norm = use_post_norm + + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adaptive_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + if norm_cfg is not None: + # build pre or post norm layer based on different channels + if self.use_post_norm: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + else: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + + if self.adaptive_padding: + x = self.adaptive_padding(x) + H, W = x.shape[-2:] + + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + x = self.sampler(x) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + + if self.use_post_norm: + # use post-norm here + x = self.reduction(x) + x = self.norm(x) if self.norm else x + else: + x = self.norm(x) if self.norm else x + x = self.reduction(x) + + return x, output_size diff --git a/mmpretrain/models/utils/helpers.py b/mmpretrain/models/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..971f45054e5edac15c71aa64ddd26164bf404d22 --- /dev/null +++ b/mmpretrain/models/utils/helpers.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections.abc +import warnings +from itertools import repeat + +import torch +from mmengine.utils import digit_version + + +def is_tracing() -> bool: + """Determine whether the model is called during the tracing of code with + ``torch.jit.trace``.""" + if digit_version(torch.__version__) >= digit_version('1.6.0'): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False + + +# From PyTorch internals +def _ntuple(n): + """A `to_tuple` function generator. + + It returns a function, this function will repeat the input to a tuple of + length ``n`` if the input is not an Iterable object, otherwise, return the + input directly. + + Args: + n (int): The number of the target length. + """ + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/mmpretrain/models/utils/huggingface.py b/mmpretrain/models/utils/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..a44d6daaf1cc4c51579fd849fb84ee1a5cc6e7d2 --- /dev/null +++ b/mmpretrain/models/utils/huggingface.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +from typing import Optional + +import transformers +from mmengine.registry import Registry +from transformers import AutoConfig, PreTrainedModel +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from mmpretrain.registry import MODELS, TOKENIZER + + +def register_hf_tokenizer( + cls: Optional[type] = None, + registry: Registry = TOKENIZER, +): + """Register HuggingFace-style PreTrainedTokenizerBase class.""" + if cls is None: + + # use it as a decorator: @register_hf_tokenizer() + def _register(cls): + register_hf_tokenizer(cls=cls) + return cls + + return _register + + def from_pretrained(**kwargs): + if ('pretrained_model_name_or_path' not in kwargs + and 'name_or_path' not in kwargs): + raise TypeError( + f'{cls.__name__}.from_pretrained() missing required ' + "argument 'pretrained_model_name_or_path' or 'name_or_path'.") + # `pretrained_model_name_or_path` is too long for config, + # add an alias name `name_or_path` here. + name_or_path = kwargs.pop('pretrained_model_name_or_path', + kwargs.pop('name_or_path')) + return cls.from_pretrained(name_or_path, **kwargs) + + registry._register_module(module=from_pretrained, module_name=cls.__name__) + return cls + + +_load_hf_pretrained_model = True + + +@contextlib.contextmanager +def no_load_hf_pretrained_model(): + global _load_hf_pretrained_model + _load_hf_pretrained_model = False + yield + _load_hf_pretrained_model = True + + +def register_hf_model( + cls: Optional[type] = None, + registry: Registry = MODELS, +): + """Register HuggingFace-style PreTrainedModel class.""" + if cls is None: + + # use it as a decorator: @register_hf_tokenizer() + def _register(cls): + register_hf_model(cls=cls) + return cls + + return _register + + if issubclass(cls, _BaseAutoModelClass): + get_config = AutoConfig.from_pretrained + from_config = cls.from_config + elif issubclass(cls, PreTrainedModel): + get_config = cls.config_class.from_pretrained + from_config = cls + else: + raise TypeError('Not auto model nor pretrained model of huggingface.') + + def build(**kwargs): + if ('pretrained_model_name_or_path' not in kwargs + and 'name_or_path' not in kwargs): + raise TypeError( + f'{cls.__name__} missing required argument ' + '`pretrained_model_name_or_path` or `name_or_path`.') + # `pretrained_model_name_or_path` is too long for config, + # add an alias name `name_or_path` here. + name_or_path = kwargs.pop('pretrained_model_name_or_path', + kwargs.pop('name_or_path')) + + if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model: + model = cls.from_pretrained(name_or_path, **kwargs) + setattr(model, 'is_init', True) + return model + else: + cfg = get_config(name_or_path, **kwargs) + return from_config(cfg) + + registry._register_module(module=build, module_name=cls.__name__) + return cls + + +register_hf_model(transformers.AutoModelForCausalLM) diff --git a/mmpretrain/models/utils/inverted_residual.py b/mmpretrain/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..8387b21251aacff8efcb1b048e37ecdfa1299b2b --- /dev/null +++ b/mmpretrain/models/utils/inverted_residual.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule + +from .se_layer import SELayer + + +class InvertedResidual(BaseModule): + """Inverted Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Defaults to 3. + stride (int): The stride of the depthwise convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_expand_conv = (mid_channels != in_channels) + + if self.with_se: + assert isinstance(se_cfg, dict) + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if self.with_se: + self.se = SELayer(**se_cfg) + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmpretrain/models/utils/layer_scale.py b/mmpretrain/models/utils/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..bb480a15ce35570a5fcfe060c25ef676730430a7 --- /dev/null +++ b/mmpretrain/models/utils/layer_scale.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +import torch.nn as nn + + +class LayerScale(nn.Module): + """LayerScale layer. + + Args: + dim (int): Dimension of input features. + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 1e-5. + inplace (bool): inplace: can optionally do the + operation in-place. Defaults to False. + data_format (str): The input data format, could be 'channels_last' + or 'channels_first', representing (B, C, H, W) and + (B, N, C) format data respectively. Defaults to 'channels_last'. + """ + + def __init__(self, + dim: int, + layer_scale_init_value: Union[float, torch.Tensor] = 1e-5, + inplace: bool = False, + data_format: str = 'channels_last'): + super().__init__() + assert data_format in ('channels_last', 'channels_first'), \ + "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value) + + def forward(self, x): + if self.data_format == 'channels_first': + if self.inplace: + return x.mul_(self.weight.view(-1, 1, 1)) + else: + return x * self.weight.view(-1, 1, 1) + return x.mul_(self.weight) if self.inplace else x * self.weight diff --git a/mmpretrain/models/utils/make_divisible.py b/mmpretrain/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec74689e37d4a9d605a595adb0cca1da88aa19a --- /dev/null +++ b/mmpretrain/models/utils/make_divisible.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmpretrain/models/utils/norm.py b/mmpretrain/models/utils/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8b890a0c6ec654f00e4bb4cd148158eaeba7599d --- /dev/null +++ b/mmpretrain/models/utils/norm.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class GRN(nn.Module): + """Global Response Normalization Module. + + Come from `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked + Autoencoders `_ + + Args: + in_channels (int): The number of channels of the input tensor. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-6. + """ + + def __init__(self, in_channels, eps=1e-6): + super().__init__() + self.in_channels = in_channels + self.gamma = nn.Parameter(torch.zeros(in_channels)) + self.beta = nn.Parameter(torch.zeros(in_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor, data_format='channel_first'): + """Forward method. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + if data_format == 'channel_last': + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + self.eps) + x = self.gamma * (x * nx) + self.beta + x + elif data_format == 'channel_first': + gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps) + x = self.gamma.view(1, -1, 1, 1) * (x * nx) + self.beta.view( + 1, -1, 1, 1) + x + return x + + +@MODELS.register_module('LN2d') +class LayerNorm2d(nn.LayerNorm): + """LayerNorm on channels for 2d images. + + Args: + num_channels (int): The number of channels of the input tensor. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-5. + elementwise_affine (bool): a boolean value that when set to ``True``, + this module has learnable per-element affine parameters initialized + to ones (for weights) and zeros (for biases). Defaults to True. + """ + + def __init__(self, num_channels: int, **kwargs) -> None: + super().__init__(num_channels, **kwargs) + self.num_channels = self.normalized_shape[0] + + def forward(self, x, data_format='channel_first'): + """Forward method. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \ + f'(N, C, H, W), but got tensor with shape {x.shape}' + if data_format == 'channel_last': + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, + self.eps) + elif data_format == 'channel_first': + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, + self.eps) + # If the output is discontiguous, it may cause some unexpected + # problem in the downstream tasks + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +def build_norm_layer(cfg: dict, num_features: int) -> nn.Module: + """Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + + num_features (int): Number of input channels. + + Returns: + nn.Module: The created norm layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + norm_layer = MODELS.get(layer_type) + if norm_layer is None: + raise KeyError(f'Cannot find {layer_type} in registry under scope ' + f'name {MODELS.scope}') + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + else: + layer = norm_layer(num_channels=num_features, **cfg_) + + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return layer diff --git a/mmpretrain/models/utils/position_encoding.py b/mmpretrain/models/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..07a3c486a25a84633d7e50463dd8b09f1c222837 --- /dev/null +++ b/mmpretrain/models/utils/position_encoding.py @@ -0,0 +1,247 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.utils import digit_version + +from ..utils import to_2tuple + +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + feat_token = x + # convert (B, N, C) to (B, C, H, W) + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous() + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +class PositionEncodingFourier(BaseModule): + """The Position Encoding Fourier (PEF) module. + + The PEF is adopted from EdgeNeXt '_. + Args: + in_channels (int): Number of input channels. + Default: 32 + embed_dims (int): The feature dimension. + Default: 768. + temperature (int): Temperature. + Default: 10000. + dtype (torch.dtype): The data type. + Default: torch.float32. + init_cfg (dict): The config dict for initializing the module. + Default: None. + """ + + def __init__(self, + in_channels=32, + embed_dims=768, + temperature=10000, + dtype=torch.float32, + init_cfg=None): + super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1) + self.scale = 2 * math.pi + self.in_channels = in_channels + self.embed_dims = embed_dims + self.dtype = dtype + + if digit_version(torch.__version__) < digit_version('1.8.0'): + floor_div = torch.floor_divide + else: + floor_div = partial(torch.div, rounding_mode='floor') + dim_t = torch.arange(in_channels, dtype=self.dtype) + self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels) + + def forward(self, bhw_shape): + B, H, W = bhw_shape + mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device) + not_mask = ~mask + eps = 1e-6 + y_embed = not_mask.cumsum(1, dtype=self.dtype) + x_embed = not_mask.cumsum(2, dtype=self.dtype) + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = self.dim_t.to(mask.device) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.proj(pos) + + return pos + + +def build_2d_sincos_position_embedding( + patches_resolution: Union[int, Sequence[int]], + embed_dims: int, + temperature: Optional[int] = 10000., + cls_token: Optional[bool] = False) -> torch.Tensor: + """The function is to build position embedding for model to obtain the + position information of the image patches. + + Args: + patches_resolution (Union[int, Sequence[int]]): The resolution of each + patch. + embed_dims (int): The dimension of the embedding vector. + temperature (int, optional): The temperature parameter. Defaults to + 10000. + cls_token (bool, optional): Whether to concatenate class token. + Defaults to False. + + Returns: + torch.Tensor: The position embedding vector. + """ + + if isinstance(patches_resolution, int): + patches_resolution = (patches_resolution, patches_resolution) + + h, w = patches_resolution + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch_meshgrid(grid_w, grid_h) + assert embed_dims % 4 == 0, \ + 'Embed dimension must be divisible by 4.' + pos_dim = embed_dims // 4 + + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature**omega) + out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) + out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) + + pos_emb = torch.cat( + [ + torch.sin(out_w), + torch.cos(out_w), + torch.sin(out_h), + torch.cos(out_h) + ], + dim=1, + )[None, :, :] + + if cls_token: + cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) + pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) + + return pos_emb + + +class RotaryEmbeddingFast(BaseModule): + """Implements 2D rotary embedding (RoPE) for image tokens. Position + encoding is implemented with sin and cos functions, + + .. math:: + Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\ + Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}} + Args: + embed_dims (int): The feature dimension for each head. + patch_resolution (int | tuple): The resolution of the + image, in format (H, W). + theta (float): The hyperparameter for position coding. + Defaults to 10000. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + patch_resolution, + theta=10000., + init_cfg=None): + super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg) + + self.half_dim = embed_dims // 2 + self.patch_resolution = to_2tuple(patch_resolution) + self.theta = theta + + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos) + self.register_buffer('freqs_sin', freqs_sin) + + def compute_position_embedding(self): + frequency = self.theta**( + torch.arange(0, self.half_dim, 2).float() / self.half_dim) + frequency = 1. / frequency + + h, w = self.patch_resolution + th = torch.arange(h) / h * self.half_dim + tw = torch.arange(w) / w * self.half_dim + + position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2) + position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2) + + height = position_h[:, None, :].expand(h, w, self.half_dim) + width = position_w[None, :, :].expand(h, w, self.half_dim) + position = torch.cat((height, width), dim=-1) + + freqs_cos = position.cos().view(-1, position.shape[-1]) + freqs_sin = position.sin().view(-1, position.shape[-1]) + + return freqs_cos, freqs_sin + + def forward(self, x, patch_resolution): + # Check whether the patch resolution is the predefined size + patch_resolution = to_2tuple(patch_resolution) + if patch_resolution != self.patch_resolution: + self.patch_resolution = patch_resolution + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos.to(x.device)) + self.register_buffer('freqs_sin', freqs_sin.to(x.device)) + + batch, num_heads, num_patches, dim = x.shape + + inputs = x + x = x.reshape(batch, num_heads, num_patches, -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + x = x.reshape(batch, num_heads, num_patches, dim) + + return inputs * self.freqs_cos + x * self.freqs_sin diff --git a/mmpretrain/models/utils/res_layer_extra_norm.py b/mmpretrain/models/utils/res_layer_extra_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..37e387ba9795ec528bd210dab75bd05abdc0addf --- /dev/null +++ b/mmpretrain/models/utils/res_layer_extra_norm.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .norm import build_norm_layer + +try: + from mmdet.models.backbones import ResNet + from mmdet.models.roi_heads.shared_heads.res_layer import ResLayer + from mmdet.registry import MODELS + + @MODELS.register_module() + class ResLayerExtraNorm(ResLayer): + """Add extra norm to original ``ResLayer``.""" + + def __init__(self, *args, **kwargs): + super(ResLayerExtraNorm, self).__init__(*args, **kwargs) + + block = ResNet.arch_settings[kwargs['depth']][0] + self.add_module( + 'norm', + build_norm_layer(self.norm_cfg, + 64 * 2**self.stage * block.expansion)) + + def forward(self, x): + """Forward function.""" + res_layer = getattr(self, f'layer{self.stage + 1}') + norm = getattr(self, 'norm') + x = res_layer(x) + out = norm(x) + return out + +except ImportError: + ResLayerExtraNorm = None diff --git a/mmpretrain/models/utils/se_layer.py b/mmpretrain/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..20290171008c2fd6f7a9e14e444f23b8375abe22 --- /dev/null +++ b/mmpretrain/models/utils/se_layer.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.utils import is_tuple_of + +from .make_divisible import make_divisible + + +class SELayer(BaseModule): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + squeeze_channels (None or int): The intermediate channel number of + SElayer. Default: None, means the value of ``squeeze_channels`` + is ``make_divisible(channels // ratio, divisor)``. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will + be ``make_divisible(channels // ratio, divisor)``. Only used when + ``squeeze_channels`` is None. Default: 16. + divisor(int): The divisor to true divide the channel number. Only + used when ``squeeze_channels`` is None. Default: 8. + conv_cfg (None or dict): Config dict for convolution layer. Default: + None, which means using conv2d. + return_weight(bool): Whether to return the weight. Default: False. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='Sigmoid')) + """ + + def __init__(self, + channels, + squeeze_channels=None, + ratio=16, + divisor=8, + bias='auto', + conv_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + return_weight=False, + init_cfg=None): + super(SELayer, self).__init__(init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + if squeeze_channels is None: + squeeze_channels = make_divisible(channels // ratio, divisor) + assert isinstance(squeeze_channels, int) and squeeze_channels > 0, \ + '"squeeze_channels" should be a positive integer, but get ' + \ + f'{squeeze_channels} instead.' + self.return_weight = return_weight + self.conv1 = ConvModule( + in_channels=channels, + out_channels=squeeze_channels, + kernel_size=1, + stride=1, + bias=bias, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=squeeze_channels, + out_channels=channels, + kernel_size=1, + stride=1, + bias=bias, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + if self.return_weight: + return out + else: + return x * out diff --git a/mmpretrain/models/utils/sparse_modules.py b/mmpretrain/models/utils/sparse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6bf345399bbb9c1c2ec4af6c19cfe7adf9beb6 --- /dev/null +++ b/mmpretrain/models/utils/sparse_modules.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved. +# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +class SparseHelper: + """The helper to compute sparse operation with pytorch, such as sparse + convlolution, sparse batch norm, etc.""" + + _cur_active: torch.Tensor = None + + @staticmethod + def _get_active_map_or_index(H: int, + returning_active_map: bool = True + ) -> torch.Tensor: + """Get current active map with (B, 1, f, f) shape or index format.""" + # _cur_active with shape (B, 1, f, f) + downsample_raito = H // SparseHelper._cur_active.shape[-1] + active_ex = SparseHelper._cur_active.repeat_interleave( + downsample_raito, 2).repeat_interleave(downsample_raito, 3) + return active_ex if returning_active_map else active_ex.squeeze( + 1).nonzero(as_tuple=True) + + @staticmethod + def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse convolution forward function.""" + x = super(type(self), self).forward(x) + + # (b, c, h, w) *= (b, 1, h, w), mask the output of conv + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + return x + + @staticmethod + def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse batch norm forward function.""" + active_index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + + # (b, c, h, w) -> (b, h, w, c) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten features + # with shape (n, c) + x_flattened = x_permuted[active_index] + + # use BN1d to normalize this flatten feature (n, c) + x_flattened = super(type(self), self).forward(x_flattened) + + # generate output + output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + output[active_index] = x_flattened + + # (b, h, w, c) -> (b, c, h, w) + output = output.permute(0, 3, 1, 2) + return output + + +class SparseConv2d(nn.Conv2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseMaxPooling(nn.MaxPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseAvgPooling(nn.AvgPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +@MODELS.register_module() +class SparseBatchNorm2d(nn.BatchNorm1d): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module() +class SparseSyncBatchNorm2d(nn.SyncBatchNorm): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module('SparseLN2d') +class SparseLayerNorm2D(nn.LayerNorm): + """Implementation of sparse LayerNorm on channels for 2d images.""" + + def forward(self, + x: torch.Tensor, + data_format='channel_first') -> torch.Tensor: + """Sparse layer norm forward function with 2D data. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, ( + f'LayerNorm2d only supports inputs with shape ' + f'(N, C, H, W), but got tensor with shape {x.shape}') + if data_format == 'channel_last': + index = SparseHelper._get_active_map_or_index( + H=x.shape[1], returning_active_map=False) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x, dtype=x_flattened.dtype) + x[index] = x_flattened + elif data_format == 'channel_first': + index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x_permuted[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + x[index] = x_flattened + x = x.permute(0, 3, 1, 2).contiguous() + else: + raise NotImplementedError + return x diff --git a/mmpretrain/models/utils/swiglu_ffn.py b/mmpretrain/models/utils/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..20b4591f4f09ae185dd28e432dff7919d98d3a50 --- /dev/null +++ b/mmpretrain/models/utils/swiglu_ffn.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.drop import build_dropout + +from .layer_scale import LayerScale +from .norm import build_norm_layer + + +class SwiGLUFFN(nn.Module): + """SwiGLU FFN layer. + + Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0., + bias: bool = True, + dropout_layer: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + add_identity: bool = True, + ) -> None: + super().__init__() + self.embed_dims = embed_dims + self.out_dims = out_dims or embed_dims + hidden_dims = feedforward_channels or embed_dims + + self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, hidden_dims) + else: + self.norm = nn.Identity() + + self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale( + dim=embed_dims, layer_scale_init_value=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + def forward(self, + x: torch.Tensor, + identity: Optional[torch.Tensor] = None) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + hidden = self.norm(hidden) + out = self.w3(hidden) + out = self.gamma2(out) + out = self.dropout_layer(out) + + if self.out_dims != self.embed_dims or not self.add_identity: + # due to the dimension inconsistence or user setting + # not to apply residual operation + return out + + if identity is None: + identity = x + return identity + out + + +class SwiGLUFFNFused(SwiGLUFFN): + """SwiGLU FFN layer with fusing. + + Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0., + bias: bool = True, + ) -> None: + out_dims = out_dims or embed_dims + feedforward_channels = feedforward_channels or embed_dims + feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8 + super().__init__( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + out_dims=out_dims, + layer_scale_init_value=layer_scale_init_value, + bias=bias, + ) diff --git a/mmpretrain/models/utils/tokenizer.py b/mmpretrain/models/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8a324bad00ff03a9ce24dc4cff222e379f1520 --- /dev/null +++ b/mmpretrain/models/utils/tokenizer.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import os + +from mmengine.fileio import list_from_file +from transformers import (AutoTokenizer, BartTokenizer, BasicTokenizer, + BertTokenizer, BertTokenizerFast, LlamaTokenizer, + WordpieceTokenizer) + +from mmpretrain.registry import TOKENIZER +from .huggingface import register_hf_tokenizer + +register_hf_tokenizer(AutoTokenizer) +register_hf_tokenizer(LlamaTokenizer) + + +@register_hf_tokenizer() +class BlipTokenizer(BertTokenizerFast): + """"BlipTokenizer inherit BertTokenizerFast (fast, Rust-based).""" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) + return tokenizer + + +@register_hf_tokenizer() +class Blip2Tokenizer(BertTokenizer): + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + return tokenizer + + +@register_hf_tokenizer() +class OFATokenizer(BartTokenizer): + + vocab_files_names = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt' + } + + pretrained_vocab_files_map = { + 'vocab_file': { + 'OFA-Sys/OFA-tiny': + 'https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/vocab.json', + 'OFA-Sys/OFA-medium': + 'https://huggingface.co/OFA-Sys/OFA-medium/blob/main/vocab.json', + 'OFA-Sys/OFA-base': + 'https://huggingface.co/OFA-Sys/OFA-base/blob/main/vocab.json', + 'OFA-Sys/OFA-large': + 'https://huggingface.co/OFA-Sys/OFA-large/blob/main/vocab.json', + }, + 'merges_file': { + 'OFA-Sys/OFA-tiny': + 'https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/merges.txt', + 'OFA-Sys/OFA-medium': + 'https://huggingface.co/OFA-Sys/OFA-medium/blob/main/merges.txt', + 'OFA-Sys/OFA-base': + 'https://huggingface.co/OFA-Sys/OFA-base/blob/main/merges.txt', + 'OFA-Sys/OFA-large': + 'https://huggingface.co/OFA-Sys/OFA-large/blob/main/merges.txt', + }, + } + + max_model_input_sizes = { + 'OFA-Sys/OFA-tiny': 1024, + 'OFA-Sys/OFA-medium': 1024, + 'OFA-Sys/OFA-base': 1024, + 'OFA-Sys/OFA-large': 1024, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + num_bins = kwargs.pop('num_bins', 1000) + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + length = len(tokenizer) + tokenizer.add_tokens([''.format(i) for i in range(8192)]) + tokenizer.code_offset = length + tokenizer.add_tokens([''.format(i) for i in range(num_bins)]) + tokenizer.bin_offset = length + 8192 + tokenizer.num_bins = num_bins + return tokenizer + + +@TOKENIZER.register_module() +class FullTokenizer(BertTokenizer): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = self.load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token='[UNK]', max_input_chars_per_word=200) + + def load_vocab(self, vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + vocab_list = list_from_file(vocab_file) + for token in vocab_list: + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_by_vocab(self, vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + def convert_tokens_to_ids(self, tokens): + return self.convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return self.convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """Converts a sequence of tokens (string) in a single string.""" + + def clean_up_tokenization(out_string): + """Clean up a list of simple English tokenization artifacts like + spaces before punctuations and abbreviated forms.""" + out_string = ( + out_string.replace(' .', '.').replace(' ?', '?').replace( + ' !', '!').replace(' ,', ',').replace(" ' ", "'").replace( + " n't", "n't").replace(" 'm", "'m").replace( + " 's", "'s").replace(" 've", + "'ve").replace(" 're", "'re")) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) diff --git a/mmpretrain/models/utils/vector_quantizer.py b/mmpretrain/models/utils/vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2ea89339e190d0d19bf5c89b60c1d4bab8fad5 --- /dev/null +++ b/mmpretrain/models/utils/vector_quantizer.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2022 Microsoft +# Modified from +# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mmengine.dist import all_reduce + + +def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average with norm data.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1)) + + +def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor: + """Sample vectors according to the given number.""" + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num, ), device=device) + + return samples[indices] + + +def kmeans(samples: torch.Tensor, + num_clusters: int, + num_iters: int = 10, + use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Run k-means algorithm.""" + dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = F.normalize(new_means, p=2, dim=-1) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + """The codebook of embedding vectors. + + Args: + num_tokens (int): Number of embedding vectors in the codebook. + codebook_dim (int) : The dimension of embedding vectors in the + codebook. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_tokens: int, + codebook_dim: int, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + if codebook_init_path is None: + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = F.normalize(weight, p=2, dim=-1) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f'load init codebook weight from {codebook_init_path}') + codebook_ckpt_weight = torch.load( + codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize embedding vectors of codebook.""" + if self.initted: + return + print('Performing K-means init for codebook') + embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id: torch.Tensor) -> torch.Tensor: + """Get embedding vectors.""" + return F.embedding(embed_id, self.weight) + + +class NormEMAVectorQuantizer(nn.Module): + """Normed EMA vector quantizer module. + + Args: + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + beta (float): The mutiplier for VectorQuantizer embedding loss. + Defaults to 1. + decay (float): The decay parameter of EMA. Defaults to 0.99. + statistic_code_usage (bool): Whether to use cluster_size to record + statistic. Defaults to True. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_embed: int, + embed_dims: int, + beta: float, + decay: float = 0.99, + statistic_code_usage: bool = True, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None) -> None: + super().__init__() + self.codebook_dim = embed_dims + self.num_tokens = num_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA( + num_tokens=self.num_tokens, + codebook_dim=self.codebook_dim, + kmeans_init=kmeans_init, + codebook_init_path=codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(num_embed)) + + def reset_cluster_size(self, device): + + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + """Forward function.""" + # reshape z -> (batch, height, width, channel) + z = rearrange(z, 'b c h w -> b h w c') + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + # 'n d -> d n' + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + all_reduce(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # update cluster size with EMA + bins = encodings.sum(0) + all_reduce(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + all_reduce(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = F.normalize(embed_normalized, p=2, dim=-1) + embed_normalized = torch.where(zero_mask[..., None], + self.embedding.weight, + embed_normalized) + + # Update embedding vectors with EMA + norm_ema_inplace(self.embedding.weight, embed_normalized, + self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, encoding_indices diff --git a/mmpretrain/registry.py b/mmpretrain/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cac2bdad725b9adf5c345d58e5e4a0320b3ddcd4 --- /dev/null +++ b/mmpretrain/registry.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMPretrain provides 21 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +__all__ = [ + 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'LOG_PROCESSORS', + 'OPTIMIZERS', 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', + 'PARAM_SCHEDULERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', + 'MODEL_WRAPPERS', 'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'TASK_UTILS', + 'METRICS', 'EVALUATORS', 'VISUALIZERS', 'VISBACKENDS' +] + +####################################################################### +# mmpretrain.engine # +####################################################################### + +# Runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry( + 'runner', + parent=MMENGINE_RUNNERS, + locations=['mmpretrain.engine'], +) +# Runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', + parent=MMENGINE_RUNNER_CONSTRUCTORS, + locations=['mmpretrain.engine'], +) +# Loops which define the training or test process, like `EpochBasedTrainLoop` +LOOPS = Registry( + 'loop', + parent=MMENGINE_LOOPS, + locations=['mmpretrain.engine'], +) +# Hooks to add additional functions during running, like `CheckpointHook` +HOOKS = Registry( + 'hook', + parent=MMENGINE_HOOKS, + locations=['mmpretrain.engine'], +) +# Log processors to process the scalar log data. +LOG_PROCESSORS = Registry( + 'log processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmpretrain.engine'], +) +# Optimizers to optimize the model weights, like `SGD` and `Adam`. +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmpretrain.engine'], +) +# Optimizer wrappers to enhance the optimization process. +OPTIM_WRAPPERS = Registry( + 'optimizer_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmpretrain.engine'], +) +# Optimizer constructors to customize the hyperparameters of optimizers. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmpretrain.engine'], +) +# Parameter schedulers to dynamically adjust optimization parameters. +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmpretrain.engine'], +) + +####################################################################### +# mmpretrain.datasets # +####################################################################### + +# Datasets like `ImageNet` and `CIFAR10`. +DATASETS = Registry( + 'dataset', + parent=MMENGINE_DATASETS, + locations=['mmpretrain.datasets'], +) +# Samplers to sample the dataset. +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmpretrain.datasets'], +) +# Transforms to process the samples from the dataset. +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmpretrain.datasets'], +) + +####################################################################### +# mmpretrain.models # +####################################################################### + +# Neural network modules inheriting `nn.Module`. +MODELS = Registry( + 'model', + parent=MMENGINE_MODELS, + locations=['mmpretrain.models'], +) +# Model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmpretrain.models'], +) +# Weight initialization methods like uniform, xavier. +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmpretrain.models'], +) +# Batch augmentations like `Mixup` and `CutMix`. +BATCH_AUGMENTS = Registry( + 'batch augment', + locations=['mmpretrain.models'], +) +# Task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', + parent=MMENGINE_TASK_UTILS, + locations=['mmpretrain.models'], +) +# Tokenizer to encode sequence +TOKENIZER = Registry( + 'tokenizer', + locations=['mmpretrain.models'], +) + +####################################################################### +# mmpretrain.evaluation # +####################################################################### + +# Metrics to evaluate the model prediction results. +METRICS = Registry( + 'metric', + parent=MMENGINE_METRICS, + locations=['mmpretrain.evaluation'], +) +# Evaluators to define the evaluation process. +EVALUATORS = Registry( + 'evaluator', + parent=MMENGINE_EVALUATOR, + locations=['mmpretrain.evaluation'], +) + +####################################################################### +# mmpretrain.visualization # +####################################################################### + +# Visualizers to display task-specific results. +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmpretrain.visualization'], +) +# Backends to save the visualization results, like TensorBoard, WandB. +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmpretrain.visualization'], +) diff --git a/mmpretrain/structures/__init__.py b/mmpretrain/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7de863087d9d07800ff119d3c8b941059ef3886 --- /dev/null +++ b/mmpretrain/structures/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_sample import DataSample +from .multi_task_data_sample import MultiTaskDataSample +from .utils import (batch_label_to_onehot, cat_batch_labels, format_label, + format_score, label_to_onehot, tensor_split) + +__all__ = [ + 'DataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'tensor_split', + 'MultiTaskDataSample', 'label_to_onehot', 'format_label', 'format_score' +] diff --git a/mmpretrain/structures/__pycache__/__init__.cpython-311.pyc b/mmpretrain/structures/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b28c26b9670898a9b7fe4239e2e8f8bc34e9d8 Binary files /dev/null and b/mmpretrain/structures/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/structures/__pycache__/data_sample.cpython-311.pyc b/mmpretrain/structures/__pycache__/data_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d44e7d2a404c8a2e8ffb3e7054624372819ce19 Binary files /dev/null and b/mmpretrain/structures/__pycache__/data_sample.cpython-311.pyc differ diff --git a/mmpretrain/structures/__pycache__/multi_task_data_sample.cpython-311.pyc b/mmpretrain/structures/__pycache__/multi_task_data_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d4f07e8d19eca2eed1138cc4e321e012c67071 Binary files /dev/null and b/mmpretrain/structures/__pycache__/multi_task_data_sample.cpython-311.pyc differ diff --git a/mmpretrain/structures/__pycache__/utils.cpython-311.pyc b/mmpretrain/structures/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8912f29a08ad05b39983ad14e3e4d3a3bbf7b4e0 Binary files /dev/null and b/mmpretrain/structures/__pycache__/utils.cpython-311.pyc differ diff --git a/mmpretrain/structures/data_sample.py b/mmpretrain/structures/data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..ce588b8ba13811afdb2bb3300d42f221a6f2df7f --- /dev/null +++ b/mmpretrain/structures/data_sample.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing.reduction import ForkingPickler +from typing import Union + +import numpy as np +import torch +from mmengine.structures import BaseDataElement + +from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score + + +class DataSample(BaseDataElement): + """A general data structure interface. + + It's used as the interface between different components. + + The following fields are convention names in MMPretrain, and we will set or + get these fields in data transforms, models, and metrics if needed. You can + also set any new fields for your need. + + Meta fields: + img_shape (Tuple): The shape of the corresponding input image. + ori_shape (Tuple): The original shape of the corresponding image. + sample_idx (int): The index of the sample in the dataset. + num_classes (int): The number of all categories. + + Data fields: + gt_label (tensor): The ground truth label. + gt_score (tensor): The ground truth score. + pred_label (tensor): The predicted label. + pred_score (tensor): The predicted score. + mask (tensor): The mask used in masked image modeling. + + Examples: + >>> import torch + >>> from mmpretrain.structures import DataSample + >>> + >>> img_meta = dict(img_shape=(960, 720), num_classes=5) + >>> data_sample = DataSample(metainfo=img_meta) + >>> data_sample.set_gt_label(3) + >>> print(data_sample) + + >>> + >>> # For multi-label data + >>> data_sample = DataSample().set_gt_label([0, 1, 4]) + >>> print(data_sample) + + >>> + >>> # Set one-hot format score + >>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1]) + >>> print(data_sample) + + >>> + >>> # Set custom field + >>> data_sample = DataSample() + >>> data_sample.my_field = [1, 2, 3] + >>> print(data_sample) + + >>> print(data_sample.my_field) + [1, 2, 3] + """ + + def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample': + """Set ``gt_label``.""" + self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor) + return self + + def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample': + """Set ``gt_score``.""" + score = format_score(value) + self.set_field(score, 'gt_score', dtype=torch.Tensor) + if hasattr(self, 'num_classes'): + assert len(score) == self.num_classes, \ + f'The length of score {len(score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', value=len(score), field_type='metainfo') + return self + + def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample': + """Set ``pred_label``.""" + self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor) + return self + + def set_pred_score(self, value: SCORE_TYPE): + """Set ``pred_label``.""" + score = format_score(value) + self.set_field(score, 'pred_score', dtype=torch.Tensor) + if hasattr(self, 'num_classes'): + assert len(score) == self.num_classes, \ + f'The length of score {len(score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', value=len(score), field_type='metainfo') + return self + + def set_mask(self, value: Union[torch.Tensor, np.ndarray]): + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Invalid mask type {type(value)}') + self.set_field(value, 'mask', dtype=torch.Tensor) + return self + + def __repr__(self) -> str: + """Represent the object.""" + + def dump_items(items, prefix=''): + return '\n'.join(f'{prefix}{k}: {v}' for k, v in items) + + repr_ = '' + if len(self._metainfo_fields) > 0: + repr_ += '\n\nMETA INFORMATION\n' + repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4) + if len(self._data_fields) > 0: + repr_ += '\n\nDATA FIELDS\n' + repr_ += dump_items(self.items(), prefix=' ' * 4) + + repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>' + return repr_ + + +def _reduce_datasample(data_sample): + """reduce DataSample.""" + attr_dict = data_sample.__dict__ + convert_keys = [] + for k, v in attr_dict.items(): + if isinstance(v, torch.Tensor): + attr_dict[k] = v.numpy() + convert_keys.append(k) + return _rebuild_datasample, (attr_dict, convert_keys) + + +def _rebuild_datasample(attr_dict, convert_keys): + """rebuild DataSample.""" + data_sample = DataSample() + for k in convert_keys: + attr_dict[k] = torch.from_numpy(attr_dict[k]) + data_sample.__dict__ = attr_dict + return data_sample + + +# Due to the multi-processing strategy of PyTorch, DataSample may consume many +# file descriptors because it contains multiple tensors. Here we overwrite the +# reduce function of DataSample in ForkingPickler and convert these tensors to +# np.ndarray during pickling. It may slightly influence the performance of +# dataloader. +ForkingPickler.register(DataSample, _reduce_datasample) diff --git a/mmpretrain/structures/multi_task_data_sample.py b/mmpretrain/structures/multi_task_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f00993861bfb4f35fb7d145198f81c5e9f0a5993 --- /dev/null +++ b/mmpretrain/structures/multi_task_data_sample.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.structures import BaseDataElement + + +class MultiTaskDataSample(BaseDataElement): + + @property + def tasks(self): + return self._data_fields diff --git a/mmpretrain/structures/utils.py b/mmpretrain/structures/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f9e95ef6fd557b9d0bdf5f017a7b73ba250453 --- /dev/null +++ b/mmpretrain/structures/utils.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.utils import is_str + +if hasattr(torch, 'tensor_split'): + tensor_split = torch.tensor_split +else: + # A simple implementation of `tensor_split`. + def tensor_split(input: torch.Tensor, indices: list): + outs = [] + for start, end in zip([0] + indices, indices + [input.size(0)]): + outs.append(input[start:end]) + return outs + + +LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int] +SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence] + + +def format_label(value: LABEL_TYPE) -> torch.Tensor: + """Convert various python types to label-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + + Returns: + :obj:`torch.Tensor`: The foramtted label tensor. + """ + + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).to(torch.long) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def format_score(value: SCORE_TYPE) -> torch.Tensor: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence): Score values. + + Returns: + :obj:`torch.Tensor`: The foramtted score tensor. + """ + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).float() + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def cat_batch_labels(elements: List[torch.Tensor]): + """Concat a batch of label tensor to one tensor. + + Args: + elements (List[tensor]): A batch of labels. + + Returns: + Tuple[torch.Tensor, List[int]]: The first item is the concated label + tensor, and the second item is the split indices of every sample. + """ + labels = [] + splits = [0] + for element in elements: + labels.append(element) + splits.append(splits[-1] + element.size(0)) + batch_label = torch.cat(labels) + return batch_label, splits[1:-1] + + +def batch_label_to_onehot(batch_label, split_indices, num_classes): + """Convert a concated label tensor to onehot format. + + Args: + batch_label (torch.Tensor): A concated label tensor from multiple + samples. + split_indices (List[int]): The split indices of every sample. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmpretrain.structures import batch_label_to_onehot + >>> # Assume a concated label from 3 samples. + >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] + >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) + >>> split_indices = [2, 5] + >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) + tensor([[1, 1, 0, 0, 0], + [1, 0, 1, 0, 1], + [0, 1, 0, 1, 0]]) + """ + sparse_onehot_list = F.one_hot(batch_label, num_classes) + onehot_list = [ + sparse_onehot.sum(0) + for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) + ] + return torch.stack(onehot_list) + + +def label_to_onehot(label: LABEL_TYPE, num_classes: int): + """Convert a label to onehot format tensor. + + Args: + label (LABEL_TYPE): Label value. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmpretrain.structures import label_to_onehot + >>> # Single-label + >>> label_to_onehot(1, num_classes=5) + tensor([0, 1, 0, 0, 0]) + >>> # Multi-label + >>> label_to_onehot([0, 2, 3], num_classes=5) + tensor([1, 0, 1, 1, 0]) + """ + label = format_label(label) + sparse_onehot = F.one_hot(label, num_classes) + return sparse_onehot.sum(0) diff --git a/mmpretrain/utils/__init__.py b/mmpretrain/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..991e3217d2f1e5926028e6c9c79e450e30404a33 --- /dev/null +++ b/mmpretrain/utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .analyze import load_json_log +from .collect_env import collect_env +from .dependency import require +from .misc import get_ori_model +from .progress import track, track_on_main_process +from .setup_env import register_all_modules + +__all__ = [ + 'collect_env', 'register_all_modules', 'track_on_main_process', + 'load_json_log', 'get_ori_model', 'track', 'require' +] diff --git a/mmpretrain/utils/__pycache__/__init__.cpython-311.pyc b/mmpretrain/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0546bf892d773d43f976f872101b5534a043b7 Binary files /dev/null and b/mmpretrain/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmpretrain/utils/__pycache__/analyze.cpython-311.pyc b/mmpretrain/utils/__pycache__/analyze.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72a40de9202ecca594f7e4648a7c60675345866c Binary files /dev/null and b/mmpretrain/utils/__pycache__/analyze.cpython-311.pyc differ diff --git a/mmpretrain/utils/__pycache__/collect_env.cpython-311.pyc b/mmpretrain/utils/__pycache__/collect_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..478f1fc67fe69c7f8653196ead1989485102d99d Binary files /dev/null and b/mmpretrain/utils/__pycache__/collect_env.cpython-311.pyc differ diff --git a/mmpretrain/utils/__pycache__/dependency.cpython-311.pyc b/mmpretrain/utils/__pycache__/dependency.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ff062f158407ab98f1dff6085f22c48217c4ed Binary files /dev/null and b/mmpretrain/utils/__pycache__/dependency.cpython-311.pyc differ diff --git a/mmpretrain/utils/__pycache__/misc.cpython-311.pyc b/mmpretrain/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15aa1d6028fd06898d9130dc72015d70bcafee1 Binary files /dev/null and b/mmpretrain/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/mmpretrain/utils/__pycache__/progress.cpython-311.pyc b/mmpretrain/utils/__pycache__/progress.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50532ee1102635cad9b7e8812e8b2d3a93723374 Binary files /dev/null and b/mmpretrain/utils/__pycache__/progress.cpython-311.pyc differ diff --git a/mmpretrain/utils/__pycache__/setup_env.cpython-311.pyc b/mmpretrain/utils/__pycache__/setup_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c90c14987db8390c95d1ce81ee77bd45e29db07e Binary files /dev/null and b/mmpretrain/utils/__pycache__/setup_env.cpython-311.pyc differ diff --git a/mmpretrain/utils/analyze.py b/mmpretrain/utils/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..a933591618951e1e49558f4f5cbbdf9c49a76bfe --- /dev/null +++ b/mmpretrain/utils/analyze.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + + +def load_json_log(json_log): + """load and convert json_logs to log_dicts. + + Args: + json_log (str): The path of the json log file. + + Returns: + dict: The result dict contains two items, "train" and "val", for + the training log and validate log. + + Example: + An example output: + + .. code-block:: python + + { + 'train': [ + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 100}, + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 200}, + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 300}, + ... + ] + 'val': [ + {"accuracy/top1": 32.1, "step": 1}, + {"accuracy/top1": 50.2, "step": 2}, + {"accuracy/top1": 60.3, "step": 2}, + ... + ] + } + """ + log_dict = dict(train=[], val=[]) + with open(json_log, 'r') as log_file: + for line in log_file: + log = json.loads(line.strip()) + # A hack trick to determine whether the line is training log. + mode = 'train' if 'lr' in log else 'val' + log_dict[mode].append(log) + + return log_dict diff --git a/mmpretrain/utils/collect_env.py b/mmpretrain/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..988451ec530e8d21ec3d5a087a3bb7f7b66fd223 --- /dev/null +++ b/mmpretrain/utils/collect_env.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmpretrain + + +def collect_env(with_torch_comiling_info=False): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMCV'] = mmcv.__version__ + if not with_torch_comiling_info: + env_info.pop('PyTorch compiling details') + env_info['MMPreTrain'] = mmpretrain.__version__ + '+' + get_git_hash()[:7] + return env_info diff --git a/mmpretrain/utils/dependency.py b/mmpretrain/utils/dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3d8ae5df7a6968f26e0563e80a7d37a2e2cd68 --- /dev/null +++ b/mmpretrain/utils/dependency.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from functools import wraps +from inspect import isfunction + +from importlib_metadata import PackageNotFoundError, distribution +from mmengine.utils import digit_version + + +def satisfy_requirement(dep): + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, dep, maxsplit=1) + parts = [p.strip() for p in parts] + package = parts[0] + if len(parts) > 1: + op, version = parts[1:] + op = { + '>=': '__ge__', + '==': '__eq__', + '>': '__gt__', + '<': '__lt__', + '<=': '__le__' + }[op] + else: + op, version = None, None + + try: + dist = distribution(package) + if op is None or getattr(digit_version(dist.version), op)( + digit_version(version)): + return True + except PackageNotFoundError: + pass + + return False + + +def require(dep, install=None): + """A wrapper of function for extra package requirements. + + Args: + dep (str): The dependency package name, like ``transformers`` + or ``transformers>=4.28.0``. + install (str, optional): The installation command hint. Defaults + to None, which means to use "pip install dep". + """ + + def wrapper(fn): + assert isfunction(fn) + + @wraps(fn) + def ask_install(*args, **kwargs): + name = fn.__qualname__.replace('.__init__', '') + ins = install or f'pip install "{dep}"' + raise ImportError( + f'{name} requires {dep}, please install it by `{ins}`.') + + if satisfy_requirement(dep): + fn._verify_require = getattr(fn, '_verify_require', lambda: None) + return fn + + ask_install._verify_require = ask_install + return ask_install + + return wrapper + + +WITH_MULTIMODAL = all( + satisfy_requirement(item) + for item in ['pycocotools', 'transformers>=4.28.0']) + + +def register_multimodal_placeholder(names, registry): + for name in names: + + def ask_install(*args, **kwargs): + raise ImportError( + f'{name} requires extra multi-modal dependencies, please ' + 'install it by `pip install "mmpretrain[multimodal]"` ' + 'or `pip install -e ".[multimodal]"`.') + + registry.register_module(name=name, module=ask_install) diff --git a/mmpretrain/utils/misc.py b/mmpretrain/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cc532679943689233be76e9a8f74da8ed822443e --- /dev/null +++ b/mmpretrain/utils/misc.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmengine.model import is_model_wrapper + + +def get_ori_model(model: nn.Module) -> nn.Module: + """Get original model if the input model is a model wrapper. + + Args: + model (nn.Module): A model may be a model wrapper. + + Returns: + nn.Module: The model without model wrapper. + """ + if is_model_wrapper(model): + return model.module + else: + return model diff --git a/mmpretrain/utils/progress.py b/mmpretrain/utils/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..b23f976a42fc3a2f6e38f025f01041deb5608405 --- /dev/null +++ b/mmpretrain/utils/progress.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import mmengine.dist as dist +import rich.progress as progress +from rich.live import Live + +disable_progress_bar = False +global_progress = progress.Progress( + '{task.description}', + progress.BarColumn(), + progress.TaskProgressColumn(show_speed=True), + progress.TimeRemainingColumn(), +) +global_live = Live(global_progress, refresh_per_second=10) + + +def track(sequence, description: str = '', total: Optional[float] = None): + if disable_progress_bar: + yield from sequence + else: + global_live.start() + task_id = global_progress.add_task(description, total=total) + task = global_progress._tasks[task_id] + try: + yield from global_progress.track(sequence, task_id=task_id) + finally: + if task.total is None: + global_progress.update(task_id, total=task.completed) + if all(task.finished for task in global_progress.tasks): + global_live.stop() + for task_id in global_progress.task_ids: + global_progress.remove_task(task_id) + + +def track_on_main_process(sequence, description='', total=None): + if not dist.is_main_process() or disable_progress_bar: + yield from sequence + else: + yield from track(sequence, total=total, description=description) diff --git a/mmpretrain/utils/setup_env.py b/mmpretrain/utils/setup_env.py new file mode 100644 index 0000000000000000000000000000000000000000..1b57b848c98a75c7a1b5854c800ecc2dd5da6df8 --- /dev/null +++ b/mmpretrain/utils/setup_env.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings + +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmpretrain into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmpretrain default + scope. If True, the global default scope will be set to + `mmpretrain`, and all registries will build modules from + mmpretrain's registry node. To understand more about the registry, + please refer to + https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa: E501 + import mmpretrain.datasets # noqa: F401,F403 + import mmpretrain.engine # noqa: F401,F403 + import mmpretrain.evaluation # noqa: F401,F403 + import mmpretrain.models # noqa: F401,F403 + import mmpretrain.structures # noqa: F401,F403 + import mmpretrain.visualization # noqa: F401,F403 + + if not init_default_scope: + return + + current_scope = DefaultScope.get_current_instance() + if current_scope is None: + DefaultScope.get_instance('mmpretrain', scope_name='mmpretrain') + elif current_scope.scope_name != 'mmpretrain': + warnings.warn( + f'The current default scope "{current_scope.scope_name}" ' + 'is not "mmpretrain", `register_all_modules` will force ' + 'the current default scope to be "mmpretrain". If this is ' + 'not expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmpretrain-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmpretrain') diff --git a/mmpretrain/version.py b/mmpretrain/version.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8c8b7f0a9b327459cc5a62b1e8cc9a44c48894 --- /dev/null +++ b/mmpretrain/version.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved + +__version__ = '1.1.1' + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/mmpretrain/visualization/__init__.py b/mmpretrain/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbeecfb070193f479b248dca3e98311577410a1 --- /dev/null +++ b/mmpretrain/visualization/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .utils import create_figure, get_adaptive_scale +from .visualizer import UniversalVisualizer + +__all__ = ['UniversalVisualizer', 'get_adaptive_scale', 'create_figure'] diff --git a/mmpretrain/visualization/utils.py b/mmpretrain/visualization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91a1d81f1449dfbfb7ff5198eb6dc25a6386ed48 --- /dev/null +++ b/mmpretrain/visualization/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from matplotlib.figure import Figure + + +def get_adaptive_scale(img_shape: Tuple[int, int], + min_scale: float = 0.3, + max_scale: float = 3.0) -> float: + """Get adaptive scale according to image shape. + + The target scale depends on the the short edge length of the image. If the + short edge length equals 224, the output is 1.0. And output linear scales + according the short edge length. + + You can also specify the minimum scale and the maximum scale to limit the + linear scale. + + Args: + img_shape (Tuple[int, int]): The shape of the canvas image. + min_size (int): The minimum scale. Defaults to 0.3. + max_size (int): The maximum scale. Defaults to 3.0. + + Returns: + int: The adaptive scale. + """ + short_edge_length = min(img_shape) + scale = short_edge_length / 224. + return min(max(scale, min_scale), max_scale) + + +def create_figure(*args, margin=False, **kwargs) -> 'Figure': + """Create a independent figure. + + Different from the :func:`plt.figure`, the figure from this function won't + be managed by matplotlib. And it has + :obj:`matplotlib.backends.backend_agg.FigureCanvasAgg`, and therefore, you + can use the ``canvas`` attribute to get access the drawn image. + + Args: + *args: All positional arguments of :class:`matplotlib.figure.Figure`. + margin: Whether to reserve the white edges of the figure. + Defaults to False. + **kwargs: All keyword arguments of :class:`matplotlib.figure.Figure`. + + Return: + matplotlib.figure.Figure: The created figure. + """ + from matplotlib.backends.backend_agg import FigureCanvasAgg + from matplotlib.figure import Figure + + figure = Figure(*args, **kwargs) + FigureCanvasAgg(figure) + + if not margin: + # remove white edges by set subplot margin + figure.subplots_adjust(left=0, right=1, bottom=0, top=1) + + return figure diff --git a/mmpretrain/visualization/visualizer.py b/mmpretrain/visualization/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d18ca87f6bc246b4defe17281ae87c4464e1b89 --- /dev/null +++ b/mmpretrain/visualization/visualizer.py @@ -0,0 +1,777 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.dataset import BaseDataset +from mmengine.dist import master_only +from mmengine.visualization import Visualizer +from mmengine.visualization.utils import img_from_canvas + +from mmpretrain.registry import VISUALIZERS +from mmpretrain.structures import DataSample +from .utils import create_figure, get_adaptive_scale + + +@VISUALIZERS.register_module() +class UniversalVisualizer(Visualizer): + """Universal Visualizer for multiple tasks. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. + """ + DEFAULT_TEXT_CFG = { + 'family': 'monospace', + 'color': 'white', + 'bbox': dict(facecolor='black', alpha=0.5, boxstyle='Round'), + 'verticalalignment': 'top', + 'horizontalalignment': 'left', + } + + @master_only + def visualize_cls(self, + image: np.ndarray, + data_sample: DataSample, + classes: Optional[Sequence[str]] = None, + draw_gt: bool = True, + draw_pred: bool = True, + draw_score: bool = True, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize image classification result. + + This method will draw an text box on the input image to visualize the + information about image classification, like the ground-truth label and + prediction label. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + classes (Sequence[str], optional): The categories names. + Defaults to None. + draw_gt (bool): Whether to draw ground-truth labels. + Defaults to True. + draw_pred (bool): Whether to draw prediction labels. + Defaults to True. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if self.dataset_meta is not None: + classes = classes or self.dataset_meta.get('classes', None) + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + elif rescale_factor is not None: + image = mmcv.imrescale(image, rescale_factor) + + texts = [] + self.set_image(image) + + if draw_gt and 'gt_label' in data_sample: + idx = data_sample.gt_label.tolist() + class_labels = [''] * len(idx) + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] + prefix = 'Ground truth: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + if draw_pred and 'pred_label' in data_sample: + idx = data_sample.pred_label.tolist() + score_labels = [''] * len(idx) + class_labels = [''] * len(idx) + if draw_score and 'pred_score' in data_sample: + score_labels = [ + f', {data_sample.pred_score[i].item():.2f}' for i in idx + ] + + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + + labels = [ + str(idx[i]) + score_labels[i] + class_labels[i] + for i in range(len(idx)) + ] + prefix = 'Prediction: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image retrieval result. + + This method will draw the input image and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + if resize is not None: + image = mmcv.imrescale(image, (resize, resize)) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True) + gs = figure.add_gridspec(2, topk) + query_plot = figure.add_subplot(gs[0, :]) + query_plot.axis(False) + query_plot.imshow(image) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[1, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + def add_mask_to_image( + self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + ) -> np.ndarray: + if isinstance(resize, int): + resize = (resize, resize) + + image = mmcv.imresize(image, resize) + self.set_image(image) + + if isinstance(data_sample.mask, np.ndarray): + data_sample.mask = torch.tensor(data_sample.mask) + mask = data_sample.mask.float()[None, None, ...] + mask_ = F.interpolate(mask, image.shape[:2], mode='nearest')[0, 0] + + self.draw_binary_masks(mask_.bool(), colors=color, alphas=alpha) + + drawn_img = self.get_image() + return drawn_img + + @master_only + def visualize_masked_image(self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize masked image. + + This method will draw an image with binary mask. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int | Tuple[int]): Resize the input image to the specified + shape. Defaults to 224. + color (str | Tuple[int]): The color of the binary mask. + Defaults to "black". + alpha (int | float): The transparency of the mask. Defaults to 0.8. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + drawn_img = self.add_mask_to_image( + image=image, + data_sample=data_sample, + resize=resize, + color=color, + alpha=alpha) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_caption(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image caption result. + + This method will draw the input image and the images caption. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + data_sample.get('pred_caption'), + wrap=True, + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_vqa(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize visual question answering result. + + This method will draw the input image, question and answer. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + text = (f'Q: {data_sample.get("question")}\n' + f'A: {data_sample.get("pred_answer")}') + self.ax_save.text( + img_scale * 5, + img_scale * 5, + text, + wrap=True, + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_visual_grounding(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + line_width: Union[int, float] = 3, + bbox_color: Union[str, tuple] = 'green', + step: int = 0) -> None: + """Visualize visual grounding result. + + This method will draw the input image, bbox and the object. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + gt_bboxes = data_sample.get('gt_bboxes') + pred_bboxes = data_sample.get('pred_bboxes') + if resize is not None: + h, w = image.shape[:2] + if w < h: + image, w_scale, h_scale = mmcv.imresize( + image, (resize, resize * h // w), return_scale=True) + else: + image, w_scale, h_scale = mmcv.imresize( + image, (resize * w // h, resize), return_scale=True) + pred_bboxes[:, ::2] *= w_scale + pred_bboxes[:, 1::2] *= h_scale + if gt_bboxes is not None: + gt_bboxes[:, ::2] *= w_scale + gt_bboxes[:, 1::2] *= h_scale + + self.set_image(image) + # Avoid the line-width limit in the base classes. + self._default_font_size = 1e3 + self.draw_bboxes( + pred_bboxes, line_widths=line_width, edge_colors=bbox_color) + if gt_bboxes is not None: + self.draw_bboxes( + gt_bboxes, line_widths=line_width, edge_colors='blue') + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + + text_positions = pred_bboxes[:, :2] + line_width + for i in range(pred_bboxes.size(0)): + self.ax_save.text( + text_positions[i, 0] + line_width, + text_positions[i, 1] + line_width, + data_sample.get('text'), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_t2i_retrieval(self, + text: str, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + text_cfg: dict = dict(), + fig_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize Text-To-Image retrieval result. + + This method will draw the input text and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + topk (int): To visualize the topk matching items. Defaults to 1. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + fig_cfg (dict): Extra figure setting, which accepts arguments of + :func:`plt.Figure`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True, **fig_cfg) + figure.suptitle(text) + gs = figure.add_gridspec(1, topk) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[0, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_i2t_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: Sequence[str], + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize Image-To-Text retrieval result. + + This method will draw the input image and the texts retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (Sequence[str]): The prototype dataset. + It should be a list of texts. + topk (int): To visualize the topk matching items. Defaults to 1. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + texts = [] + for score, sample_idx in zip(match_scores, indices): + text = prototype_dataset[sample_idx.item()] + if draw_score: + text = f'{score:.2f} ' + text + texts.append(text) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img diff --git a/mmseg/.DS_Store b/mmseg/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9bcbea1ccb91c857e6bc9898adb429b998431647 Binary files /dev/null and b/mmseg/.DS_Store differ diff --git a/mmseg/__init__.py b/mmseg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcb84e8c4f986121ba9d782b384477129f75ff6 --- /dev/null +++ b/mmseg/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv +import mmengine +from packaging.version import parse + +from .version import __version__, version_info + +MMCV_MIN = '2.0.0rc4' +MMCV_MAX = '2.2.0' +MMENGINE_MIN = '0.5.0' +MMENGINE_MAX = '1.0.0' + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +mmcv_min_version = digit_version(MMCV_MIN) +mmcv_max_version = digit_version(MMCV_MAX) +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>=2.0.0rc4.' + +mmengine_min_version = digit_version(MMENGINE_MIN) +mmengine_max_version = digit_version(MMENGINE_MAX) +mmengine_version = digit_version(mmengine.__version__) + +assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_min_version}, '\ + f'<{mmengine_max_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmseg/__pycache__/__init__.cpython-311.pyc b/mmseg/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2192bfb46ae096facb650211aa970e90fbb0cba Binary files /dev/null and b/mmseg/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/__pycache__/version.cpython-311.pyc b/mmseg/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5554bf13d2cb6afb907050838bd7f53438c78ba3 Binary files /dev/null and b/mmseg/__pycache__/version.cpython-311.pyc differ diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b50a266319c9cf74cb8b13afcff564248c058732 --- /dev/null +++ b/mmseg/apis/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_model, init_model, show_result_pyplot +from .mmseg_inferencer import MMSegInferencer +from .remote_sense_inferencer import RSImage, RSInferencer + +__all__ = [ + 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer', + 'RSInferencer', 'RSImage' +] diff --git a/mmseg/apis/__pycache__/__init__.cpython-311.pyc b/mmseg/apis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b8375bd4c98cf74787e6f56acce4b7bc2a32b8c Binary files /dev/null and b/mmseg/apis/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/apis/__pycache__/inference.cpython-311.pyc b/mmseg/apis/__pycache__/inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac628039c368929a80bd73c49875a03faa0251c5 Binary files /dev/null and b/mmseg/apis/__pycache__/inference.cpython-311.pyc differ diff --git a/mmseg/apis/__pycache__/mmseg_inferencer.cpython-311.pyc b/mmseg/apis/__pycache__/mmseg_inferencer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70bdf35a871d3d2afa12f6ac805dd5617c7d1464 Binary files /dev/null and b/mmseg/apis/__pycache__/mmseg_inferencer.cpython-311.pyc differ diff --git a/mmseg/apis/__pycache__/remote_sense_inferencer.cpython-311.pyc b/mmseg/apis/__pycache__/remote_sense_inferencer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed49973c4eb540f3dcfb8f8ecb09d27e0bf57fab Binary files /dev/null and b/mmseg/apis/__pycache__/remote_sense_inferencer.cpython-311.pyc differ diff --git a/mmseg/apis/__pycache__/utils.cpython-311.pyc b/mmseg/apis/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9a24eb0d834e0a3d114a6a962f974e9f885bde Binary files /dev/null and b/mmseg/apis/__pycache__/utils.cpython-311.pyc differ diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..aab11d14f4becc43d4c2ecd3772417e4923bd20e --- /dev/null +++ b/mmseg/apis/inference.py @@ -0,0 +1,189 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from pathlib import Path +from typing import Optional, Union + +import mmcv +import numpy as np +import torch +from mmengine import Config +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint +from mmengine.utils import mkdir_or_exist + +from mmseg.models import BaseSegmentor +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette +from mmseg.visualization import SegLocalVisualizer +from .utils import ImageType, _preprare_data + + +def init_model(config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + cfg_options: Optional[dict] = None): + """Initialize a segmentor from config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + cfg_options (dict, optional): Options to override some settings in + the used config. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + if cfg_options is not None: + config.merge_from_dict(cfg_options) + if config.model.type == 'EncoderDecoder': + if 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None + elif config.model.type == 'MultimodalEncoderDecoder': + for k, v in config.model.items(): + if isinstance(v, dict) and 'init_cfg' in v: + config.model[k].init_cfg = None + config.model.pretrained = None + config.model.train_cfg = None + init_default_scope(config.get('default_scope', 'mmseg')) + + model = MODELS.build(config.model) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + dataset_meta = checkpoint['meta'].get('dataset_meta', None) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmseg 1.x + model.dataset_meta = dataset_meta + elif 'CLASSES' in checkpoint.get('meta', {}): + # < mmseg 1.x + classes = checkpoint['meta']['CLASSES'] + palette = checkpoint['meta']['PALETTE'] + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.simplefilter('once') + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, classes and palette will be' + 'set according to num_classes ') + num_classes = model.decode_head.num_classes + dataset_name = None + for name in dataset_aliases.keys(): + if len(get_classes(name)) == num_classes: + dataset_name = name + break + if dataset_name is None: + warnings.warn( + 'No suitable dataset found, use Cityscapes by default') + dataset_name = 'cityscapes' + model.dataset_meta = { + 'classes': get_classes(dataset_name), + 'palette': get_palette(dataset_name) + } + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_model(model: BaseSegmentor, + img: ImageType) -> Union[SegDataSample, SampleList]: + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + :obj:`SegDataSample` or list[:obj:`SegDataSample`]: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the segmentation results directly. + """ + # prepare data + data, is_batch = _preprare_data(img, model) + + # forward the model + with torch.no_grad(): + results = model.test_step(data) + + return results if is_batch else results[0] + + +def show_result_pyplot(model: BaseSegmentor, + img: Union[str, np.ndarray], + result: SegDataSample, + opacity: float = 0.5, + title: str = '', + draw_gt: bool = True, + draw_pred: bool = True, + wait_time: float = 0, + show: bool = True, + with_labels: Optional[bool] = True, + save_dir=None, + out_file=None): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (SegDataSample): The prediction SegDataSample result. + opacity(float): Opacity of painted segmentation map. + Default 0.5. Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + wait_time (float): The interval of show (s). 0 is the special value + that means "forever". Defaults to 0. + show (bool): Whether to display the drawn image. + Default to True. + with_labels(bool, optional): Add semantic labels in visualization + result, Default to True. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + out_file (str, optional): Path to output file. Default to None. + + + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + if hasattr(model, 'module'): + model = model.module + if isinstance(img, str): + image = mmcv.imread(img, channel_order='rgb') + else: + image = img + if save_dir is not None: + mkdir_or_exist(save_dir) + # init visualizer + visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], + save_dir=save_dir, + alpha=opacity) + visualizer.dataset_meta = dict( + classes=model.dataset_meta['classes'], + palette=model.dataset_meta['palette']) + visualizer.add_datasample( + name=title, + image=image, + data_sample=result, + draw_gt=draw_gt, + draw_pred=draw_pred, + wait_time=wait_time, + out_file=out_file, + show=show, + with_labels=with_labels) + vis_img = visualizer.get_image() + + return vis_img diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..02a198b516a71c1f5a0833955607ba4ecc05bf13 --- /dev/null +++ b/mmseg/apis/mmseg_inferencer.py @@ -0,0 +1,382 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import List, Optional, Sequence, Union + +import mmcv +import mmengine +import numpy as np +import torch +import torch.nn as nn +from mmcv.transforms import Compose +from mmengine.infer.infer import BaseInferencer, ModelType +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner.checkpoint import _load_checkpoint_to_model +from PIL import Image + +from mmseg.structures import SegDataSample +from mmseg.utils import ConfigType, SampleList, get_classes, get_palette +from mmseg.visualization import SegLocalVisualizer + +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[SegDataSample, SampleList] + + +class MMSegInferencer(BaseInferencer): + """Semantic segmentation inferencer, provides inference and visualization + interfaces. Note: MMEngine >= 0.5.0 is required. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. Take the `mmseg metafile `_ + as an example the `model` could be + "fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model + will be download automatically. If use config file, like + "configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the + `weights` should be defined. + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. If palette is + not defined, visualizer will take `cityscapes` palette by default. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to 'mmseg'. + """ # noqa + + preprocess_kwargs: set = set() + forward_kwargs: set = {'mode', 'out_dir'} + visualize_kwargs: set = { + 'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis', + 'with_labels' + } + postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + classes: Optional[Union[str, List]] = None, + palette: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmseg') -> None: + # A global counter tracking the number of images processes, for + # naming of the output images + self.num_visualized_imgs = 0 + self.num_pred_imgs = 0 + init_default_scope(scope if scope else 'mmseg') + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + if device == 'cpu' or not torch.cuda.is_available(): + self.model = revert_sync_batchnorm(self.model) + + assert isinstance(self.visualizer, SegLocalVisualizer) + self.visualizer.set_dataset_meta(classes, palette, dataset_name) + + def _load_weights_to_model(self, model: nn.Module, + checkpoint: Optional[dict], + cfg: Optional[ConfigType]) -> None: + """Loading model weights and meta information from cfg and checkpoint. + + Subclasses could override this method to load extra meta information + from ``checkpoint`` and ``cfg`` to model. + + Args: + model (nn.Module): Model to load weights and meta information. + checkpoint (dict, optional): The loaded checkpoint. + cfg (Config or ConfigDict, optional): The loaded config. + """ + + if checkpoint is not None: + _load_checkpoint_to_model(model, checkpoint) + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmsegmentation 1.x + model.dataset_meta = { + 'classes': checkpoint_meta['dataset_meta'].get('classes'), + 'palette': checkpoint_meta['dataset_meta'].get('palette') + } + elif 'CLASSES' in checkpoint_meta: + # mmsegmentation 0.x + classes = checkpoint_meta['CLASSES'] + palette = checkpoint_meta.get('PALETTE', None) + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, use classes of Cityscapes by ' + 'default.') + model.dataset_meta = { + 'classes': get_classes('cityscapes'), + 'palette': get_palette('cityscapes') + } + else: + warnings.warn('Checkpoint is not loaded, and the inference ' + 'result is calculated by the randomly initialized ' + 'model!') + warnings.warn( + 'weights is None, use cityscapes classes by default.') + model.dataset_meta = { + 'classes': get_classes('cityscapes'), + 'palette': get_palette('cityscapes') + } + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + out_dir: str = '', + img_out_dir: str = 'vis', + pred_out_dir: str = 'pred', + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (Union[list, str, np.ndarray]): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`SegDataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + show (bool): Whether to display the rendering color segmentation + mask in a popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_dir (str): Output directory of inference results. Defaults + to ''. + img_out_dir (str): Subdirectory of `out_dir`, used to save + rendering color segmentation mask, so `out_dir` must be defined + if you would like to save predicted mask. Defaults to 'vis'. + pred_out_dir (str): Subdirectory of `out_dir`, used to save + predicted mask file, so `out_dir` must be defined if you would + like to save predicted mask. Defaults to 'pred'. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + + Returns: + dict: Inference and visualization results. + """ + + if out_dir != '': + pred_out_dir = osp.join(out_dir, pred_out_dir) + img_out_dir = osp.join(out_dir, img_out_dir) + else: + pred_out_dir = '' + img_out_dir = '' + + return super().__call__( + inputs=inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + show=show, + wait_time=wait_time, + img_out_dir=img_out_dir, + pred_out_dir=pred_out_dir, + return_vis=return_vis, + **kwargs) + + def visualize(self, + inputs: list, + preds: List[dict], + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + img_out_dir: str = '', + opacity: float = 0.8, + with_labels: Optional[bool] = True) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + img_out_dir (str): Output directory of rendering prediction i.e. + color segmentation mask. Defaults: '' + opacity (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Returns: + List[np.ndarray]: Visualization results. + """ + if not show and img_out_dir == '' and not return_vis: + return None + if self.visualizer is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + self.visualizer.set_dataset_meta(**self.model.dataset_meta) + self.visualizer.alpha = opacity + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + '_vis' + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type:' + f'{type(single_input)}') + + out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\ + else None + + self.visualizer.add_datasample( + img_name, + img, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=True, + out_file=out_file, + with_labels=with_labels) + if return_vis: + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results if return_vis else None + + def postprocess(self, + preds: PredType, + visualization: List[np.ndarray], + return_datasample: bool = False, + pred_out_dir: str = '') -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Pack the predictions and visualization results and return them. + 2. Save the predictions, if it needed. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (List[np.ndarray]): The list of rendering color + segmentation mask. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + pred_out_dir: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (List[np.ndarray], np.ndarray): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it will be the segmentation mask + with label indice. + """ + if return_datasample: + if len(preds) == 1: + return preds[0] + else: + return preds + + results_dict = {} + + results_dict['predictions'] = [] + results_dict['visualization'] = [] + + for i, pred in enumerate(preds): + pred_data = dict() + if 'pred_sem_seg' in pred.keys(): + pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0] + elif 'pred_depth_map' in pred.keys(): + pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0] + + if visualization is not None: + vis = visualization[i] + results_dict['visualization'].append(vis) + if pred_out_dir != '': + mmengine.mkdir_or_exist(pred_out_dir) + for key, data in pred_data.items(): + post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy' + img_name = str(self.num_pred_imgs).zfill(8) + post_fix + img_path = osp.join(pred_out_dir, img_name) + if key == 'sem_seg': + output = Image.fromarray(data.astype(np.uint8)) + output.save(img_path) + else: + np.save(img_path, data) + pred_data = next(iter(pred_data.values())) + results_dict['predictions'].append(pred_data) + self.num_pred_imgs += 1 + + if len(results_dict['predictions']) == 1: + results_dict['predictions'] = results_dict['predictions'][0] + if visualization is not None: + results_dict['visualization'] = \ + results_dict['visualization'][0] + return results_dict + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + # Loading annotations is also not applicable + for transform in ('LoadAnnotations', 'LoadDepthAnnotation'): + idx = self._get_transform_idx(pipeline_cfg, transform) + if idx != -1: + del pipeline_cfg[idx] + + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader' + return Compose(pipeline_cfg) + + def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: + """Returns the index of the transform in a pipeline. + + If the transform is not found, returns -1. + """ + for i, transform in enumerate(pipeline_cfg): + if transform['type'] == name: + return i + return -1 diff --git a/mmseg/apis/remote_sense_inferencer.py b/mmseg/apis/remote_sense_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..6726c6ae3464b3911f7e69b14a0baf35cffc66d0 --- /dev/null +++ b/mmseg/apis/remote_sense_inferencer.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import threading +from queue import Queue +from typing import List, Optional, Tuple + +import numpy as np +import torch +from mmengine import Config +from mmengine.model import BaseModel +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint + +try: + from osgeo import gdal +except ImportError: + gdal = None + +from mmseg.registry import MODELS +from .utils import _preprare_data + + +class RSImage: + """Remote sensing image class. + + Args: + img (str or gdal.Dataset): Image file path or gdal.Dataset. + """ + + def __init__(self, image): + self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance( + image, str) else image + assert isinstance(self.dataset, gdal.Dataset), \ + f'{image} is not a image' + self.width = self.dataset.RasterXSize + self.height = self.dataset.RasterYSize + self.channel = self.dataset.RasterCount + self.trans = self.dataset.GetGeoTransform() + self.proj = self.dataset.GetProjection() + self.band_list = [] + self.band_list.extend( + self.dataset.GetRasterBand(c + 1) for c in range(self.channel)) + self.grids = [] + + def read(self, grid: Optional[List] = None) -> np.ndarray: + """Read image data. If grid is None, read the whole image. + + Args: + grid (Optional[List], optional): Grid to read. Defaults to None. + Returns: + np.ndarray: Image data. + """ + if grid is None: + return np.einsum('ijk->jki', self.dataset.ReadAsArray()) + assert len( + grid) >= 4, 'grid must be a list containing at least 4 elements' + data = self.dataset.ReadAsArray(*grid[:4]) + if data.ndim == 2: + data = data[np.newaxis, ...] + return np.einsum('ijk->jki', data) + + def write(self, data: Optional[np.ndarray], grid: Optional[List] = None): + """Write image data. + + Args: + grid (Optional[List], optional): Grid to write. Defaults to None. + data (Optional[np.ndarray], optional): Data to write. + Defaults to None. + + Raises: + ValueError: Either grid or data must be provided. + """ + if grid is not None: + assert len(grid) == 8, 'grid must be a list of 8 elements' + for band in self.band_list: + band.WriteArray( + data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]], + grid[0] + grid[4], grid[1] + grid[5]) + elif data is not None: + for i in range(self.channel): + self.band_list[i].WriteArray(data[..., i]) + else: + raise ValueError('Either grid or data must be provided.') + + def create_seg_map(self, output_path: Optional[str] = None): + if output_path is None: + output_path = 'output_label.tif' + driver = gdal.GetDriverByName('GTiff') + seg_map = driver.Create(output_path, self.width, self.height, 1, + gdal.GDT_Byte) + seg_map.SetGeoTransform(self.trans) + seg_map.SetProjection(self.proj) + seg_map_img = RSImage(seg_map) + seg_map_img.path = output_path + return seg_map_img + + def create_grids(self, + window_size: Tuple[int, int], + stride: Tuple[int, int] = (0, 0)): + """Create grids for image inference. + + Args: + window_size (Tuple[int, int]): the size of the sliding window. + stride (Tuple[int, int], optional): the stride of the sliding + window. Defaults to (0, 0). + + Raises: + AssertionError: window_size must be a tuple of 2 elements. + AssertionError: stride must be a tuple of 2 elements. + """ + assert len( + window_size) == 2, 'window_size must be a tuple of 2 elements' + assert len(stride) == 2, 'stride must be a tuple of 2 elements' + win_w, win_h = window_size + stride_x, stride_y = stride + + stride_x = win_w if stride_x == 0 else stride_x + stride_y = win_h if stride_y == 0 else stride_y + + x_half_overlap = (win_w - stride_x + 1) // 2 + y_half_overlap = (win_h - stride_y + 1) // 2 + + for y in range(0, self.height, stride_y): + y_end = y + win_h >= self.height + y_offset = self.height - win_h if y_end else y + y_size = win_h + y_crop_off = 0 if y_offset == 0 else y_half_overlap + y_crop_size = y_size if y_end else win_h - y_crop_off + + for x in range(0, self.width, stride_x): + x_end = x + win_w >= self.width + x_offset = self.width - win_w if x_end else x + x_size = win_w + x_crop_off = 0 if x_offset == 0 else x_half_overlap + x_crop_size = x_size if x_end else win_w - x_crop_off + + self.grids.append([ + x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off, + x_crop_size, y_crop_size + ]) + + +class RSInferencer: + """Remote sensing inference class. + + Args: + model (BaseModel): The loaded model. + batch_size (int, optional): Batch size. Defaults to 1. + thread (int, optional): Number of threads. Defaults to 1. + """ + + def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1): + self.model = model + self.batch_size = batch_size + self.END_FLAG = object() + self.read_buffer = Queue(self.batch_size) + self.write_buffer = Queue(self.batch_size) + self.thread = thread + + @classmethod + def from_config_path(cls, + config_path: str, + checkpoint_path: str, + batch_size: int = 1, + thread: int = 1, + device: Optional[str] = 'cpu'): + """Initialize a segmentor from config file. + + Args: + config_path (str): Config file path. + checkpoint_path (str): Checkpoint path. + batch_size (int, optional): Batch size. Defaults to 1. + """ + init_default_scope('mmseg') + cfg = Config.fromfile(config_path) + model = MODELS.build(cfg.model) + model.cfg = cfg + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + model.eval() + return cls(model, batch_size, thread) + + @classmethod + def from_model(cls, + model: BaseModel, + checkpoint_path: Optional[str] = None, + batch_size: int = 1, + thread: int = 1, + device: Optional[str] = 'cpu'): + """Initialize a segmentor from model. + + Args: + model (BaseModel): The loaded model. + checkpoint_path (Optional[str]): Checkpoint path. + batch_size (int, optional): Batch size. Defaults to 1. + """ + if checkpoint_path is not None: + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + return cls(model, batch_size, thread) + + def read(self, + image: RSImage, + window_size: Tuple[int, int], + strides: Tuple[int, int] = (0, 0)): + """Load image data to read buffer. + + Args: + image (RSImage): The image to read. + window_size (Tuple[int, int]): The size of the sliding window. + strides (Tuple[int, int], optional): The stride of the sliding + window. Defaults to (0, 0). + """ + image.create_grids(window_size, strides) + for grid in image.grids: + self.read_buffer.put([grid, image.read(grid=grid)]) + self.read_buffer.put(self.END_FLAG) + + def inference(self): + """Inference image data from read buffer and put the result to write + buffer.""" + while True: + item = self.read_buffer.get() + if item == self.END_FLAG: + self.read_buffer.put(self.END_FLAG) + self.write_buffer.put(item) + break + data, _ = _preprare_data(item[1], self.model) + with torch.no_grad(): + result = self.model.test_step(data) + item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0] + self.write_buffer.put(item) + self.read_buffer.task_done() + + def write(self, image: RSImage, output_path: Optional[str] = None): + """Write image data from write buffer. + + Args: + image (RSImage): The image to write. + output_path (Optional[str], optional): The path to save the + segmentation map. Defaults to None. + """ + seg_map = image.create_seg_map(output_path) + while True: + item = self.write_buffer.get() + if item == self.END_FLAG: + break + seg_map.write(data=item[1], grid=item[0]) + self.write_buffer.task_done() + + def run(self, + image: RSImage, + window_size: Tuple[int, int], + strides: Tuple[int, int] = (0, 0), + output_path: Optional[str] = None): + """Run inference with multi-threading. + + Args: + image (RSImage): The image to inference. + window_size (Tuple[int, int]): The size of the sliding window. + strides (Tuple[int, int], optional): The stride of the sliding + window. Defaults to (0, 0). + output_path (Optional[str], optional): The path to save the + segmentation map. Defaults to None. + """ + read_thread = threading.Thread( + target=self.read, args=(image, window_size, strides)) + read_thread.start() + inference_threads = [] + for _ in range(self.thread): + inference_thread = threading.Thread(target=self.inference) + inference_thread.start() + inference_threads.append(inference_thread) + write_thread = threading.Thread( + target=self.write, args=(image, output_path)) + write_thread.start() + read_thread.join() + for inference_thread in inference_threads: + inference_thread.join() + write_thread.join() diff --git a/mmseg/apis/utils.py b/mmseg/apis/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf877566028dbb2b966c2888b1ebd1a5f57c330 --- /dev/null +++ b/mmseg/apis/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Sequence, Union + +import numpy as np +from mmengine.dataset import Compose +from mmengine.model import BaseModel + +ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def _preprare_data(imgs: ImageType, model: BaseModel): + + cfg = model.cfg + for t in cfg.test_pipeline: + if t.get('type') == 'LoadAnnotations': + cfg.test_pipeline.remove(t) + + is_batch = True + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + is_batch = False + + if isinstance(imgs[0], np.ndarray): + cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray' + + # TODO: Consider using the singleton pattern to avoid building + # a pipeline for each inference + pipeline = Compose(cfg.test_pipeline) + + data = defaultdict(list) + for img in imgs: + if isinstance(img, np.ndarray): + data_ = dict(img=img) + else: + data_ = dict(img_path=img) + data_ = pipeline(data_) + data['inputs'].append(data_['inputs']) + data['data_samples'].append(data_['data_samples']) + + return data, is_batch diff --git a/mmseg/datasets/.DS_Store b/mmseg/datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9489b33ad40b163ae09fa18a492872d30f925fec Binary files /dev/null and b/mmseg/datasets/.DS_Store differ diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2bdb63d016664bf76c93e2c3ee6f5386905064c --- /dev/null +++ b/mmseg/datasets/__init__.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable +from .ade import ADE20KDataset +from .basesegdataset import BaseCDDataset, BaseSegDataset +from .bdd100k import BDD100KDataset +from .chase_db1 import ChaseDB1Dataset +from .cityscapes import CityscapesDataset +from .coco_stuff import COCOStuffDataset +from .dark_zurich import DarkZurichDataset +from .dataset_wrappers import MultiImageMixDataset +from .decathlon import DecathlonDataset +from .drive import DRIVEDataset +from .dsdl import DSDLSegDataset +from .hrf import HRFDataset +from .isaid import iSAIDDataset +from .isprs import ISPRSDataset +from .levir import LEVIRCDDataset +from .lip import LIPDataset +from .loveda import LoveDADataset +from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2 +from .night_driving import NightDrivingDataset +from .nyu import NYUDataset +from .pascal_context import PascalContextDataset, PascalContextDataset59 +from .potsdam import PotsdamDataset +from .refuge import REFUGEDataset +from .stare import STAREDataset +from .synapse import SynapseDataset +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, + LoadAnnotations, LoadBiomedicalAnnotation, + LoadBiomedicalData, LoadBiomedicalImageFromFile, + LoadImageFromNDArray, LoadMultipleRSImageFromFile, + LoadSingleRSImageFromFile, PackSegInputs, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomMosaic, RandomRotate, RandomRotFlip, Rerange, + ResizeShortestEdge, ResizeToMultiple, RGB2Gray, + SegRescale) +from .voc import PascalVOCDataset + +# yapf: enable +__all__ = [ + 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip', + 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', + 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', + 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', + 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', + 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', + 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', + 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', + 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', + 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', + 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge', + 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', + 'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1', + 'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset', + 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile', + 'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset', + 'NYUDataset' +] diff --git a/mmseg/datasets/__pycache__/__init__.cpython-311.pyc b/mmseg/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22a9efd444804e31573b19ae2f4d5b952543d578 Binary files /dev/null and b/mmseg/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/ade.cpython-311.pyc b/mmseg/datasets/__pycache__/ade.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5018d78d48c4b56f555957e292e8a199004faa66 Binary files /dev/null and b/mmseg/datasets/__pycache__/ade.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/basesegdataset.cpython-311.pyc b/mmseg/datasets/__pycache__/basesegdataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..942edb204e86d8ef3d06cb50b874ede9032f5ede Binary files /dev/null and b/mmseg/datasets/__pycache__/basesegdataset.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/bdd100k.cpython-311.pyc b/mmseg/datasets/__pycache__/bdd100k.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75c075f3a07665411ab254037ac19e9b880048a Binary files /dev/null and b/mmseg/datasets/__pycache__/bdd100k.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/chase_db1.cpython-311.pyc b/mmseg/datasets/__pycache__/chase_db1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a5a6c8922fef51bf996df9d76788bc3d65957b Binary files /dev/null and b/mmseg/datasets/__pycache__/chase_db1.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/cityscapes.cpython-311.pyc b/mmseg/datasets/__pycache__/cityscapes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4172835ea4ee08f3faf9543461899067b1ab2e7 Binary files /dev/null and b/mmseg/datasets/__pycache__/cityscapes.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/coco_stuff.cpython-311.pyc b/mmseg/datasets/__pycache__/coco_stuff.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..245ec1e001af516915bba1bc900bb499853f1a3c Binary files /dev/null and b/mmseg/datasets/__pycache__/coco_stuff.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/dark_zurich.cpython-311.pyc b/mmseg/datasets/__pycache__/dark_zurich.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2e9d58695304cdd967267656e22704ed0675f15 Binary files /dev/null and b/mmseg/datasets/__pycache__/dark_zurich.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/dataset_wrappers.cpython-311.pyc b/mmseg/datasets/__pycache__/dataset_wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..541610ed0f9ac88ab5878602580524025edebd16 Binary files /dev/null and b/mmseg/datasets/__pycache__/dataset_wrappers.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/decathlon.cpython-311.pyc b/mmseg/datasets/__pycache__/decathlon.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04c6013b82579350aa3ef8f299d471336fbcee0b Binary files /dev/null and b/mmseg/datasets/__pycache__/decathlon.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/drive.cpython-311.pyc b/mmseg/datasets/__pycache__/drive.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2817a250950ed52c54d1a2e7ea1fb0890bfe512b Binary files /dev/null and b/mmseg/datasets/__pycache__/drive.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/dsdl.cpython-311.pyc b/mmseg/datasets/__pycache__/dsdl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ee34423fc0020dcf40364abe3aedeeee3ba9495 Binary files /dev/null and b/mmseg/datasets/__pycache__/dsdl.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/hrf.cpython-311.pyc b/mmseg/datasets/__pycache__/hrf.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe421fc5252393fd8f8ab7a61d2989f75d8ae3e0 Binary files /dev/null and b/mmseg/datasets/__pycache__/hrf.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/isaid.cpython-311.pyc b/mmseg/datasets/__pycache__/isaid.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0568fbd08939f3d700fa9d14a4fffb2ea75c89fb Binary files /dev/null and b/mmseg/datasets/__pycache__/isaid.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/isprs.cpython-311.pyc b/mmseg/datasets/__pycache__/isprs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..678c1a1ce54b018b62f8c870cdad71c9352af627 Binary files /dev/null and b/mmseg/datasets/__pycache__/isprs.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/levir.cpython-311.pyc b/mmseg/datasets/__pycache__/levir.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c2d5b58542753629b33a61ae4c1e9d088015ff4 Binary files /dev/null and b/mmseg/datasets/__pycache__/levir.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/lip.cpython-311.pyc b/mmseg/datasets/__pycache__/lip.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5fc264d2b5a472750cdfb4ad9b817fdb9febf05 Binary files /dev/null and b/mmseg/datasets/__pycache__/lip.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/loveda.cpython-311.pyc b/mmseg/datasets/__pycache__/loveda.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cc84eabe79edde946e9df8b43aaefb4677679e4 Binary files /dev/null and b/mmseg/datasets/__pycache__/loveda.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/mapillary.cpython-311.pyc b/mmseg/datasets/__pycache__/mapillary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2a3b2623aca0536ddb4894c9505ce8b9e396f50 Binary files /dev/null and b/mmseg/datasets/__pycache__/mapillary.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/night_driving.cpython-311.pyc b/mmseg/datasets/__pycache__/night_driving.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcb04dd92e17762bd76704fce84577fadf3822af Binary files /dev/null and b/mmseg/datasets/__pycache__/night_driving.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/nyu.cpython-311.pyc b/mmseg/datasets/__pycache__/nyu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ff17604995bad8ea4d7df39670d9d8f4560b6ec Binary files /dev/null and b/mmseg/datasets/__pycache__/nyu.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/pascal_context.cpython-311.pyc b/mmseg/datasets/__pycache__/pascal_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cca0dba6d297b449601ea7113d7f560a8670e0b Binary files /dev/null and b/mmseg/datasets/__pycache__/pascal_context.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/potsdam.cpython-311.pyc b/mmseg/datasets/__pycache__/potsdam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5e23b3ac3d5fe41c71039e46b7459ae7d9a0b9 Binary files /dev/null and b/mmseg/datasets/__pycache__/potsdam.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/refuge.cpython-311.pyc b/mmseg/datasets/__pycache__/refuge.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2722c7bd9407c517ad79922b683d521c9720398b Binary files /dev/null and b/mmseg/datasets/__pycache__/refuge.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/stare.cpython-311.pyc b/mmseg/datasets/__pycache__/stare.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08aa75d13fe1fe56755d822701a9f61672605288 Binary files /dev/null and b/mmseg/datasets/__pycache__/stare.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/synapse.cpython-311.pyc b/mmseg/datasets/__pycache__/synapse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbdaef08957cafb534000c3c0715a63b407bb2d2 Binary files /dev/null and b/mmseg/datasets/__pycache__/synapse.cpython-311.pyc differ diff --git a/mmseg/datasets/__pycache__/voc.cpython-311.pyc b/mmseg/datasets/__pycache__/voc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85be9e8df7aaae2ff5bd2be6de549df6eb76a07c Binary files /dev/null and b/mmseg/datasets/__pycache__/voc.cpython-311.pyc differ diff --git a/mmseg/datasets/ade.py b/mmseg/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..e9bdae7421205f25d39441381d6492e9208a4714 --- /dev/null +++ b/mmseg/datasets/ade.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class ADE20KDataset(BaseSegDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', + 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', + 'person', 'earth', 'door', 'table', 'mountain', 'plant', + 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', + 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', + 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', + 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', + 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', + 'screen door', 'stairway', 'river', 'bridge', 'bookcase', + 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', + 'bench', 'countertop', 'stove', 'palm', 'kitchen island', + 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', + 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', + 'television receiver', 'airplane', 'dirt track', 'apparel', + 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', + 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', + 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', + 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', + 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', + 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', + 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', + 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', + 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag'), + palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/basesegdataset.py b/mmseg/datasets/basesegdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4668c1f561961fb27642fb7c1ac702f626cbb7 --- /dev/null +++ b/mmseg/datasets/basesegdataset.py @@ -0,0 +1,552 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.dataset import BaseDataset, Compose + +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class BaseSegDataset(BaseDataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img_path='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + ignore_index: int = 255, + reduce_zero_label: bool = False, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.seg_map_suffix = seg_map_suffix + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix)) + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + _suffix_len = len(self.img_suffix) + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img[:-_suffix_len] + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list + + +@DATASETS.register_module() +class BaseCDDataset(BaseDataset): + """Custom dataset for change detection. An example of file structure is as + followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── img_dir2 + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The image names in img_dir and img_dir2 should be consistent. + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, img_path2=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + img_suffix2 (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + img_suffix2='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict( + img_path='', img_path2='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + ignore_index: int = 255, + reduce_zero_label: bool = False, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.img_suffix2 = img_suffix2 + self.seg_map_suffix = seg_map_suffix + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + img_dir2 = self.data_prefix.get('img_path2', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if osp.isfile(self.ann_file): + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + if '.' in osp.basename(img_name): + img_name, img_ext = osp.splitext(img_name) + self.img_suffix = img_ext + self.img_suffix2 = img_ext + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix), + img_path2=osp.join(img_dir2, img_name + self.img_suffix2)) + + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + if '.' in osp.basename(img): + img, img_ext = osp.splitext(img) + self.img_suffix = img_ext + self.img_suffix2 = img_ext + data_info = dict( + img_path=osp.join(img_dir, img + self.img_suffix), + img_path2=osp.join(img_dir2, img + self.img_suffix2)) + if ann_dir is not None: + seg_map = img + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/mmseg/datasets/bdd100k.py b/mmseg/datasets/bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae70b5cb29f2b34c5804129c85622bfcca6767d --- /dev/null +++ b/mmseg/datasets/bdd100k.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmseg.datasets.basesegdataset import BaseSegDataset +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class BDD100KDataset(BaseSegDataset): + METAINFO = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', + 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, + 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], + [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], + [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/chase_db1.py b/mmseg/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..626ddf75e9a2a10a09ca1f298f12f4290268d504 --- /dev/null +++ b/mmseg/datasets/chase_db1.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class ChaseDB1Dataset(BaseSegDataset): + """Chase_db1 dataset. + + In segmentation map annotation for Chase_db1, 0 stands for background, + which is included in 2 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_1stHO.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='_1stHO.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/cityscapes.py b/mmseg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f494d62424a39581961ab705b3308e7e07bee110 --- /dev/null +++ b/mmseg/datasets/cityscapes.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class CityscapesDataset(BaseSegDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + METAINFO = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', + 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, + 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], + [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], + [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) + + def __init__(self, + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtFine_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/coco_stuff.py b/mmseg/datasets/coco_stuff.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1574d9702330cc5b10bab084841df61e7121ff --- /dev/null +++ b/mmseg/datasets/coco_stuff.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class COCOStuffDataset(BaseSegDataset): + """COCO-Stuff dataset. + + In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version + are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff + 164k is from 0 to 170, where 255 is the ignore index. So, they are all 171 + semantic categories. ``reduce_zero_label`` is set to True and False for the + 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', + and ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', + 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', + 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', + 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', + 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', + 'stone', 'straw', 'structural-other', 'table', 'tent', + 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', + 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone', + 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood'), + palette=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], + [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], + [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], + [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], + [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], + [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], + [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], + [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], + [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], + [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], + [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], + [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], + [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], + [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0], + [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128], + [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224], + [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128], + [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192], + [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224], + [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0], + [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192], + [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224], + [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128], + [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128], + [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160], + [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64], + [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128], + [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160], + [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192], + [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192], + [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160], + [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64], + [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192], + [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], + [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], + [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], + [64, 192, 96], [64, 160, 64], [64, 64, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/dark_zurich.py b/mmseg/datasets/dark_zurich.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5393fa9e5047e81790f91829cfe4b7f33cc707 --- /dev/null +++ b/mmseg/datasets/dark_zurich.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class DarkZurichDataset(CityscapesDataset): + """DarkZurichDataset dataset.""" + + def __init__(self, + img_suffix='_rgb_anon.png', + seg_map_suffix='_gt_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..082c116ff4582ecc7064dba1aba3c164dd556af5 --- /dev/null +++ b/mmseg/datasets/dataset_wrappers.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import copy +from typing import List, Optional, Sequence, Union + +from mmengine.dataset import ConcatDataset, force_full_init + +from mmseg.registry import DATASETS, TRANSFORMS + + +@DATASETS.register_module() +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. + + Args: + dataset (ConcatDataset or dict): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + """ + + def __init__(self, + dataset: Union[ConcatDataset, dict], + pipeline: Sequence[dict], + skip_type_keys: Optional[List[str]] = None, + lazy_init: bool = False) -> None: + assert isinstance(pipeline, collections.abc.Sequence) + + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, ConcatDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`ConcatDataset` instance, but got {type(dataset)}') + + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + transform = TRANSFORMS.build(transform) + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self._metainfo = self.dataset.metainfo + self.num_samples = len(self.dataset) + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of the multi-image-mixed dataset. + + Returns: + dict: The meta information of multi-image-mixed dataset. + """ + return copy.deepcopy(self._metainfo) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._ori_len = len(self.dataset) + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indices'): + indices = transform.get_indices(self.dataset) + if not isinstance(indices, collections.abc.Sequence): + indices = [indices] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indices + ] + results['mix_results'] = mix_results + + results = transform(results) + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. + + It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys diff --git a/mmseg/datasets/decathlon.py b/mmseg/datasets/decathlon.py new file mode 100644 index 0000000000000000000000000000000000000000..26aa4ef0d7f44e55d4400ed6151ea1f6cb3930ec --- /dev/null +++ b/mmseg/datasets/decathlon.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import List + +from mmengine.fileio import load + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class DecathlonDataset(BaseSegDataset): + """Dataset for Dacathlon dataset. + + The dataset.json format is shown as follows + + .. code-block:: none + + { + "name": "BRATS", + "tensorImageSize": "4D", + "modality": + { + "0": "FLAIR", + "1": "T1w", + "2": "t1gd", + "3": "T2w" + }, + "labels": { + "0": "background", + "1": "edema", + "2": "non-enhancing tumor", + "3": "enhancing tumour" + }, + "numTraining": 484, + "numTest": 266, + "training": + [ + { + "image": "./imagesTr/BRATS_306.nii.gz" + "label": "./labelsTr/BRATS_306.nii.gz" + ... + } + ] + "test": + [ + "./imagesTs/BRATS_557.nii.gz" + ... + ] + } + """ + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + # `self.ann_file` denotes the absolute annotation file path if + # `self.root=None` or relative path if `self.root=/path/to/data/`. + annotations = load(self.ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + raw_data_list = annotations[ + 'training'] if not self.test_mode else annotations['test'] + data_list = [] + for raw_data_info in raw_data_list: + # `2:` works for removing './' in file path, which will break + # loading from cloud storage. + if isinstance(raw_data_info, dict): + data_info = dict( + img_path=osp.join(self.data_root, raw_data_info['image'] + [2:])) + data_info['seg_map_path'] = osp.join( + self.data_root, raw_data_info['label'][2:]) + else: + data_info = dict( + img_path=osp.join(self.data_root, raw_data_info)[2:]) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + annotations.pop('training') + annotations.pop('test') + + metainfo = copy.deepcopy(annotations) + metainfo['classes'] = [*metainfo['labels'].values()] + # Meta information load from annotation file will not influence the + # existed meta information load from `BaseDataset.METAINFO` and + # `metainfo` arguments defined in constructor. + for k, v in metainfo.items(): + self._metainfo.setdefault(k, v) + + return data_list diff --git a/mmseg/datasets/drive.py b/mmseg/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..76c0160a6b6bf4a56ff135620ff0b08dc086d1d9 --- /dev/null +++ b/mmseg/datasets/drive.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class DRIVEDataset(BaseSegDataset): + """DRIVE dataset. + + In segmentation map annotation for DRIVE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='_manual1.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/dsdl.py b/mmseg/datasets/dsdl.py new file mode 100644 index 0000000000000000000000000000000000000000..bf7e4e61b5fdd4bcb34617c8e53b93829def443a --- /dev/null +++ b/mmseg/datasets/dsdl.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Dict, List, Optional, Sequence, Union + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + +try: + from dsdl.dataset import DSDLDataset +except ImportError: + DSDLDataset = None + + +@DATASETS.register_module() +class DSDLSegDataset(BaseSegDataset): + """Dataset for dsdl segmentation. + + Args: + specific_key_path(dict): Path of specific key which can not + be loaded by it's field name. + pre_transform(dict): pre-transform functions before loading. + used_labels(sequence): list of actual used classes in train steps, + this must be subset of class domain. + """ + + METAINFO = {} + + def __init__(self, + specific_key_path: Dict = {}, + pre_transform: Dict = {}, + used_labels: Optional[Sequence] = None, + **kwargs) -> None: + + if DSDLDataset is None: + raise RuntimeError( + 'Package dsdl is not installed. Please run "pip install dsdl".' + ) + self.used_labels = used_labels + + loc_config = dict(type='LocalFileReader', working_dir='') + if kwargs.get('data_root'): + kwargs['ann_file'] = os.path.join(kwargs['data_root'], + kwargs['ann_file']) + required_fields = ['Image', 'LabelMap'] + + self.dsdldataset = DSDLDataset( + dsdl_yaml=kwargs['ann_file'], + location_config=loc_config, + required_fields=required_fields, + specific_key_path=specific_key_path, + transform=pre_transform, + ) + BaseSegDataset.__init__(self, **kwargs) + + def load_data_list(self) -> List[Dict]: + """Load data info from a dsdl yaml file named as ``self.ann_file`` + + Returns: + List[dict]: A list of data list. + """ + + if self.used_labels: + self._metainfo['classes'] = tuple(self.used_labels) + self.label_map = self.get_label_map(self.used_labels) + else: + self._metainfo['classes'] = tuple(['background'] + + self.dsdldataset.class_names) + data_list = [] + + for i, data in enumerate(self.dsdldataset): + datainfo = dict( + img_path=os.path.join(self.data_prefix['img_path'], + data['Image'][0].location), + seg_map_path=os.path.join(self.data_prefix['seg_map_path'], + data['LabelMap'][0].location), + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label, + seg_fields=[], + ) + data_list.append(datainfo) + + return data_list + + def get_label_map(self, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in class_dom + is not equal to new classes in args and nether of them is not + None, `label_map` is not None. + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + Returns: + dict, optional: The mapping from old classes to new classes. + """ + old_classes = ['background'] + self.dsdldataset.class_names + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(old_classes): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in class_dom.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None diff --git a/mmseg/datasets/hrf.py b/mmseg/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..fd669cce26420b7e2c810ecace247a9e09350a5d --- /dev/null +++ b/mmseg/datasets/hrf.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class HRFDataset(BaseSegDataset): + """HRF dataset. + + In segmentation map annotation for HRF, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/isaid.py b/mmseg/datasets/isaid.py new file mode 100644 index 0000000000000000000000000000000000000000..61942ec1ea33e76c65c22d8e7fc71fb8194841dd --- /dev/null +++ b/mmseg/datasets/isaid.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class iSAIDDataset(BaseSegDataset): + """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images + In segmentation map annotation for iSAID dataset, which is included + in 16 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + METAINFO = dict( + classes=('background', 'ship', 'store_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'Ground_Track_Field', + 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', + 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', + 'Harbor'), + palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], + [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], + [0, 127, 191], [0, 127, 255], [0, 100, 155]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='_instance_color_RGB.png', + ignore_index=255, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ignore_index=ignore_index, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/isprs.py b/mmseg/datasets/isprs.py new file mode 100644 index 0000000000000000000000000000000000000000..30af53c569b05c9be1218e9a58655c36c8aa9931 --- /dev/null +++ b/mmseg/datasets/isprs.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class ISPRSDataset(BaseSegDataset): + """ISPRS dataset. + + In segmentation map annotation for ISPRS, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter'), + palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/levir.py b/mmseg/datasets/levir.py new file mode 100644 index 0000000000000000000000000000000000000000..f467481bad70a426381842dba61d85576c196eaf --- /dev/null +++ b/mmseg/datasets/levir.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmseg.registry import DATASETS +from .basesegdataset import BaseCDDataset + + +@DATASETS.register_module() +class LEVIRCDDataset(BaseCDDataset): + """ISPRS dataset. + + In segmentation map annotation for ISPRS, 0 is to ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + + METAINFO = dict( + classes=('background', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.png', + img_suffix2='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + img_suffix2=img_suffix2, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/lip.py b/mmseg/datasets/lip.py new file mode 100644 index 0000000000000000000000000000000000000000..3a32a193aff990ae9f819d4a0a1be82df1d049cb --- /dev/null +++ b/mmseg/datasets/lip.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class LIPDataset(BaseSegDataset): + """LIP dataset. + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', + 'UpperClothes', 'Dress', 'Coat', 'Socks', 'Pants', + 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', + 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', + 'Right-shoe'), + palette=( + [0, 0, 0], + [128, 0, 0], + [255, 0, 0], + [0, 85, 0], + [170, 0, 51], + [255, 85, 0], + [0, 0, 85], + [0, 119, 221], + [85, 85, 0], + [0, 85, 85], + [85, 51, 0], + [52, 86, 128], + [0, 128, 0], + [0, 0, 255], + [51, 170, 221], + [0, 255, 255], + [85, 255, 170], + [170, 255, 85], + [255, 255, 0], + [255, 170, 0], + )) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/loveda.py b/mmseg/datasets/loveda.py new file mode 100644 index 0000000000000000000000000000000000000000..5c16db503adee6f1a1cac67e1dc72ff873ccd5ea --- /dev/null +++ b/mmseg/datasets/loveda.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class LoveDADataset(BaseSegDataset): + """LoveDA dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural'), + palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/mapillary.py b/mmseg/datasets/mapillary.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2947338ec79b3d8558cee0387a2a84e41f0421 --- /dev/null +++ b/mmseg/datasets/mapillary.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class MapillaryDataset_v1(BaseSegDataset): + """Mapillary Vistas Dataset. + + Dataset paper link: + http://ieeexplore.ieee.org/document/8237796/ + + v1.2 contain 66 object classes. + (37 instance-specific) + + v2.0 contain 124 object classes. + (70 instance-specific, 46 stuff, 8 void or crowd). + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png' for Mapillary Vistas Dataset. + """ + METAINFO = dict( + classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', + 'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain', + 'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track', + 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', + 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Crosswalk', + 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', + 'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', + 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', + 'Phone Booth', 'Pothole', 'Street Light', 'Pole', + 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', + 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', + 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle', + 'On Rails', 'Other Vehicle', 'Trailer', 'Truck', + 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'), + palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [102, 102, 156], + [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], + [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], + [200, 128, 128], [255, 255, 255], [64, 170, + 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 220, 220], [220, 128, 128], + [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], + [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], + [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], + [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], + [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, + 10], [0, 0, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) + + +@DATASETS.register_module() +class MapillaryDataset_v2(BaseSegDataset): + """Mapillary Vistas Dataset. + + Dataset paper link: + http://ieeexplore.ieee.org/document/8237796/ + + v1.2 contain 66 object classes. + (37 instance-specific) + + v2.0 contain 124 object classes. + (70 instance-specific, 46 stuff, 8 void or crowd). + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png' for Mapillary Vistas Dataset. + """ + METAINFO = dict( + classes=( + 'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', + 'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median', + 'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall', + 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway', + 'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track', + 'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk', + 'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel', + 'Person', 'Person Group', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Dashed Line', + 'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line', + 'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)', + 'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)', + 'Lane Marking - Arrow (Split Left or Straight)', + 'Lane Marking - Arrow (Split Right or Straight)', + 'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk', + 'Lane Marking - Give Way (Row)', + 'Lane Marking - Give Way (Single)', + 'Lane Marking - Hatched (Chevron)', + 'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other', + 'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)', + 'Lane Marking - Symbol (Other)', 'Lane Marking - Text', + 'Lane Marking (only) - Dashed Line', + 'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other', + 'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow', + 'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack', + 'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box', + 'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole', + 'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back', + 'Signage - Information', 'Signage - Other', 'Signage - Store', + 'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame', + 'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)', + 'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)', + 'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists', + 'Traffic Light - Other', 'Traffic Sign - Ambiguous', + 'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)', + 'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)', + 'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)', + 'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', + 'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve', + 'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', + 'Unlabeled'), + palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], + [196, 196, 196], [190, 153, 153], [180, 165, 180], + [90, 120, 150], [250, 170, 33], [250, 170, 34], + [128, 128, 128], [250, 170, 35], [102, 102, 156], + [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [110, 110, 110], [244, 35, 232], [128, 196, + 128], [150, 100, 100], + [70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60], + [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], + [255, 255, 255], [255, 255, 255], [250, 170, 29], + [250, 170, 28], [250, 170, 26], [250, 170, + 25], [250, 170, 24], + [250, 170, 22], [250, 170, 21], [250, 170, + 20], [255, 255, 255], + [250, 170, 19], [250, 170, 18], [250, 170, + 12], [250, 170, 11], + [255, 255, 255], [255, 255, 255], [250, 170, 16], + [250, 170, 15], [250, 170, 15], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [255, 255, 255], + [64, 170, 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 128, 128], [222, 40, + 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], + [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], + [250, 173, 30], [250, 174, 30], [250, 175, + 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], + [128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, + 30], [250, 170, 30], + [250, 170, 30], [192, 192, 192], [192, 192, 192], + [192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196], + [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], + [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], + [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/night_driving.py b/mmseg/datasets/night_driving.py new file mode 100644 index 0000000000000000000000000000000000000000..3ead91ec77cbd8e3f0a870dee3462549183e9c9b --- /dev/null +++ b/mmseg/datasets/night_driving.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class NightDrivingDataset(CityscapesDataset): + """NightDrivingDataset dataset.""" + + def __init__(self, + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtCoarse_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/nyu.py b/mmseg/datasets/nyu.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfda46647d25b5d16425af97a06ffb8c1f81bca --- /dev/null +++ b/mmseg/datasets/nyu.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class NYUDataset(BaseSegDataset): + """NYU depth estimation dataset. The file structure should be. + + .. code-block:: none + + ├── data + │ ├── nyu + │ │ ├── images + │ │ │ ├── train + │ │ │ │ ├── scene_xxx.jpg + │ │ │ │ ├── ... + │ │ │ ├── test + │ │ ├── annotations + │ │ │ ├── train + │ │ │ │ ├── scene_xxx.png + │ │ │ │ ├── ... + │ │ │ ├── test + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path='images', depth_map_path='annotations'). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO = dict( + classes=('printer_room', 'bathroom', 'living_room', 'study', + 'conference_room', 'study_room', 'kitchen', 'home_office', + 'bedroom', 'dinette', 'playroom', 'indoor_balcony', + 'laundry_room', 'basement', 'excercise_room', 'foyer', + 'home_storage', 'cafe', 'furniture_store', 'office_kitchen', + 'student_lounge', 'dining_room', 'reception_room', + 'computer_lab', 'classroom', 'office', 'bookstore')) + + def __init__(self, + data_prefix=dict( + img_path='images', depth_map_path='annotations'), + img_suffix='.jpg', + depth_map_suffix='.png', + **kwargs) -> None: + super().__init__( + data_prefix=data_prefix, + img_suffix=img_suffix, + seg_map_suffix=depth_map_suffix, + **kwargs) + + def _get_category_id_from_filename(self, image_fname: str) -> int: + """Retrieve the category ID from the given image filename.""" + image_fname = osp.basename(image_fname) + position = image_fname.find(next(filter(str.isdigit, image_fname)), 0) + categoty_name = image_fname[:position - 1] + if categoty_name not in self._metainfo['classes']: + return -1 + else: + return self._metainfo['classes'].index(categoty_name) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('depth_map_path', None) + + _suffix_len = len(self.img_suffix) + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + depth_map = img[:-_suffix_len] + self.seg_map_suffix + data_info['depth_map_path'] = osp.join(ann_dir, depth_map) + data_info['seg_fields'] = [] + data_info['category_id'] = self._get_category_id_from_filename(img) + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/mmseg/datasets/pascal_context.py b/mmseg/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..82d00a9b3086a0db81457ab9b2f79c79de4ffaa8 --- /dev/null +++ b/mmseg/datasets/pascal_context.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class PascalContextDataset(BaseSegDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + ann_file (str): Annotation file path. + """ + + METAINFO = dict( + classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes', + 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', + 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', + 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', + 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', + 'horse', 'keyboard', 'light', 'motorbike', 'mountain', + 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', + 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', + 'sofa', 'table', 'track', 'train', 'tree', 'truck', + 'tvmonitor', 'wall', 'water', 'window', 'wood'), + palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) + + def __init__(self, + ann_file='', + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ann_file=ann_file, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert fileio.exists(self.data_prefix['img_path'], self.backend_args) + + +@DATASETS.register_module() +class PascalContextDataset59(BaseSegDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + True. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + Noted: If the background is 255 and the ids of categories are from 0 to 58, + ``reduce_zero_label`` needs to be set to False. + + Args: + ann_file (str): Annotation file path. + """ + METAINFO = dict( + classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', + 'bird', 'boat', 'book', 'bottle', 'building', 'bus', + 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', + 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', + 'floor', 'flower', 'food', 'grass', 'ground', 'horse', + 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', + 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', + 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', + 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', + 'wall', 'water', 'window', 'wood'), + palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) + + def __init__(self, + ann_file='', + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs): + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ann_file=ann_file, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert fileio.exists(self.data_prefix['img_path'], self.backend_args) diff --git a/mmseg/datasets/potsdam.py b/mmseg/datasets/potsdam.py new file mode 100644 index 0000000000000000000000000000000000000000..6892de3dd29fda569527342377c6e83ce0d972bf --- /dev/null +++ b/mmseg/datasets/potsdam.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class PotsdamDataset(BaseSegDataset): + """ISPRS Potsdam dataset. + + In segmentation map annotation for Potsdam dataset, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter'), + palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/refuge.py b/mmseg/datasets/refuge.py new file mode 100644 index 0000000000000000000000000000000000000000..4016a825a37cdd0162f9c3e72df2fcabc6984991 --- /dev/null +++ b/mmseg/datasets/refuge.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class REFUGEDataset(BaseSegDataset): + """REFUGE dataset. + + In segmentation map annotation for REFUGE, 0 stands for background, which + is not included in 2 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('background', ' Optic Cup', 'Optic Disc'), + palette=[[120, 120, 120], [6, 230, 230], [56, 59, 120]]) + + def __init__(self, **kwargs) -> None: + super().__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/stare.py b/mmseg/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..1b997bb785f20a9225c8b7e3f9b0522bc5e5ed99 --- /dev/null +++ b/mmseg/datasets/stare.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class STAREDataset(BaseSegDataset): + """STARE dataset. + + In segmentation map annotation for STARE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.ah.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.ah.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/synapse.py b/mmseg/datasets/synapse.py new file mode 100644 index 0000000000000000000000000000000000000000..6f83b6415046667fb24086083c43083040f4487c --- /dev/null +++ b/mmseg/datasets/synapse.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class SynapseDataset(BaseSegDataset): + """Synapse dataset. + + Before dataset preprocess of Synapse, there are total 13 categories of + foreground which does not include background. After preprocessing, 8 + foreground categories are kept while the other 5 foreground categories are + handled as background. The ``img_suffix`` is fixed to '.jpg' and + ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=('background', 'aorta', 'gallbladder', 'left_kidney', + 'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'), + palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], + [0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255], + [240, 240, 240]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..125f07081810c980ebc6ded077bcf5dfd955cfcf --- /dev/null +++ b/mmseg/datasets/transforms/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .formatting import PackSegInputs +from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, + LoadBiomedicalData, LoadBiomedicalImageFromFile, + LoadDepthAnnotation, LoadImageFromNDArray, + LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile) +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomDepthMix, RandomFlip, RandomMosaic, + RandomRotate, RandomRotFlip, Rerange, Resize, + ResizeShortestEdge, ResizeToMultiple, RGB2Gray, + SegRescale) + +# yapf: enable +__all__ = [ + 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', + 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', + 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', + 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', + 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', + 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput', + 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix', + 'RandomFlip', 'Resize' +] diff --git a/mmseg/datasets/transforms/__pycache__/__init__.cpython-311.pyc b/mmseg/datasets/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bfc8001ebf18c59513e292c581fc62ac10d7a6c Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/datasets/transforms/__pycache__/formatting.cpython-311.pyc b/mmseg/datasets/transforms/__pycache__/formatting.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a3b96dba0126957699e13fa3f3ceb48f5f0fb8b Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/formatting.cpython-311.pyc differ diff --git a/mmseg/datasets/transforms/__pycache__/loading.cpython-311.pyc b/mmseg/datasets/transforms/__pycache__/loading.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0081b30c696c228f2d28c3ae8997cca255757dc Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/loading.cpython-311.pyc differ diff --git a/mmseg/datasets/transforms/__pycache__/transforms.cpython-311.pyc b/mmseg/datasets/transforms/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7666fb6bcc17d507ea828e97d639148c93a6352c Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/transforms.cpython-311.pyc differ diff --git a/mmseg/datasets/transforms/formatting.py b/mmseg/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..bd250551e98ffc9decaa2e168943821501844c1f --- /dev/null +++ b/mmseg/datasets/transforms/formatting.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmengine.structures import PixelData + +from mmseg.registry import TRANSFORMS +from mmseg.structures import SegDataSample + + +@TRANSFORMS.register_module() +class PackSegInputs(BaseTransform): + """Pack the inputs data for the semantic segmentation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: + + - ``img_path``: filename of the image + + - ``ori_shape``: original shape of the image as a tuple (h, w, c) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w, c). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``pad_shape``: shape of padded images + + - ``scale_factor``: a float indicating the preprocessing scale + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + Args: + meta_keys (Sequence[str], optional): Meta keys to be packed from + ``SegDataSample`` and collected in ``data[img_metas]``. + Default: ``('img_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction')`` + """ + + def __init__(self, + meta_keys=('img_path', 'seg_map_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'reduce_zero_label')): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`SegDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results['inputs'] = img + + data_sample = SegDataSample() + if 'gt_seg_map' in results: + if len(results['gt_seg_map'].shape) == 2: + data = to_tensor(results['gt_seg_map'][None, + ...].astype(np.int64)) + else: + warnings.warn('Please pay attention your ground truth ' + 'segmentation map, usually the segmentation ' + 'map is 2D, but got ' + f'{results["gt_seg_map"].shape}') + data = to_tensor(results['gt_seg_map'].astype(np.int64)) + gt_sem_seg_data = dict(data=data) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + if 'gt_edge_map' in results: + gt_edge_data = dict( + data=to_tensor(results['gt_edge_map'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data))) + + if 'gt_depth_map' in results: + gt_depth_data = dict( + data=to_tensor(results['gt_depth_map'][None, ...])) + data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data))) + + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..438b5527f08d4aa7b66a7ba972af05f34dd192ff --- /dev/null +++ b/mmseg/datasets/transforms/loading.py @@ -0,0 +1,704 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, Optional, Union + +import mmcv +import mmengine.fileio as fileio +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile + +from mmseg.registry import TRANSFORMS +from mmseg.utils import datafrombytes + +try: + from osgeo import gdal +except ImportError: + gdal = None + + +@TRANSFORMS.register_module() +class LoadAnnotations(MMCV_LoadAnnotations): + """Load annotations for semantic segmentation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of semantic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of semantic segmentation ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_zero_label (bool, optional): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + reduce_zero_label=None, + backend_args=None, + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + img_bytes = fileio.get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + + # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['seg_fields'].append('gt_seg_map') + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class LoadImageFromNDArray(LoadImageFromFile): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results['img'] + if self.to_float32: + img = img.astype(np.float32) + + results['img_path'] = None + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + +@TRANSFORMS.register_module() +class LoadBiomedicalImageFromFile(BaseTransform): + """Load an biomedical mage from file. + + Required Keys: + + - img_path + + Added Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities, and data type is float32 + if set to_float32 = True, or float64 if decode_backend is 'nifti' and + to_float32 is False. + - img_shape + - ori_shape + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an float64 array. + Defaults to True. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'nifti', + to_xyz: bool = False, + to_float32: bool = True, + backend_args: Optional[dict] = None) -> None: + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.to_float32 = to_float32 + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + + data_bytes = fileio.get(filename, self.backend_args) + img = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + img = img.astype(np.float32) + + if len(img.shape) == 3: + img = img[None, ...] + + if self.decode_backend == 'nifti': + img = img.transpose(0, 3, 2, 1) + + if self.to_xyz: + img = img.transpose(0, 3, 2, 1) + + results['img'] = img + results['img_shape'] = img.shape[1:] + results['ori_shape'] = img.shape[1:] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadBiomedicalAnnotation(BaseTransform): + """Load ``seg_map`` annotation provided by biomedical dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + 'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X) + } + + Required Keys: + + - seg_map_path + + Added Keys: + + - gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by + default, and data type is float32 if set to_float32 = True, or + float64 if decode_backend is 'nifti' and to_float32 is False. + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + to_float32 (bool): Whether to convert the loaded seg map to a float32 + numpy array. If set to False, the loaded image is an float64 array. + Defaults to True. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See :class:`mmengine.fileio` for details. + Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'nifti', + to_xyz: bool = False, + to_float32: bool = True, + backend_args: Optional[dict] = None) -> None: + super().__init__() + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.to_float32 = to_float32 + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + data_bytes = fileio.get(results['seg_map_path'], self.backend_args) + gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + gt_seg_map = gt_seg_map.astype(np.float32) + + if self.decode_backend == 'nifti': + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + if self.to_xyz: + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + results['gt_seg_map'] = gt_seg_map + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadBiomedicalData(BaseTransform): + """Load an biomedical image and annotation from file. + + The loading data format is as the following: + + .. code-block:: python + + { + 'img': np.ndarray data[:-1, X, Y, Z] + 'seg_map': np.ndarray data[-1, X, Y, Z] + } + + + Required Keys: + + - img_path + + Added Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + - img_shape + - ori_shape + + Args: + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Defaults to False. + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + with_seg=False, + decode_backend: str = 'numpy', + to_xyz: bool = False, + backend_args: Optional[dict] = None) -> None: # noqa + self.with_seg = with_seg + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + data_bytes = fileio.get(results['img_path'], self.backend_args) + data = datafrombytes(data_bytes, backend=self.decode_backend) + # img is 4D data (N, X, Y, Z), N is the number of protocol + img = data[:-1, :] + + if self.decode_backend == 'nifti': + img = img.transpose(0, 3, 2, 1) + + if self.to_xyz: + img = img.transpose(0, 3, 2, 1) + + results['img'] = img + results['img_shape'] = img.shape[1:] + results['ori_shape'] = img.shape[1:] + + if self.with_seg: + gt_seg_map = data[-1, :] + if self.decode_backend == 'nifti': + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + if self.to_xyz: + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + results['gt_seg_map'] = gt_seg_map + return results + + def __repr__(self) -> str: + repr_str = (f'{self.__class__.__name__}(' + f'with_seg={self.with_seg}, ' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class InferencerLoader(BaseTransform): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='LoadImageFromFile', **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='LoadImageFromNDArray', **kwargs)) + + def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if isinstance(single_input, str): + inputs = dict(img_path=single_input) + elif isinstance(single_input, np.ndarray): + inputs = dict(img=single_input) + elif isinstance(single_input, dict): + inputs = single_input + else: + raise NotImplementedError + + if 'img' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) + + +@TRANSFORMS.register_module() +class LoadSingleRSImageFromFile(BaseTransform): + """Load a Remote Sensing mage from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, to_float32: bool = True): + self.to_float32 = to_float32 + + if gdal is None: + raise RuntimeError('gdal is not installed') + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + ds = gdal.Open(filename) + if ds is None: + raise Exception(f'Unable to open file: {filename}') + img = np.einsum('ijk->jki', ds.ReadAsArray()) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadMultipleRSImageFromFile(BaseTransform): + """Load two Remote Sensing mage from file. + + Required Keys: + + - img_path + - img_path2 + + Modified Keys: + + - img + - img2 + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, to_float32: bool = True): + if gdal is None: + raise RuntimeError('gdal is not installed') + self.to_float32 = to_float32 + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + filename2 = results['img_path2'] + + ds = gdal.Open(filename) + ds2 = gdal.Open(filename2) + + if ds is None: + raise Exception(f'Unable to open file: {filename}') + if ds2 is None: + raise Exception(f'Unable to open file: {filename2}') + + img = np.einsum('ijk->jki', ds.ReadAsArray()) + img2 = np.einsum('ijk->jki', ds2.ReadAsArray()) + + if self.to_float32: + img = img.astype(np.float32) + img2 = img2.astype(np.float32) + + if img.shape != img2.shape: + raise Exception(f'Image shapes do not match:' + f' {img.shape} vs {img2.shape}') + + results['img'] = img + results['img2'] = img2 + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadDepthAnnotation(BaseTransform): + """Load ``depth_map`` annotation provided by depth estimation dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + 'gt_depth_map': np.ndarray [Y, X] + } + + Required Keys: + + - seg_depth_path + + Added Keys: + + - gt_depth_map (np.ndarray): Depth map with shape (Y, X) by + default, and data type is float32 if set to_float32 = True. + - depth_rescale_factor (float): The rescale factor of depth map, which + can be used to recover the original value of depth map. + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'. + to_float32 (bool): Whether to convert the loaded depth map to a float32 + numpy array. If set to False, the loaded image is an uint16 array. + Defaults to True. + depth_rescale_factor (float): Factor to rescale the depth value to + limit the range. Defaults to 1.0. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See :class:`mmengine.fileio` for details. + Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'cv2', + to_float32: bool = True, + depth_rescale_factor: float = 1.0, + backend_args: Optional[dict] = None) -> None: + super().__init__() + self.decode_backend = decode_backend + self.to_float32 = to_float32 + self.depth_rescale_factor = depth_rescale_factor + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load depth map. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded depth map. + """ + data_bytes = fileio.get(results['depth_map_path'], self.backend_args) + gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + gt_depth_map = gt_depth_map.astype(np.float32) + + gt_depth_map *= self.depth_rescale_factor + results['gt_depth_map'] = gt_depth_map + results['seg_fields'].append('gt_depth_map') + results['depth_rescale_factor'] = self.depth_rescale_factor + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..082ae5b4401dce3b90bab888bd754ee164094b88 --- /dev/null +++ b/mmseg/datasets/transforms/transforms.py @@ -0,0 +1,2514 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +import warnings +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import mmcv +import mmengine +import numpy as np +from mmcv.transforms import RandomFlip as MMCV_RandomFlip +from mmcv.transforms import Resize as MMCV_Resize +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_tuple_of +from numpy import random +from scipy.ndimage import gaussian_filter + +from mmseg.datasets.dataset_wrappers import MultiImageMixDataset +from mmseg.registry import TRANSFORMS + +try: + import albumentations + from albumentations import Compose + ALBU_INSTALLED = True +except ImportError: + albumentations = None + Compose = None + ALBU_INSTALLED = False + + +@TRANSFORMS.register_module() +class ResizeToMultiple(BaseTransform): + """Resize images & seg to multiple of divisor. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - pad_shape + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def transform(self, results: dict) -> dict: + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['pad_shape'] = img.shape[:2] + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str + + +@TRANSFORMS.register_module() +class Rerange(BaseTransform): + """Rerange the image pixel value. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def transform(self, results: dict) -> dict: + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@TRANSFORMS.register_module() +class CLAHE(BaseTransform): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def transform(self, results: dict) -> dict: + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, ' \ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Random crop the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - gt_seg_map + + + Args: + crop_size (Union[int, Tuple[int, int]]): Expected size after cropping + with the format of (h, w). If set to an integer, then cropping + width and height are equal to this integer. + cat_max_ratio (float): The maximum ratio that single category could + occupy. + ignore_index (int): The label index to be ignored. Default: 255 + """ + + def __init__(self, + crop_size: Union[int, Tuple[int, int]], + cat_max_ratio: float = 1., + ignore_index: int = 255): + super().__init__() + assert isinstance(crop_size, int) or ( + isinstance(crop_size, tuple) and len(crop_size) == 2 + ), 'The expected crop_size is an integer, or a tuple containing two ' + 'intergers' + + if isinstance(crop_size, int): + crop_size = (crop_size, crop_size) + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + @cache_randomness + def crop_bbox(self, results: dict) -> tuple: + """get a crop bounding box. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: Coordinates of the cropped image. + """ + + def generate_crop_bbox(img: np.ndarray) -> tuple: + """Randomly get a crop bounding box. + + Args: + img (np.ndarray): Original input image. + + Returns: + tuple: Coordinates of the cropped image. + """ + + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + img = results['img'] + crop_bbox = generate_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_seg_map'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = generate_crop_bbox(img) + + return crop_bbox + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.crop_bbox(results) + + # crop the image + img = self.crop(img, crop_bbox) + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + results['img'] = img + results['img_shape'] = img.shape[:2] + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@TRANSFORMS.register_module() +class RandomRotate(BaseTransform): + """Rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + @cache_randomness + def generate_degree(self): + return np.random.rand() < self.prob, np.random.uniform( + min(*self.degree), max(*self.degree)) + + def transform(self, results: dict) -> dict: + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate, degree = self.generate_degree() + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@TRANSFORMS.register_module() +class RGB2Gray(BaseTransform): + """Convert RGB image to grayscale image. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def transform(self, results: dict) -> dict: + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@TRANSFORMS.register_module() +class AdjustGamma(BaseTransform): + """Using gamma correction to process the image. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def transform(self, results: dict) -> dict: + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@TRANSFORMS.register_module() +class SegRescale(BaseTransform): + """Rescale semantic segmentation maps. + + Required Keys: + + - gt_seg_map + + Modified Keys: + + - gt_seg_map + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def transform(self, results: dict) -> dict: + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@TRANSFORMS.register_module() +class PhotoMetricDistortion(BaseTransform): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_delta: int = 18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, + img: np.ndarray, + alpha: int = 1, + beta: int = 0) -> np.ndarray: + """Multiple with alpha and add beat with clip. + + Args: + img (np.ndarray): The input image. + alpha (int): Image weights, change the contrast/saturation + of the image. Default: 1 + beta (int): Image bias, change the brightness of the + image. Default: 0 + + Returns: + np.ndarray: The transformed image. + """ + + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img: np.ndarray) -> np.ndarray: + """Brightness distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after brightness change. + """ + + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img: np.ndarray) -> np.ndarray: + """Contrast distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after contrast change. + """ + + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img: np.ndarray) -> np.ndarray: + """Saturation distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after saturation change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img: np.ndarray) -> np.ndarray: + """Hue distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after hue change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def transform(self, results: dict) -> dict: + """Transform function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str + + +@TRANSFORMS.register_module() +class RandomCutOut(BaseTransform): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): cutout probability. + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. + """ + + def __init__(self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None): + + assert 0 <= prob and prob <= 1 + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) + self.prob = prob + self.n_holes = n_holes + self.fill_in = fill_in + self.seg_fill_in = seg_fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + @cache_randomness + def do_cutout(self): + return np.random.rand() < self.prob + + @cache_randomness + def generate_patches(self, results): + cutout = self.do_cutout() + + h, w, _ = results['img'].shape + if cutout: + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + else: + n_holes = 0 + x1_lst = [] + y1_lst = [] + index_lst = [] + for _ in range(n_holes): + x1_lst.append(np.random.randint(0, w)) + y1_lst.append(np.random.randint(0, h)) + index_lst.append(np.random.randint(0, len(self.candidates))) + return cutout, n_holes, x1_lst, y1_lst, index_lst + + def transform(self, results: dict) -> dict: + """Call function to drop some regions of image.""" + cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches( + results) + if cutout: + h, w, c = results['img'].shape + for i in range(n_holes): + x1 = x1_lst[i] + y1 = y1_lst[i] + index = index_lst[i] + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomRotFlip(BaseTransform): + """Rotate and flip the image & seg or just rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + rotate_prob (float): The probability of rotate image. + flip_prob (float): The probability of rotate&flip image. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + """ + + def __init__(self, rotate_prob=0.5, flip_prob=0.5, degree=(-20, 20)): + self.rotate_prob = rotate_prob + self.flip_prob = flip_prob + assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + + def random_rot_flip(self, results: dict) -> dict: + k = np.random.randint(0, 4) + results['img'] = np.rot90(results['img'], k) + for key in results.get('seg_fields', []): + results[key] = np.rot90(results[key], k) + axis = np.random.randint(0, 2) + results['img'] = np.flip(results['img'], axis=axis).copy() + for key in results.get('seg_fields', []): + results[key] = np.flip(results[key], axis=axis).copy() + return results + + def random_rotate(self, results: dict) -> dict: + angle = np.random.uniform(min(*self.degree), max(*self.degree)) + results['img'] = mmcv.imrotate(results['img'], angle=angle) + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate(results[key], angle=angle) + return results + + def transform(self, results: dict) -> dict: + """Call function to rotate or rotate & flip image, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated or rotated & flipped results. + """ + rotate_flag = 0 + if random.random() < self.rotate_prob: + results = self.random_rotate(results) + rotate_flag = 1 + if random.random() < self.flip_prob and rotate_flag == 0: + results = self.random_rot_flip(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(rotate_prob={self.rotate_prob}, ' \ + f'flip_prob={self.flip_prob}, ' \ + f'degree={self.degree})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomFlip(MMCV_RandomFlip): + """Flip the image & bbox & segmentation map. Added or Updated + keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and gt_depth_map. + There are 3 flip modes: + + - ``prob`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``prob`` . + E.g., ``prob=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + + - ``prob`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``prob/len(direction)``. + E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + + - ``prob`` is list of float, ``direction`` is list of string: + given ``len(prob) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``prob[i]``. + E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with + probability of 0.3, vertically with probability of 0.5. + + Required Keys: + + - img + - gt_bboxes (optional) + - gt_seg_map (optional) + - gt_depth_map (optional) + + Modified Keys: + + - img + - gt_bboxes (optional) + - gt_seg_map (optional) + - gt_depth_map (optional) + + Added Keys: + + - flip + - flip_direction + - swap_seg_labels (optional) + + Args: + prob (float | list[float], optional): The flipping probability. + Defaults to None. + direction(str | list[str]): The flipping direction. Options + If input is a list, the length must equal ``prob``. Each + element in ``prob`` indicates the flip probability of + corresponding direction. Defaults to 'horizontal'. + swap_seg_labels (list, optional): The label pair need to be swapped + for ground truth, like 'left arm' and 'right arm' need to be + swapped after horizontal flipping. For example, ``[(1, 5)]``, + where 1/5 is the label of the left/right arm. Defaults to None. + """ + + def _flip(self, results: dict) -> None: + """Flip images, bounding boxes and semantic segmentation map.""" + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + img_shape = results['img'].shape[:2] + + # flip bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'], + img_shape, + results['flip_direction']) + + # flip seg map + for key in results.get('seg_fields', []): + if results.get(key, None) is not None: + results[key] = self._flip_seg_map( + results[key], direction=results['flip_direction']).copy() + results['swap_seg_labels'] = self.swap_seg_labels + + +@TRANSFORMS.register_module() +class Resize(MMCV_Resize): + """Resize images & seg & depth map. + + This transform resizes the input image according to ``scale`` or + ``scale_factor``. Seg map, depth map and other relative annotations are + then resized with the same scale factor. + if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to + resize. + + Required Keys: + + - img + - gt_seg_map (optional) + - gt_depth_map (optional) + + Modified Keys: + + - img + - gt_seg_map + - gt_depth_map + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + Args: + scale (int or tuple): Images scales for resizing. Defaults to None + scale_factor (float or tuple[float]): Scale factors for resizing. + Defaults to None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Defaults to False. + clip_object_border (bool): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, the gt + bboxes are allowed to cross the border of images. Therefore, we + don't need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def _resize_seg(self, results: dict) -> None: + """Resize semantic segmentation map with ``results['scale']``.""" + for seg_key in results.get('seg_fields', []): + if results.get(seg_key, None) is not None: + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[seg_key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize( + results[seg_key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results[seg_key] = gt_seg + + +@TRANSFORMS.register_module() +class RandomMosaic(BaseTransform): + """Mosaic augmentation. Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Required Keys: + + - img + - gt_seg_map + - mix_results + + Modified Keys: + + - img + - img_shape + - ori_shape + - gt_seg_map + + Args: + prob (float): mosaic probability. + img_scale (Sequence[int]): Image size after mosaic pipeline of + a single image. The size of the output image is four times + that of a single image. The output image comprises 4 single images. + Default: (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Default: (0.5, 1.5). + pad_val (int): Pad value. Default: 0. + seg_pad_val (int): Pad value of segmentation map. Default: 255. + """ + + def __init__(self, + prob, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + pad_val=0, + seg_pad_val=255): + assert 0 <= prob and prob <= 1 + assert isinstance(img_scale, tuple) + self.prob = prob + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + @cache_randomness + def do_mosaic(self): + return np.random.rand() < self.prob + + def transform(self, results: dict) -> dict: + """Call function to make a mosaic of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mosaic transformed. + """ + mosaic = self.do_mosaic() + if mosaic: + results = self._mosaic_transform_img(results) + results = self._mosaic_transform_seg(results) + return results + + def get_indices(self, dataset: MultiImageMixDataset) -> list: + """Call function to collect indices. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indices. + """ + + indices = [random.randint(0, len(dataset)) for _ in range(3)] + return indices + + @cache_randomness + def generate_mosaic_center(self): + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + return center_x, center_y + + def _mosaic_transform_img(self, results: dict) -> dict: + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + if len(results['img'].shape) == 3: + c = results['img'].shape[2] + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), c), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + self.center_x, self.center_y = self.generate_mosaic_center() + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = result_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape + results['ori_shape'] = mosaic_img.shape + + return results + + def _mosaic_transform_seg(self, results: dict) -> dict: + """Mosaic transform function for label annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + for key in results.get('seg_fields', []): + mosaic_seg = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.seg_pad_val, + dtype=results[key].dtype) + + # mosaic center x, y + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + gt_seg_i = result_patch[key] + h_i, w_i = gt_seg_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + gt_seg_i = mmcv.imresize( + gt_seg_i, + (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)), + interpolation='nearest') + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, gt_seg_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_seg[y1_p:y2_p, x1_p:x2_p] = \ + gt_seg_i[y1_c:y2_c, x1_c:x2_c] + + results[key] = mosaic_seg + + return results + + def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float], + img_shape_wh: Sequence[int]) -> tuple: + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'seg_pad_val={self.pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class GenerateEdge(BaseTransform): + """Generate Edge for CE2P approach. + + Edge will be used to calculate loss of + `CE2P `_. + + Modified from https://github.com/liutinglt/CE2P/blob/master/dataset/target_generation.py # noqa:E501 + + Required Keys: + + - img_shape + - gt_seg_map + + Added Keys: + - gt_edge_map (np.ndarray, uint8): The edge annotation generated from the + seg map by extracting border between different semantics. + + Args: + edge_width (int): The width of edge. Default to 3. + ignore_index (int): Index that will be ignored. Default to 255. + """ + + def __init__(self, edge_width: int = 3, ignore_index: int = 255) -> None: + super().__init__() + self.edge_width = edge_width + self.ignore_index = ignore_index + + def transform(self, results: Dict) -> Dict: + """Call function to generate edge from segmentation map. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with edge mask. + """ + h, w = results['img_shape'] + edge = np.zeros((h, w), dtype=np.uint8) + seg_map = results['gt_seg_map'] + + # down + edge_down = edge[1:h, :] + edge_down[(seg_map[1:h, :] != seg_map[:h - 1, :]) + & (seg_map[1:h, :] != self.ignore_index) & + (seg_map[:h - 1, :] != self.ignore_index)] = 1 + # left + edge_left = edge[:, :w - 1] + edge_left[(seg_map[:, :w - 1] != seg_map[:, 1:w]) + & (seg_map[:, :w - 1] != self.ignore_index) & + (seg_map[:, 1:w] != self.ignore_index)] = 1 + # up_left + edge_upleft = edge[:h - 1, :w - 1] + edge_upleft[(seg_map[:h - 1, :w - 1] != seg_map[1:h, 1:w]) + & (seg_map[:h - 1, :w - 1] != self.ignore_index) & + (seg_map[1:h, 1:w] != self.ignore_index)] = 1 + # up_right + edge_upright = edge[:h - 1, 1:w] + edge_upright[(seg_map[:h - 1, 1:w] != seg_map[1:h, :w - 1]) + & (seg_map[:h - 1, 1:w] != self.ignore_index) & + (seg_map[1:h, :w - 1] != self.ignore_index)] = 1 + + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, + (self.edge_width, self.edge_width)) + edge = cv2.dilate(edge, kernel) + + results['gt_edge_map'] = edge + results['edge_width'] = self.edge_width + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'edge_width={self.edge_width}, ' + repr_str += f'ignore_index={self.ignore_index})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + Copyright (c) Facebook, Inc. and its affiliates. + Licensed under the Apache-2.0 License + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + + - img + - gt_seg_map (optional) + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional)) + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, scale: Union[int, Tuple[int, int]], + max_size: int) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + # Create a empty Resize object + self.resize = TRANSFORMS.build({ + 'type': 'Resize', + 'scale': 0, + 'keep_ratio': True + }) + + def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return (new_w, new_h) + + def transform(self, results: Dict) -> Dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + +@TRANSFORMS.register_module() +class BioMedical3DRandomCrop(BaseTransform): + """Crop the input patch for medical image & segmentation mask. + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + - gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask + with shape (Z, Y, X). + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional) + + Args: + crop_shape (Union[int, Tuple[int, int, int]]): Expected size after + cropping with the format of (z, y, x). If set to an integer, + then cropping width and height are equal to this integer. + keep_foreground (bool): If keep_foreground is True, it will sample a + voxel of foreground classes randomly, and will take it as the + center of the crop bounding-box. Default to True. + """ + + def __init__(self, + crop_shape: Union[int, Tuple[int, int, int]], + keep_foreground: bool = True): + super().__init__() + assert isinstance(crop_shape, int) or ( + isinstance(crop_shape, tuple) and len(crop_shape) == 3 + ), 'The expected crop_shape is an integer, or a tuple containing ' + 'three integers' + + if isinstance(crop_shape, int): + crop_shape = (crop_shape, crop_shape, crop_shape) + assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0 + self.crop_shape = crop_shape + self.keep_foreground = keep_foreground + + def random_sample_location(self, seg_map: np.ndarray) -> dict: + """sample foreground voxel when keep_foreground is True. + + Args: + seg_map (np.ndarray): gt seg map. + + Returns: + dict: Coordinates of selected foreground voxel. + """ + num_samples = 10000 + # at least 1% of the class voxels need to be selected, + # otherwise it may be too sparse + min_percent_coverage = 0.01 + class_locs = {} + foreground_classes = [] + all_classes = np.unique(seg_map) + for c in all_classes: + if c == 0: + # to avoid the segmentation mask full of background 0 + # and the class_locs is just void dictionary {} when it return + # there add a void list for background 0. + class_locs[c] = [] + else: + all_locs = np.argwhere(seg_map == c) + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max( + target_num_samples, + int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[np.random.choice( + len(all_locs), target_num_samples, replace=False)] + class_locs[c] = selected + foreground_classes.append(c) + + selected_voxel = None + if len(foreground_classes) > 0: + selected_class = np.random.choice(foreground_classes) + voxels_of_that_class = class_locs[selected_class] + selected_voxel = voxels_of_that_class[np.random.choice( + len(voxels_of_that_class))] + + return selected_voxel + + def random_generate_crop_bbox(self, margin_z: int, margin_y: int, + margin_x: int) -> tuple: + """Randomly get a crop bounding box. + + Args: + seg_map (np.ndarray): Ground truth segmentation map. + + Returns: + tuple: Coordinates of the cropped image. + """ + offset_z = np.random.randint(0, margin_z + 1) + offset_y = np.random.randint(0, margin_y + 1) + offset_x = np.random.randint(0, margin_x + 1) + crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0] + crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1] + crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2] + + return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 + + def generate_margin(self, results: dict) -> tuple: + """Generate margin of crop bounding-box. + + If keep_foreground is True, it will sample a voxel of foreground + classes randomly, and will take it as the center of the bounding-box, + and return the margin between of the bounding-box and image. + If keep_foreground is False, it will return the difference from crop + shape and image shape. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: The margin for 3 dimensions of crop bounding-box and image. + """ + + seg_map = results['gt_seg_map'] + if self.keep_foreground: + selected_voxel = self.random_sample_location(seg_map) + if selected_voxel is None: + # this only happens if some image does not contain + # foreground voxels at all + warnings.warn(f'case does not contain any foreground classes' + f': {results["img_path"]}') + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + else: + margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2) + margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2) + margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2) + margin_z = max( + 0, min(seg_map.shape[0] - self.crop_shape[0], margin_z)) + margin_y = max( + 0, min(seg_map.shape[1] - self.crop_shape[1], margin_y)) + margin_x = max( + 0, min(seg_map.shape[2] - self.crop_shape[2], margin_x)) + else: + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + + return margin_z, margin_y, margin_x + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + if len(img.shape) == 3: + # crop seg map + img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + else: + # crop image + assert len(img.shape) == 4 + img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + margin = self.generate_margin(results) + crop_bbox = self.random_generate_crop_bbox(*margin) + + # crop the image + img = results['img'] + results['img'] = self.crop(img, crop_bbox) + results['img_shape'] = results['img'].shape[1:] + + # crop semantic seg + seg_map = results['gt_seg_map'] + results['gt_seg_map'] = self.crop(seg_map, crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' + + +@TRANSFORMS.register_module() +class BioMedicalGaussianNoise(BaseTransform): + """Add random Gaussian noise to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + prob (float): Probability to add Gaussian noise for + each sample. Default to 0.1. + mean (float): Mean or “centre” of the distribution. Default to 0.0. + std (float): Standard deviation of distribution. Default to 0.1. + """ + + def __init__(self, + prob: float = 0.1, + mean: float = 0.0, + std: float = 0.1) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 and std >= 0.0 + self.prob = prob + self.mean = mean + self.std = std + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian noise to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + rand_std = np.random.uniform(0, self.std) + noise = np.random.normal( + self.mean, rand_std, size=results['img'].shape) + # noise is float64 array, convert to the results['img'].dtype + noise = noise.astype(results['img'].dtype) + results['img'] = results['img'] + noise + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'mean={self.mean}, ' + repr_str += f'std={self.std})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalGaussianBlur(BaseTransform): + """Add Gaussian blur with random sigma to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + sigma_range (Tuple[float, float]|float): range to randomly + select sigma value. Default to (0.5, 1.0). + prob (float): Probability to apply Gaussian blur + for each sample. Default to 0.2. + prob_per_channel (float): Probability to apply Gaussian blur + for each channel (axis N of the image). Default to 0.5. + different_sigma_per_channel (bool): whether to use different + sigma for each channel (axis N of the image). Default to True. + different_sigma_per_axis (bool): whether to use different + sigma for axis Z, X and Y of the image. Default to True. + """ + + def __init__(self, + sigma_range: Tuple[float, float] = (0.5, 1.0), + prob: float = 0.2, + prob_per_channel: float = 0.5, + different_sigma_per_channel: bool = True, + different_sigma_per_axis: bool = True) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 + assert 0.0 <= prob_per_channel <= 1.0 + assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2 + self.sigma_range = sigma_range + self.prob = prob + self.prob_per_channel = prob_per_channel + self.different_sigma_per_channel = different_sigma_per_channel + self.different_sigma_per_axis = different_sigma_per_axis + + def _get_valid_sigma(self, value_range) -> Tuple[float, ...]: + """Ensure the `value_range` to be either a single value or a sequence + of two values. If the `value_range` is a sequence, generate a random + value with `[value_range[0], value_range[1]]` based on uniform + sampling. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501 + + Args: + value_range (tuple|list|float|int): the input value range + """ + if (isinstance(value_range, (list, tuple))): + if (value_range[0] == value_range[1]): + value = value_range[0] + else: + orig_type = type(value_range[0]) + value = np.random.uniform(value_range[0], value_range[1]) + value = orig_type(value) + return value + + def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray: + """Random generate sigma and apply Gaussian Blur to the data + Args: + data_sample (np.ndarray): data sample with multiple modalities, + the data shape is (N, Z, Y, X) + """ + sigma = None + for c in range(data_sample.shape[0]): + if np.random.rand() < self.prob_per_channel: + # if no `sigma` is generated, generate one + # if `self.different_sigma_per_channel` is True, + # re-generate random sigma for each channel + if (sigma is None or self.different_sigma_per_channel): + if (not self.different_sigma_per_axis): + sigma = self._get_valid_sigma(self.sigma_range) + else: + sigma = [ + self._get_valid_sigma(self.sigma_range) + for _ in data_sample.shape[1:] + ] + # apply gaussian filter with `sigma` + data_sample[c] = gaussian_filter( + data_sample[c], sigma, order=0) + return data_sample + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian blur to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + results['img'] = self._gaussian_blur(results['img']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'prob_per_channel={self.prob_per_channel}, ' + repr_str += f'sigma_range={self.sigma_range}, ' + repr_str += 'different_sigma_per_channel=' \ + f'{self.different_sigma_per_channel}, ' + repr_str += 'different_sigma_per_axis=' \ + f'{self.different_sigma_per_axis})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalRandomGamma(BaseTransform): + """Using random gamma correction to process the biomedical image. + + Modified from + https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501 + With licence: Apache 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + - img + + Args: + prob (float): The probability to perform this transform. Default: 0.5. + gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2). + invert_image (bool): Whether invert the image before applying gamma + augmentation. Default: False. + per_channel (bool): Whether perform the transform each channel + individually. Default: False + retain_stats (bool): Gamma transformation will alter the mean and std + of the data in the patch. If retain_stats=True, the data will be + transformed to match the mean and standard deviation before gamma + augmentation. Default: False. + """ + + def __init__(self, + prob: float = 0.5, + gamma_range: Tuple[float] = (0.5, 2), + invert_image: bool = False, + per_channel: bool = False, + retain_stats: bool = False): + assert 0 <= prob and prob <= 1 + assert isinstance(gamma_range, tuple) and len(gamma_range) == 2 + assert isinstance(invert_image, bool) + assert isinstance(per_channel, bool) + assert isinstance(retain_stats, bool) + self.prob = prob + self.gamma_range = gamma_range + self.invert_image = invert_image + self.per_channel = per_channel + self.retain_stats = retain_stats + + @cache_randomness + def _do_gamma(self): + """Whether do adjust gamma for image.""" + return np.random.rand() < self.prob + + def _adjust_gamma(self, img: np.array): + """Gamma adjustment for image. + + Args: + img (np.array): Input image before gamma adjust. + + Returns: + np.arrays: Image after gamma adjust. + """ + + if self.invert_image: + img = -img + + def _do_adjust(img): + if retain_stats_here: + img_mean = img.mean() + img_std = img.std() + if np.random.random() < 0.5 and self.gamma_range[0] < 1: + gamma = np.random.uniform(self.gamma_range[0], 1) + else: + gamma = np.random.uniform( + max(self.gamma_range[0], 1), self.gamma_range[1]) + img_min = img.min() + img_range = img.max() - img_min # range + img = np.power(((img - img_min) / float(img_range + 1e-7)), + gamma) * img_range + img_min + if retain_stats_here: + img = img - img.mean() + img = img / (img.std() + 1e-8) * img_std + img = img + img_mean + return img + + if not self.per_channel: + retain_stats_here = self.retain_stats + img = _do_adjust(img) + else: + for c in range(img.shape[0]): + img[c] = _do_adjust(img[c]) + if self.invert_image: + img = -img + return img + + def transform(self, results: dict) -> dict: + """Call function to perform random gamma correction + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with random gamma correction performed. + """ + do_gamma = self._do_gamma() + + if do_gamma: + results['img'] = self._adjust_gamma(results['img']) + else: + pass + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'gamma_range={self.gamma_range},' + repr_str += f'invert_image={self.invert_image},' + repr_str += f'per_channel={self.per_channel},' + repr_str += f'retain_stats={self.retain_stats}' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DPad(BaseTransform): + """Pad the biomedical 3d image & biomedical 3d semantic segmentation maps. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - pad_shape (Tuple[int, int, int]): The padded shape. + + Args: + pad_shape (Tuple[int, int, int]): Fixed padding size. + Expected padding shape (Z, Y, X). + pad_val (float): Padding value for biomedical image. + The padding mode is set to "constant". The value + to be filled in padding area. Default: 0. + seg_pad_val (int): Padding value for biomedical 3d semantic + segmentation maps. The padding mode is set to "constant". + The value to be filled in padding area. Default: 0. + """ + + def __init__(self, + pad_shape: Tuple[int, int, int], + pad_val: float = 0., + seg_pad_val: int = 0) -> None: + + # check pad_shape + assert pad_shape is not None + if not isinstance(pad_shape, tuple): + assert len(pad_shape) == 3 + + self.pad_shape = pad_shape + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def _pad_img(self, results: dict) -> None: + """Pad images according to ``self.pad_shape`` + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: The dict contains the padded image and shape + information. + """ + padded_img = self._to_pad( + results['img'], pad_shape=self.pad_shape, pad_val=self.pad_val) + + results['img'] = padded_img + results['pad_shape'] = padded_img.shape[1:] + + def _pad_seg(self, results: dict) -> None: + """Pad semantic segmentation map according to ``self.pad_shape`` if + ``gt_seg_map`` is not None in results dict. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Update the padded gt seg map in dict. + """ + if results.get('gt_seg_map', None) is not None: + pad_gt_seg = self._to_pad( + results['gt_seg_map'][None, ...], + pad_shape=results['pad_shape'], + pad_val=self.seg_pad_val) + results['gt_seg_map'] = pad_gt_seg[1:] + + @staticmethod + def _to_pad(img: np.ndarray, + pad_shape: Tuple[int, int, int], + pad_val: Union[int, float] = 0) -> np.ndarray: + """Pad the given 3d image to a certain shape with specified padding + value. + + Args: + img (ndarray): Biomedical image with shape (N, Z, Y, X) + to be padded. N is the number of modalities. + pad_shape (Tuple[int,int,int]): Expected padding shape (Z, Y, X). + pad_val (float, int): Values to be filled in padding areas + and the padding_mode is set to 'constant'. Default: 0. + + Returns: + ndarray: The padded image. + """ + # compute pad width + d = max(pad_shape[0] - img.shape[1], 0) + pad_d = (d // 2, d - d // 2) + h = max(pad_shape[1] - img.shape[2], 0) + pad_h = (h // 2, h - h // 2) + w = max(pad_shape[2] - img.shape[2], 0) + pad_w = (w // 2, w - w // 2) + + pad_list = [(0, 0), pad_d, pad_h, pad_w] + + img = np.pad(img, pad_list, mode='constant', constant_values=pad_val) + return img + + def transform(self, results: dict) -> dict: + """Call function to pad images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_seg(results) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'pad_shape={self.pad_shape}, ' + repr_str += f'pad_val={self.pad_val}), ' + repr_str += f'seg_pad_val={self.seg_pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DRandomFlip(BaseTransform): + """Flip biomedical 3D images and segmentations. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/spatial_transforms.py # noqa:E501 + + Copyright 2021 Division of + Medical Image Computing, German Cancer Research Center (DKFZ) and Applied + Computer Vision Lab, Helmholtz Imaging Platform. + Licensed under the Apache-2.0 License. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - do_flip + - flip_axes + + Args: + prob (float): Flipping probability. + axes (Tuple[int, ...]): Flipping axes with order 'ZXY'. + swap_label_pairs (Optional[List[Tuple[int, int]]]): + The segmentation label pairs that are swapped when flipping. + """ + + def __init__(self, + prob: float, + axes: Tuple[int, ...], + swap_label_pairs: Optional[List[Tuple[int, int]]] = None): + self.prob = prob + self.axes = axes + self.swap_label_pairs = swap_label_pairs + assert prob >= 0 and prob <= 1 + if axes is not None: + assert max(axes) <= 2 + + @staticmethod + def _flip(img, direction: Tuple[bool, bool, bool]) -> np.ndarray: + if direction[0]: + img[:, :] = img[:, ::-1] + if direction[1]: + img[:, :, :] = img[:, :, ::-1] + if direction[2]: + img[:, :, :, :] = img[:, :, :, ::-1] + return img + + def _do_flip(self, img: np.ndarray) -> Tuple[bool, bool, bool]: + """Call function to determine which axis to flip. + + Args: + img (np.ndarry): Image or segmentation map array. + Returns: + tuple: Flip action, whether to flip on the z, x, and y axes. + """ + flip_c, flip_x, flip_y = False, False, False + if self.axes is not None: + flip_c = 0 in self.axes and np.random.rand() < self.prob + flip_x = 1 in self.axes and np.random.rand() < self.prob + if len(img.shape) == 4: + flip_y = 2 in self.axes and np.random.rand() < self.prob + return flip_c, flip_x, flip_y + + def _swap_label(self, seg: np.ndarray) -> np.ndarray: + out = seg.copy() + for first, second in self.swap_label_pairs: + first_area = (seg == first) + second_area = (seg == second) + out[first_area] = second + out[second_area] = first + return out + + def transform(self, results: Dict) -> Dict: + """Call function to flip and swap pair labels. + + Args: + results (dict): Result dict. + Returns: + dict: Flipped results, 'do_flip', 'flip_axes' keys are added into + result dict. + """ + # get actual flipped axis + if 'do_flip' not in results: + results['do_flip'] = self._do_flip(results['img']) + if 'flip_axes' not in results: + results['flip_axes'] = self.axes + # flip image + results['img'] = self._flip( + results['img'], direction=results['do_flip']) + # flip seg + if results['gt_seg_map'] is not None: + if results['gt_seg_map'].shape != results['img'].shape: + results['gt_seg_map'] = results['gt_seg_map'][None, :] + results['gt_seg_map'] = self._flip( + results['gt_seg_map'], direction=results['do_flip']) + results['gt_seg_map'] = results['gt_seg_map'].squeeze() + # swap label pairs + if self.swap_label_pairs is not None: + results['gt_seg_map'] = self._swap_label(results['gt_seg_map']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, axes={self.axes}, ' \ + f'swap_label_pairs={self.swap_label_pairs})' + return repr_str + + +@TRANSFORMS.register_module() +class Albu(BaseTransform): + """Albumentation augmentation. Adds custom transformations from + Albumentations library. Please, visit + `https://albumentations.readthedocs.io` to get more information. An example + of ``transforms`` is as followed: + + .. code-block:: + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + Args: + transforms (list[dict]): A list of albu transformations + keymap (dict): Contains {'input key':'albumentation-style key'} + update_pad_shape (bool): Whether to update padding shape according to \ + the output shape of the last transform + """ + + def __init__(self, + transforms: List[dict], + keymap: Optional[dict] = None, + update_pad_shape: bool = False): + if not ALBU_INSTALLED: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + + self.transforms = transforms + self.keymap = keymap + self.update_pad_shape = update_pad_shape + + self.aug = Compose([self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = {'img': 'image', 'gt_seg_map': 'mask'} + else: + self.keymap_to_albu = copy.deepcopy(keymap) + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: dict) -> object: + """Build a callable object from a dict containing albu arguments. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + + Returns: + Callable: A callable object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + if not ALBU_INSTALLED: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a valid type or str, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(t) for t in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d: dict, keymap: dict): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, _ in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def transform(self, results): + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + + # Convert to RGB since Albumentations works with RGB images + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB) + + results = self.aug(**results) + + # Convert back to BGR + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR) + + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str + + +@TRANSFORMS.register_module() +class ConcatCDInput(BaseTransform): + """Concat images for change detection. + + Required Keys: + + - img + - img2 + + Args: + input_keys (tuple): Input image keys for change detection. + Default: ('img', 'img2'). + """ + + def __init__(self, input_keys=('img', 'img2')): + self.input_keys = input_keys + + def transform(self, results: dict) -> dict: + img = [] + for input_key in self.input_keys: + img.append(results.pop(input_key)) + results['img'] = np.concatenate(img, axis=2) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(input_keys={self.input_keys}, ' + return repr_str + + +@TRANSFORMS.register_module() +class RandomDepthMix(BaseTransform): + """This class implements the RandomDepthMix transform. + + Args: + prob (float): Probability of applying the transformation. + Defaults to 0.25. + mix_scale_ratio (float): Ratio to scale the mix width. + Defaults to 0.75. + """ + + def __init__( + self, + prob: float = 0.25, + mix_scale_ratio: float = 0.75, + ): + super().__init__() + + self.prob = prob + self.mix_scale_ratio = mix_scale_ratio + + def transform(self, results: dict) -> dict: + if random.random() > self.prob: + return results + + h, w = results['img_shape'][:2] + left = int(w * random.random()) + width_ratio = self.mix_scale_ratio * random.random() + width = int(max(1, (w - left) * width_ratio)) + + img = results['img'] + depth_rescale_factor = results.get('depth_rescale_factor', 1) + depth_map = results['gt_depth_map'] / depth_rescale_factor + + if img.ndim == 3: + for c in range(img.shape[-1]): + img[:, left:left + width, c] = depth_map[:, left:left + width] + elif img.ndim == 2: + img[:, left:left + width] = depth_map[:, left:left + width] + else: + raise ValueError(f'Invalid image shape ({img.shape})') + + results['img'] = img + return results diff --git a/mmseg/datasets/voc.py b/mmseg/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5d6025c03760953a82f80e337185afc51f1386 --- /dev/null +++ b/mmseg/datasets/voc.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class PascalVOCDataset(BaseSegDataset): + """Pascal VOC dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + METAINFO = dict( + classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat', + 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', + 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', + 'sofa', 'train', 'tvmonitor'), + palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128]]) + + def __init__(self, + ann_file, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ann_file=ann_file, + **kwargs) + assert fileio.exists(self.data_prefix['img_path'], + self.backend_args) and osp.isfile(self.ann_file) diff --git a/mmseg/engine/.DS_Store b/mmseg/engine/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..44cc8a586158746674279ef49f82e0b3cdb7914e Binary files /dev/null and b/mmseg/engine/.DS_Store differ diff --git a/mmseg/engine/__init__.py b/mmseg/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98139a0047fd2f076d659ba5aed2cd3452dbd235 --- /dev/null +++ b/mmseg/engine/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import SegVisualizationHook +from .optimizers import (ForceDefaultOptimWrapperConstructor, + LayerDecayOptimizerConstructor, + LearningRateDecayOptimizerConstructor) +from .schedulers import PolyLRRatio + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'SegVisualizationHook', 'PolyLRRatio', + 'ForceDefaultOptimWrapperConstructor' +] diff --git a/mmseg/engine/hooks/__init__.py b/mmseg/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6048088a7fd322890ced17569e855acee826eca --- /dev/null +++ b/mmseg/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .visualization_hook import SegVisualizationHook + +__all__ = ['SegVisualizationHook'] diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..ea238c6969183eee8f31bf0bd97f81c89e73a327 --- /dev/null +++ b/mmseg/engine/hooks/visualization_hook.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Optional, Sequence + +import mmcv +import mmengine.fileio as fileio +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.visualization import Visualizer + +from mmseg.registry import HOOKS +from mmseg.structures import SegDataSample + + +@HOOKS.register_module() +class SegVisualizationHook(Hook): + """Segmentation Visualization Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + show: bool = False, + wait_time: float = 0., + backend_args: Optional[dict] = None): + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args.copy() if backend_args else None + self.draw = draw + if not self.draw: + warnings.warn('The draw is False, it means that the ' + 'hook for visualization will not take ' + 'effect. The results will NOT be ' + 'visualized or stored.') + + def _after_iter(self, + runner: Runner, + batch_idx: int, + data_batch: dict, + outputs: Sequence[SegDataSample], + mode: str = 'val') -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. + mode (str): mode (str): Current mode of runner. Defaults to 'val'. + """ + if self.draw is False or mode == 'train': + return + + if self.every_n_inner_iters(batch_idx, self.interval): + for output in outputs: + img_path = output.img_path + img_bytes = fileio.get( + img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'{mode}_{osp.basename(img_path)}' + + self._visualizer.add_datasample( + window_name, + img, + data_sample=output, + show=self.show, + wait_time=self.wait_time, + step=runner.iter) diff --git a/mmseg/engine/optimizers/__init__.py b/mmseg/engine/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cf58741febfc20ea33664ea8e1b1ac68bbb327 --- /dev/null +++ b/mmseg/engine/optimizers/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .force_default_constructor import ForceDefaultOptimWrapperConstructor +from .layer_decay_optimizer_constructor import ( + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'ForceDefaultOptimWrapperConstructor' +] diff --git a/mmseg/engine/optimizers/force_default_constructor.py b/mmseg/engine/optimizers/force_default_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..12c642ad411bfd547d63c894c84636e2f1896128 --- /dev/null +++ b/mmseg/engine/optimizers/force_default_constructor.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.logging import print_log +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.utils.dl_utils import mmcv_full_available +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from torch.nn import GroupNorm, LayerNorm + +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor): + """Default constructor with forced optimizer settings. + + This constructor extends the default constructor to add an option for + forcing default optimizer settings. This is useful for ensuring that + certain parameters or layers strictly adhere to pre-defined default + settings, regardless of any custom settings specified. + + By default, each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain various fields like 'custom_keys', + 'bias_lr_mult', etc., as well as the additional field + `force_default_settings` which allows for enforcing default settings on + optimizer parameters. + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``flat_decay_mult`` (float): It will be multiplied to the weight + decay for all one-dimensional parameters + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Defaults to False. + - ``force_default_settings`` (bool): If true, this will override any + custom settings defined by ``custom_keys`` and enforce the use of + default settings for optimizer parameters like ``bias_lr_mult``. + This is particularly useful when you want to ensure that certain layers + or parameters adhere strictly to the pre-defined default settings. + + Note: + + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset layer. + So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset + layer in deformable convs, set ``dcn_offset_lr_mult`` to the original + ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when the + model contains multiple DCN layers in places other than backbone. + + 3. When the option ``force_default_settings`` is true, it will override + any custom settings provided in ``custom_keys``. This ensures that the + default settings for the optimizer parameters are used. + + Args: + optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. + + Required fields of ``optim_wrapper_cfg`` are + + - ``type``: class name of the OptimizerWrapper + - ``optimizer``: The configuration of optimizer. + + Optional fields of ``optim_wrapper_cfg`` are + + - any arguments of the corresponding optimizer wrapper type, + e.g., accumulative_counts, clip_grad, etc. + + Required fields of ``optimizer`` are + + - `type`: class name of the optimizer. + + Optional fields of ``optimizer`` are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optim_wrapper_cfg = dict( + >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, + >>> momentum=0.9, weight_decay=0.0001)) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( + >>> type='SGD', lr=0.01, weight_decay=0.95)) + >>> paramwise_cfg = dict(custom_keys={ + >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + is_dcn_module: Optional[Union[int, float]] = None) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None) + flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None) + force_default_settings = self.paramwise_cfg.get( + 'force_default_settings', False) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if bypass_duplicate and self._is_in(param_group, params): + print_log( + f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}', + logger='current', + level=logging.WARNING) + continue + if not param.requires_grad: + params.append(param_group) + continue + + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + # add custom settings to param_group + for k, v in custom_keys[key].items(): + param_group[k] = v + break + + if not is_custom or force_default_settings: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not ( + is_norm or is_dcn_module) and bias_lr_mult is not None: + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and dcn_offset_lr_mult is not None + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm and norm_decay_mult is not None: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # bias lr and decay + elif (name == 'bias' and not is_dcn_module + and bias_decay_mult is not None): + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + # depth-wise conv + elif is_dwconv and dwconv_decay_mult is not None: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # flatten parameters except dcn offset + elif (param.ndim == 1 and not is_dcn_module + and flat_decay_mult is not None): + param_group[ + 'weight_decay'] = self.base_wd * flat_decay_mult + params.append(param_group) + for key, value in param_group.items(): + if key == 'params': + continue + full_name = f'{prefix}.{name}' if prefix else name + print_log( + f'paramwise_options -- {full_name}:{key}={value}', + logger='current') + + if mmcv_full_available(): + from mmcv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) diff --git a/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py b/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..fdae3ca698c65879056b969f04185f80452ff8d0 --- /dev/null +++ b/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import warnings + +from mmengine.dist import get_dist_info +from mmengine.logging import print_log +from mmengine.optim import DefaultOptimWrapperConstructor + +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +def get_layer_id_for_convnext(var_name, max_layer_id): + """Get the layer id to set the different learning rates in ``layer_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_layer_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + stage_id = int(var_name.split('.')[2]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + block_id = int(var_name.split('.')[3]) + if stage_id == 0: + layer_id = 1 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + else: + return max_layer_id + 1 + + +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_stage_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + return 0 + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + return stage_id + 1 + else: + return max_stage_id - 1 + + +def get_layer_id_for_vit(var_name, max_layer_id): + """Get the layer id to set the different learning rates. + + Args: + var_name (str): The key of the model. + num_max_layer (int): Maximum number of backbone layers. + + Returns: + int: Returns the layer id of the key. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('backbone.layers'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + else: + return max_layer_id - 1 + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for ConvNeXt, + BEiT and MAE. + """ + + def add_params(self, params, module, **kwargs): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + """ + + parameter_groups = {} + print_log(f'self.paramwise_cfg is {self.paramwise_cfg}') + num_layers = self.paramwise_cfg.get('num_layers') + 2 + decay_rate = self.paramwise_cfg.get('decay_rate') + decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') + print_log('Build LearningRateDecayOptimizerConstructor ' + f'{decay_type} {decay_rate} - {num_layers}') + weight_decay = self.base_wd + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') or name in ( + 'pos_embed', 'cls_token'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + print_log(f'set param {name} as id {layer_id}') + elif 'BEiT' in module.backbone.__class__.__name__ or \ + 'MAE' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_vit(name, num_layers) + print_log(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + elif decay_type == 'stage_wise': + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + print_log(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + group_name = f'layer_{layer_id}_{group_name}' + + if group_name not in parameter_groups: + scale = decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + print_log(f'Param groups = {json.dumps(to_display, indent=2)}') + params.extend(parameter_groups.values()) + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for BEiT, + and it will be deprecated. + Please use ``LearningRateDecayOptimizerConstructor`` instead. + """ + + def __init__(self, optim_wrapper_cfg, paramwise_cfg): + warnings.warn('DeprecationWarning: Original ' + 'LayerDecayOptimizerConstructor of BEiT ' + 'will be deprecated. Please use ' + 'LearningRateDecayOptimizerConstructor instead, ' + 'and set decay_type = layer_wise_vit in paramwise_cfg.') + paramwise_cfg.update({'decay_type': 'layer_wise_vit'}) + warnings.warn('DeprecationWarning: Layer_decay_rate will ' + 'be deleted, please use decay_rate instead.') + paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') + super().__init__(optim_wrapper_cfg, paramwise_cfg) diff --git a/mmseg/engine/schedulers/__init__.py b/mmseg/engine/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cd3f6211345bb3627b76d683291f48efd934a77 --- /dev/null +++ b/mmseg/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .poly_ratio_scheduler import PolyLRRatio + +__all__ = ['PolyLRRatio'] diff --git a/mmseg/engine/schedulers/poly_ratio_scheduler.py b/mmseg/engine/schedulers/poly_ratio_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..057203acc9cc9fc72306d2039669b90f35704436 --- /dev/null +++ b/mmseg/engine/schedulers/poly_ratio_scheduler.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from mmengine.optim.scheduler import PolyLR + +from mmseg.registry import PARAM_SCHEDULERS + + +@PARAM_SCHEDULERS.register_module() +class PolyLRRatio(PolyLR): + """Implements polynomial learning rate decay with ratio. + + This scheduler adjusts the learning rate of each parameter group + following a polynomial decay equation. The decay can occur in + conjunction with external parameter adjustments made outside this + scheduler. + + Args: + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. + eta_min (float): Minimum learning rate at the end of scheduling. + Defaults to 0. + eta_min_ratio (float, optional): The ratio of the minimum parameter + value to the base parameter value. Either `eta_min` or + `eta_min_ratio` should be specified. Defaults to None. + power (float): The power of the polynomial. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta_min_ratio = eta_min_ratio + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + param_groups_value = [] + for base_value, param_group in zip(self.base_values, + self.optimizer.param_groups): + eta_min = self.eta_min if self.eta_min_ratio is None else \ + base_value * self.eta_min_ratio + step_ratio = (1 - 1 / + (self.total_iters - self.last_step + 1))**self.power + step_value = (param_group[self.param_name] - + eta_min) * step_ratio + eta_min + param_groups_value.append(step_value) + + return param_groups_value diff --git a/mmseg/evaluation/__init__.py b/mmseg/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82b3a8d68d3aefcc23542fc1006eaddde05ca2ab --- /dev/null +++ b/mmseg/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .metrics import CityscapesMetric, DepthMetric, IoUMetric + +__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric'] diff --git a/mmseg/evaluation/metrics/__init__.py b/mmseg/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..848d4713dc8c0b6a08569d536bb72bd04ca1b1cc --- /dev/null +++ b/mmseg/evaluation/metrics/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .citys_metric import CityscapesMetric +from .depth_metric import DepthMetric +from .iou_metric import IoUMetric + +__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric'] diff --git a/mmseg/evaluation/metrics/citys_metric.py b/mmseg/evaluation/metrics/citys_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..32984653c3fa9c13d8c6a7402033001012b5031f --- /dev/null +++ b/mmseg/evaluation/metrics/citys_metric.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import shutil +from collections import OrderedDict +from typing import Dict, Optional, Sequence + +try: + + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + import cityscapesscripts.helpers.labels as CSLabels +except ImportError: + CSLabels = None + CSEval = None + +import numpy as np +from mmengine.dist import is_main_process, master_only +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class CityscapesMetric(BaseMetric): + """Cityscapes evaluation metric. + + Args: + output_dir (str): The directory for output prediction + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + keep_results (bool): Whether to keep the results. When ``format_only`` + is True, ``keep_results`` must be True. Defaults to False. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + output_dir: str, + ignore_index: int = 255, + format_only: bool = False, + keep_results: bool = False, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + if CSEval is None: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + self.output_dir = output_dir + self.ignore_index = ignore_index + + self.format_only = format_only + if format_only: + assert keep_results, ( + 'When format_only is True, the results must be keep, please ' + f'set keep_results as True, but got {keep_results}') + self.keep_results = keep_results + self.prefix = prefix + if is_main_process(): + mkdir_or_exist(self.output_dir) + + @master_only + def __del__(self) -> None: + """Clean up.""" + if not self.keep_results: + shutil.rmtree(self.output_dir) + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + mkdir_or_exist(self.output_dir) + + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy() + # when evaluating with official cityscapesscripts, + # labelIds should be used + pred_label = self._convert_to_label_id(pred_label) + basename = osp.splitext(osp.basename(data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') + output.save(png_filename) + if self.format_only: + # format_only always for test dataset without ground truth + gt_filename = '' + else: + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + gt_filename = data_sample['seg_map_path'].replace( + 'labelTrainIds.png', 'labelIds.png') + self.results.append((png_filename, gt_filename)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): Testing results of the dataset. + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + eval_results = dict() + print_log( + f'Evaluating results under {self.output_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(self.output_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + pred_list, gt_list = zip(*results) + metric = dict() + eval_results.update( + CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args)) + metric['averageScoreCategories'] = eval_results[ + 'averageScoreCategories'] + metric['averageScoreInstCategories'] = eval_results[ + 'averageScoreInstCategories'] + return metric + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy diff --git a/mmseg/evaluation/metrics/depth_metric.py b/mmseg/evaluation/metrics/depth_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..621d4a31c9fe69cdbf83790e8f320218f755557a --- /dev/null +++ b/mmseg/evaluation/metrics/depth_metric.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict, defaultdict +from typing import Dict, List, Optional, Sequence + +import cv2 +import numpy as np +import torch +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from prettytable import PrettyTable +from torch import Tensor + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class DepthMetric(BaseMetric): + """Depth estimation evaluation metric. + + Args: + depth_metrics (List[str], optional): List of metrics to compute. If + not specified, defaults to all metrics in self.METRICS. + min_depth_eval (float): Minimum depth value for evaluation. + Defaults to 0.0. + max_depth_eval (float): Maximum depth value for evaluation. + Defaults to infinity. + crop_type (str, optional): Specifies the type of cropping to be used + during evaluation. This option can affect how the evaluation mask + is generated. Currently, 'nyu_crop' is supported, but other + types can be added in future. Defaults to None if no cropping + should be applied. + depth_scale_factor (float): Factor to scale the depth values. + Defaults to 1.0. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log', + 'log10', 'silog') + + def __init__(self, + depth_metrics: Optional[List[str]] = None, + min_depth_eval: float = 0.0, + max_depth_eval: float = float('inf'), + crop_type: Optional[str] = None, + depth_scale_factor: float = 1.0, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if depth_metrics is None: + self.metrics = self.METRICS + elif isinstance(depth_metrics, [tuple, list]): + for metric in depth_metrics: + assert metric in self.METRICS, f'the metric {metric} is not ' \ + f'supported. Please use metrics in {self.METRICS}' + self.metrics = depth_metrics + + # Validate crop_type, if provided + assert crop_type in [ + None, 'nyu_crop' + ], (f'Invalid value for crop_type: {crop_type}. Supported values are ' + 'None or \'nyu_crop\'.') + self.crop_type = crop_type + self.min_depth_eval = min_depth_eval + self.max_depth_eval = max_depth_eval + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + self.depth_scale_factor = depth_scale_factor + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_label = data_sample['pred_depth_map']['data'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + gt_depth = data_sample['gt_depth_map']['data'].squeeze().to( + pred_label) + + eval_mask = self._get_eval_mask(gt_depth) + self.results.append( + (gt_depth[eval_mask], pred_label[eval_mask])) + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy( + ) * self.depth_scale_factor + + cv2.imwrite(png_filename, output_mask.astype(np.uint16), + [cv2.IMWRITE_PNG_COMPRESSION, 0]) + + def _get_eval_mask(self, gt_depth: Tensor): + """Generates an evaluation mask based on ground truth depth and + cropping. + + Args: + gt_depth (Tensor): Ground truth depth map. + + Returns: + Tensor: Boolean mask where evaluation should be performed. + """ + valid_mask = torch.logical_and(gt_depth > self.min_depth_eval, + gt_depth < self.max_depth_eval) + + if self.crop_type == 'nyu_crop': + # this implementation is adapted from + # https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa + crop_mask = torch.zeros_like(valid_mask) + crop_mask[45:471, 41:601] = 1 + else: + crop_mask = torch.ones_like(valid_mask) + + eval_mask = torch.logical_and(valid_mask, crop_mask) + return eval_mask + + @staticmethod + def _calc_all_metrics(gt_depth, pred_depth): + """Computes final evaluation metrics based on accumulated results.""" + assert gt_depth.shape == pred_depth.shape + + thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth)) + diff = pred_depth - gt_depth + diff_log = torch.log(pred_depth) - torch.log(gt_depth) + + d1 = torch.sum(thresh < 1.25).float() / len(thresh) + d2 = torch.sum(thresh < 1.25**2).float() / len(thresh) + d3 = torch.sum(thresh < 1.25**3).float() / len(thresh) + + abs_rel = torch.mean(torch.abs(diff) / gt_depth) + sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth) + + rmse = torch.sqrt(torch.mean(torch.pow(diff, 2))) + rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2))) + + log10 = torch.mean( + torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth))) + silog = torch.sqrt( + torch.pow(diff_log, 2).mean() - + 0.5 * torch.pow(diff_log.mean(), 2)) + + return { + 'd1': d1.item(), + 'd2': d2.item(), + 'd3': d3.item(), + 'abs_rel': abs_rel.item(), + 'sq_rel': sq_rel.item(), + 'rmse': rmse.item(), + 'rmse_log': rmse_log.item(), + 'log10': log10.item(), + 'silog': silog.item() + } + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The keys + are identical with self.metrics. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + metrics = defaultdict(list) + for gt_depth, pred_depth in results: + for key, value in self._calc_all_metrics(gt_depth, + pred_depth).items(): + metrics[key].append(value) + metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics} + + table_data = PrettyTable() + for key, val in metrics.items(): + table_data.add_column(key, [round(val, 5)]) + + print_log('results:', logger) + print_log('\n' + table_data.get_string(), logger=logger) + + return metrics diff --git a/mmseg/evaluation/metrics/iou_metric.py b/mmseg/evaluation/metrics/iou_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..16014c74001d7295f9fff8f03ef185077e3f613b --- /dev/null +++ b/mmseg/evaluation/metrics/iou_metric.py @@ -0,0 +1,286 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence + +import numpy as np +import torch +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image +from prettytable import PrettyTable + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class IoUMetric(BaseMetric): + """IoU evaluation metric. + + Args: + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + iou_metrics (list[str] | str): Metrics to be calculated, the options + includes 'mIoU', 'mDice' and 'mFscore'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + ignore_index: int = 255, + iou_metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + self.ignore_index = ignore_index + self.metrics = iou_metrics + self.nan_to_num = nan_to_num + self.beta = beta + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + label = data_sample['gt_sem_seg']['data'].squeeze().to( + pred_label) + self.results.append( + self.intersect_and_union(pred_label, label, num_classes, + self.ignore_index)) + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy() + # The index range of official ADE20k dataset is from 0 to 150. + # But the index range of output is from 0 to 149. + # That is because we set reduce_zero_label=True. + if data_sample.get('reduce_zero_label', False): + output_mask = output_mask + 1 + output = Image.fromarray(output_mask.astype(np.uint8)) + output.save(png_filename) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + results = tuple(zip(*results)) + assert len(results) == 4 + + total_area_intersect = sum(results[0]) + total_area_union = sum(results[1]) + total_area_pred_label = sum(results[2]) + total_area_label = sum(results[3]) + ret_metrics = self.total_area_to_metrics( + total_area_intersect, total_area_union, total_area_pred_label, + total_area_label, self.metrics, self.nan_to_num, self.beta) + + class_names = self.dataset_meta['classes'] + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + metrics[key] = val + else: + metrics['m' + key] = val + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + + return metrics + + @staticmethod + def intersect_and_union(pred_label: torch.tensor, label: torch.tensor, + num_classes: int, ignore_index: int): + """Calculate Intersection and Union. + + Args: + pred_label (torch.tensor): Prediction segmentation map + or predict result filename. The shape is (H, W). + label (torch.tensor): Ground truth segmentation map + or label filename. The shape is (H, W). + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + @staticmethod + def total_area_to_metrics(total_area_intersect: np.ndarray, + total_area_union: np.ndarray, + total_area_pred_label: np.ndarray, + total_area_label: np.ndarray, + metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1): + """Calculate evaluation metrics + Args: + total_area_intersect (np.ndarray): The intersection of prediction + and ground truth histogram on all classes. + total_area_union (np.ndarray): The union of prediction and ground + truth histogram on all classes. + total_area_pred_label (np.ndarray): The prediction histogram on + all classes. + total_area_label (np.ndarray): The ground truth histogram on + all classes. + metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and + 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be + replaced by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + Returns: + Dict[str, np.ndarray]: per category evaluation metrics, + shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined + score. Default: 1. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError(f'metrics {metrics} is not supported') + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([ + f_score(x[0], x[1], beta) for x in zip(precision, recall) + ]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/mmseg/models/.DS_Store b/mmseg/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..64389d97bf96b7040c9faee47ea9393f4847e95a Binary files /dev/null and b/mmseg/models/.DS_Store differ diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a98951283c1ac4047c5f5ca3cdc827a43c42cf60 --- /dev/null +++ b/mmseg/models/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assigners import * # noqa: F401,F403 +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, + build_head, build_loss, build_segmentor) +from .data_preprocessor import SegDataPreProcessor +from .decode_heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .segmentors import * # noqa: F401,F403 +from .text_encoder import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', + 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor' +] diff --git a/mmseg/models/__pycache__/__init__.cpython-311.pyc b/mmseg/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b5b8618be040e389f5d88783aed144872be7d12 Binary files /dev/null and b/mmseg/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/__pycache__/builder.cpython-311.pyc b/mmseg/models/__pycache__/builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7af83bb56eca6fd8811a2ff44ef03a6507bec59 Binary files /dev/null and b/mmseg/models/__pycache__/builder.cpython-311.pyc differ diff --git a/mmseg/models/__pycache__/data_preprocessor.cpython-311.pyc b/mmseg/models/__pycache__/data_preprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b36097c85067468fc281de6469e3eb2d4dd40645 Binary files /dev/null and b/mmseg/models/__pycache__/data_preprocessor.cpython-311.pyc differ diff --git a/mmseg/models/assigners/__init__.py b/mmseg/models/assigners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d49b1b18b9e3e6d4e3b19c48eb1c80cbb1205f69 --- /dev/null +++ b/mmseg/models/assigners/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_assigner import BaseAssigner +from .hungarian_assigner import HungarianAssigner +from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost + +__all__ = [ + 'BaseAssigner', + 'HungarianAssigner', + 'ClassificationCost', + 'CrossEntropyLossCost', + 'DiceCost', +] diff --git a/mmseg/models/assigners/__pycache__/__init__.cpython-311.pyc b/mmseg/models/assigners/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c9bd5140f724de4d0bb3167ba14c585fd13f32b Binary files /dev/null and b/mmseg/models/assigners/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/assigners/__pycache__/base_assigner.cpython-311.pyc b/mmseg/models/assigners/__pycache__/base_assigner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1ea252dcd28b48f4c0c56993aa0fc0fc3c0b9c9 Binary files /dev/null and b/mmseg/models/assigners/__pycache__/base_assigner.cpython-311.pyc differ diff --git a/mmseg/models/assigners/__pycache__/hungarian_assigner.cpython-311.pyc b/mmseg/models/assigners/__pycache__/hungarian_assigner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34136cecb01e243796eb045cf7b8d0b4d2efffaa Binary files /dev/null and b/mmseg/models/assigners/__pycache__/hungarian_assigner.cpython-311.pyc differ diff --git a/mmseg/models/assigners/__pycache__/match_cost.cpython-311.pyc b/mmseg/models/assigners/__pycache__/match_cost.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8bf3326f1b7e29dd72687571a72fd25ae22adf9 Binary files /dev/null and b/mmseg/models/assigners/__pycache__/match_cost.cpython-311.pyc differ diff --git a/mmseg/models/assigners/base_assigner.py b/mmseg/models/assigners/base_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..97895cdac2789a62c3e8a381caaf944679f1e5a4 --- /dev/null +++ b/mmseg/models/assigners/base_assigner.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Optional + +from mmengine.structures import InstanceData + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns masks to ground truth class labels.""" + + @abstractmethod + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs): + """Assign masks to either a ground truth class label or a negative + label.""" diff --git a/mmseg/models/assigners/hungarian_assigner.py b/mmseg/models/assigners/hungarian_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..28868f0a04e7feaf3de20e39fac5059d789047d3 --- /dev/null +++ b/mmseg/models/assigners/hungarian_assigner.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from scipy.optimize import linear_sum_assignment +from torch.cuda.amp import autocast + +from mmseg.registry import TASK_UTILS +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class HungarianAssigner(BaseAssigner): + """Computes one-to-one matching between prediction masks and ground truth. + + This class uses bipartite matching-based assignment to computes an + assignment between the prediction masks and the ground truth. The + assignment result is based on the weighted sum of match costs. The + Hungarian algorithm is used to calculate the best matching with the + minimum cost. The prediction masks that are not matched are classified + as background. + + Args: + match_costs (ConfigDict|List[ConfigDict]): Match cost configs. + """ + + def __init__( + self, match_costs: Union[List[Union[dict, ConfigDict]], dict, + ConfigDict] + ) -> None: + + if isinstance(match_costs, dict): + match_costs = [match_costs] + elif isinstance(match_costs, list): + assert len(match_costs) > 0, \ + 'match_costs must not be a empty list.' + + self.match_costs = [ + TASK_UTILS.build(match_cost) for match_cost in match_costs + ] + + def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, + **kwargs): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The assignment first calculates the cost for each + category assigned to each query mask, and then uses the + Hungarian algorithm to calculate the minimum cost as the best + match. + + Args: + pred_instances (InstanceData): Instances of model + predictions. It includes "masks", with shape + (n, h, w) or (n, l), and "cls", with shape (n, num_classes+1) + gt_instances (InstanceData): Ground truth of instance + annotations. It includes "labels", with shape (k, ), + and "masks", with shape (k, h, w) or (k, l). + + Returns: + matched_quiery_inds (Tensor): The indexes of matched quieres. + matched_label_inds (Tensor): The indexes of matched labels. + """ + # compute weighted cost + cost_list = [] + with autocast(enabled=False): + for match_cost in self.match_costs: + cost = match_cost( + pred_instances=pred_instances, gt_instances=gt_instances) + cost_list.append(cost) + cost = torch.stack(cost_list).sum(dim=0) + + device = cost.device + # do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + + matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost) + matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device) + matched_label_inds = torch.from_numpy(matched_label_inds).to(device) + + return matched_quiery_inds, matched_label_inds diff --git a/mmseg/models/assigners/match_cost.py b/mmseg/models/assigners/match_cost.py new file mode 100644 index 0000000000000000000000000000000000000000..560df852902fa7a2167cc7cfdf86595bf8d6e3f8 --- /dev/null +++ b/mmseg/models/assigners/match_cost.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Union + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import TASK_UTILS + + +class BaseMatchCost: + """Base match cost class. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, weight: Union[float, int] = 1.) -> None: + self.weight = weight + + @abstractmethod + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): Instances of model predictions. + It often includes "labels" and "scores". + gt_instances (InstanceData): Ground truth of instance + annotations. It usually includes "labels". + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pass + + +@TASK_UTILS.register_module() +class ClassificationCost(BaseMatchCost): + """ClsSoftmaxCost. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmseg.models.assigners import ClassificationCost + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight: Union[float, int] = 1) -> None: + super().__init__(weight=weight) + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): "scores" inside is + predicted classification logits, of shape + (num_queries, num_class). + gt_instances (InstanceData): "labels" inside should have + shape (num_gt, ). + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'scores'), \ + "pred_instances must contain 'scores'" + assert hasattr(gt_instances, 'labels'), \ + "gt_instances must contain 'labels'" + pred_scores = pred_instances.scores + gt_labels = gt_instances.labels + + pred_scores = pred_scores.softmax(-1) + cls_cost = -pred_scores[:, gt_labels] + + return cls_cost * self.weight + + +@TASK_UTILS.register_module() +class DiceCost(BaseMatchCost): + """Cost of mask assignments based on dice losses. + + Args: + pred_act (bool): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float): Defaults to 1e-3. + naive_dice (bool): If True, use the naive dice loss + in which the power of the number in the denominator is + the first power. If False, use the second power that + is adopted by K-Net and SOLO. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + pred_act: bool = False, + eps: float = 1e-3, + naive_dice: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.pred_act = pred_act + self.eps = eps + self.naive_dice = naive_dice + + def _binary_mask_dice_loss(self, mask_preds: Tensor, + gt_masks: Tensor) -> Tensor: + """ + Args: + mask_preds (Tensor): Mask prediction in shape (num_queries, *). + gt_masks (Tensor): Ground truth in shape (num_gt, *) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (num_queries, num_gt). + """ + mask_preds = mask_preds.flatten(1) + gt_masks = gt_masks.flatten(1).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + if self.naive_dice: + denominator = mask_preds.sum(-1)[:, None] + \ + gt_masks.sum(-1)[None, :] + else: + denominator = mask_preds.pow(2).sum(1)[:, None] + \ + gt_masks.pow(2).sum(1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): Predicted instances which + must contain "masks". + gt_instances (InstanceData): Ground truth which must contain + "mask". + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'masks'), \ + "pred_instances must contain 'masks'" + assert hasattr(gt_instances, 'masks'), \ + "gt_instances must contain 'masks'" + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + + if self.pred_act: + pred_masks = pred_masks.sigmoid() + dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks) + return dice_cost * self.weight + + +@TASK_UTILS.register_module() +class CrossEntropyLossCost(BaseMatchCost): + """CrossEntropyLossCost. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + use_sigmoid: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred: Tensor, + gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or + (num_queries, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + + Returns: + Tensor: Cross entropy cost matrix in shape (num_queries, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits( + cls_pred, torch.ones_like(cls_pred), reduction='none') + neg = F.binary_cross_entropy_with_logits( + cls_pred, torch.zeros_like(cls_pred), reduction='none') + cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ + torch.einsum('nc,mc->nm', neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``masks``. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'masks'), \ + "pred_instances must contain 'masks'" + assert hasattr(gt_instances, 'masks'), \ + "gt_instances must contain 'masks'" + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(pred_masks, gt_masks) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..784d3dfdb709f6b63f042836b0fe4047271e05a5 --- /dev/null +++ b/mmseg/models/backbones/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beit import BEiT +from .bisenetv1 import BiSeNetV1 +from .bisenetv2 import BiSeNetV2 +from .cgnet import CGNet +from .ddrnet import DDRNet +from .erfnet import ERFNet +from .fast_scnn import FastSCNN +from .hrnet import HRNet +from .icnet import ICNet +from .mae import MAE +from .mit import MixVisionTransformer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mscan import MSCAN +from .pidnet import PIDNet +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnext import ResNeXt +from .stdc import STDCContextPathNet, STDCNet +from .swin import SwinTransformer +from .timm_backbone import TIMMBackbone +from .twins import PCPVT, SVT +from .unet import UNet +from .vit import VisionTransformer +from .vpd import VPD + +__all__ = [ + 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', + 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', + 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', + 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', + 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN', + 'DDRNet', 'VPD' +] diff --git a/mmseg/models/backbones/__pycache__/__init__.cpython-311.pyc b/mmseg/models/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c1fb80d70690a5348c222a0863cd7b696da07c Binary files /dev/null and b/mmseg/models/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/beit.cpython-311.pyc b/mmseg/models/backbones/__pycache__/beit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..387887c971eeaba8e6b3846a3cd47221bfaa3ea3 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/beit.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/bisenetv1.cpython-311.pyc b/mmseg/models/backbones/__pycache__/bisenetv1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7865d8e9e274f64b6bc535c58d36d6ed4e8cf00e Binary files /dev/null and b/mmseg/models/backbones/__pycache__/bisenetv1.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/bisenetv2.cpython-311.pyc b/mmseg/models/backbones/__pycache__/bisenetv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ccfbe9bc737b092a0a2e26881a8e0a8bab870f4 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/bisenetv2.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/cgnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/cgnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f57e5a909aea1c4e3072504b5ffece116fe06746 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/cgnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/ddrnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/ddrnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6a72f4704ee21f972d1f0c8fe0212541fac6ed Binary files /dev/null and b/mmseg/models/backbones/__pycache__/ddrnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/erfnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/erfnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba87614a951707cc89b34ed45022a3febc460f2 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/erfnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/fast_scnn.cpython-311.pyc b/mmseg/models/backbones/__pycache__/fast_scnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7f738a6530ba444d015a851b6a71cc6ce45c63f Binary files /dev/null and b/mmseg/models/backbones/__pycache__/fast_scnn.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/hrnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/hrnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..527959d1283cc5d0d68df43cf3b4ec260d05b053 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/hrnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/icnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/icnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c25a75691c85e2751c302835ce898fe2f69c31f Binary files /dev/null and b/mmseg/models/backbones/__pycache__/icnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mae.cpython-311.pyc b/mmseg/models/backbones/__pycache__/mae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88f00a23d3e1af5f7328450275b9ba0c36447d7a Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mae.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mit.cpython-311.pyc b/mmseg/models/backbones/__pycache__/mit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e641b66563247c88bcbbbaec089c7d0447fdbe54 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mit.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-311.pyc b/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd8b87325ebe1340a129bbd441ffda1d90f92c05 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-311.pyc b/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9008fc7f98cc9663cd67295ad9a0b4fdba18dfe Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mscan.cpython-311.pyc b/mmseg/models/backbones/__pycache__/mscan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4361b999a688b64803154433c9a819a2ffd8046 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mscan.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/pidnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/pidnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09797022401261302a6c4b457fcd20559078f7d0 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/pidnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/resnest.cpython-311.pyc b/mmseg/models/backbones/__pycache__/resnest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7d7027d49be6b181c31fa56528370aa270e58f5 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/resnest.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/resnet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b111e3799e721b5a470a5e7a612619bcadfa9f3 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/resnet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/resnext.cpython-311.pyc b/mmseg/models/backbones/__pycache__/resnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da73f31ec77b81471da685bec29716c4f3dd880c Binary files /dev/null and b/mmseg/models/backbones/__pycache__/resnext.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/stdc.cpython-311.pyc b/mmseg/models/backbones/__pycache__/stdc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9c6071057d8f921e65d33af46c7622a98a48da2 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/stdc.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/swin.cpython-311.pyc b/mmseg/models/backbones/__pycache__/swin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7852d5940afd75f3c858d9a919a2aff69091253a Binary files /dev/null and b/mmseg/models/backbones/__pycache__/swin.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/timm_backbone.cpython-311.pyc b/mmseg/models/backbones/__pycache__/timm_backbone.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..186e8fbb74af767b3bd186f0622b3d0f8c59e68c Binary files /dev/null and b/mmseg/models/backbones/__pycache__/timm_backbone.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/twins.cpython-311.pyc b/mmseg/models/backbones/__pycache__/twins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9712bb255f742a5394df2ea85036acddf8eddcfb Binary files /dev/null and b/mmseg/models/backbones/__pycache__/twins.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/unet.cpython-311.pyc b/mmseg/models/backbones/__pycache__/unet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6fb472b81be6c9276cf7f5dfa52a93b7e6904e Binary files /dev/null and b/mmseg/models/backbones/__pycache__/unet.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/vit.cpython-311.pyc b/mmseg/models/backbones/__pycache__/vit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471814d015c86576203291851c6df07d3120d94e Binary files /dev/null and b/mmseg/models/backbones/__pycache__/vit.cpython-311.pyc differ diff --git a/mmseg/models/backbones/__pycache__/vpd.cpython-311.pyc b/mmseg/models/backbones/__pycache__/vpd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d0129c34b421b0d976da1ab5f5ab20dac45fff4 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/vpd.cpython-311.pyc differ diff --git a/mmseg/models/backbones/beit.py b/mmseg/models/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..e5da71e729256a9dd12b70d32886c9db27d9fa3c --- /dev/null +++ b/mmseg/models/backbones/beit.py @@ -0,0 +1,554 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import _load_checkpoint +from scipy import interpolate +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.registry import MODELS +from ..utils import PatchEmbed +from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + self.window_size = window_size + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (tuple[int], optional): The height and width of the window. + Default: None. + init_values (float, optional): Initialize the values of BEiTAttention + and FFN with learnable scaling. Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + bias='qv_bias', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=None, + attn_cfg=dict(), + ffn_cfg=dict(add_identity=False), + init_values=None): + attn_cfg.update(dict(window_size=window_size, qk_scale=None)) + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + qkv_bias=bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + attn_cfg=attn_cfg, + ffn_cfg=ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + self.gamma_1 = nn.Parameter( + init_values * torch.ones(embed_dims), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones(embed_dims), requires_grad=True) + + def build_attn(self, attn_cfg): + self.attn = BEiTAttention(**attn_cfg) + + def forward(self, x): + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x))) + return x + + +@MODELS.register_module() +class BEiT(BaseModule): + """BERT Pre-Training of Image Transformers. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_layers (int): Depth of transformer. Default: 12. + num_heads (int): Number of attention heads. Default: 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qv_bias (bool): Enable bias for qv if True. Default: True. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): Model pretrained path. Default: None. + init_values (float): Initialize the values of BEiTAttention and FFN + with learnable scaling. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qv_bias=True, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.img_size = img_size + self.patch_size = patch_size + self.norm_eval = norm_eval + self.pretrained = pretrained + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.num_fcs = num_fcs + self.qv_bias = qv_bias + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.patch_norm = patch_norm + self.init_values = init_values + self.window_size = (img_size[0] // patch_size, + img_size[1] // patch_size) + self.patch_shape = self.window_size + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self._build_patch_embedding() + self._build_layers() + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + def _build_patch_embedding(self): + """Build patch embedding layer.""" + self.patch_embed = PatchEmbed( + in_channels=self.in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, + norm_cfg=self.norm_cfg if self.patch_norm else None, + init_cfg=None) + + def _build_layers(self): + """Build transformer encoding layers.""" + + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + BEiTTransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias='qv_bias' if self.qv_bias else False, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.window_size, + init_values=self.init_values)) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _geometric_sequence_interpolation(self, src_size, dst_size, sequence, + num): + """Get new sequence via geometric sequence interpolation. + + Args: + src_size (int): Pos_embedding size in pre-trained model. + dst_size (int): Pos_embedding size in the current model. + sequence (tensor): The relative position bias of the pretrain + model after removing the extra tokens. + num (int): Number of attention heads. + Returns: + new_sequence (tensor): Geometric sequence interpolate the + pre-trained relative position bias to the size of + the current model. + """ + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + # Here is a binary function. + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + # The position of each interpolated point is determined + # by the ratio obtained by dichotomy. + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + r_ids = [-_ for _ in reversed(dis)] + x = r_ids + [0] + dis + y = r_ids + [0] + dis + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + # Interpolation functions are being executed and called. + new_sequence = [] + for i in range(num): + z = sequence[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + new_sequence.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence)) + new_sequence = torch.cat(new_sequence, dim=-1) + return new_sequence + + def resize_rel_pos_embed(self, checkpoint): + """Resize relative pos_embed weights. + + This function is modified from + https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + Args: + checkpoint (dict): Key and value of the pretrain model. + Returns: + state_dict (dict): Interpolate the relative pos_embed weights + in the pre-train model to the current model size. + """ + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + # In order to keep the center of pos_bias as consistent as + # possible after interpolation, and vice versa in the edge + # area, the geometric sequence interpolation method is adopted. + if 'relative_position_bias_table' in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = self.state_dict()[key].size() + dst_patch_shape = self.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + # Count the number of extra tokens. + num_extra_tokens = dst_num_pos - ( + dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + new_rel_pos_bias = self._geometric_sequence_interpolation( + src_size, dst_size, rel_pos_bias, num_attn_heads) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + return state_dict + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/mmseg/models/backbones/bisenetv1.py b/mmseg/models/backbones/bisenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..ca58bf9c597836937bc384739ff77001b5402942 --- /dev/null +++ b/mmseg/models/backbones/bisenetv1.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class SpatialPath(BaseModule): + """Spatial Path to preserve the spatial size of the original input image + and encode affluent spatial information. + + Args: + in_channels(int): The number of channels of input + image. Default: 3. + num_channels (Tuple[int]): The number of channels of + each layers in Spatial Path. + Default: (64, 64, 64, 128). + Returns: + x (torch.Tensor): Feature map for Feature Fusion Module. + """ + + def __init__(self, + in_channels=3, + num_channels=(64, 64, 64, 128), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(num_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + self.layers = [] + for i in range(len(num_channels)): + layer_name = f'layer{i + 1}' + self.layers.append(layer_name) + if i == 0: + self.add_module( + layer_name, + ConvModule( + in_channels=in_channels, + out_channels=num_channels[i], + kernel_size=7, + stride=2, + padding=3, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + elif i == len(num_channels) - 1: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + for i, layer_name in enumerate(self.layers): + layer_stage = getattr(self, layer_name) + x = layer_stage(x) + return x + + +class AttentionRefinementModule(BaseModule): + """Attention Refinement Module (ARM) to refine the features of each stage. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Attention Refinement Module. + """ + + def __init__(self, + in_channels, + out_channel, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_layer = ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.atten_conv_layer = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), nn.Sigmoid()) + + def forward(self, x): + x = self.conv_layer(x) + x_atten = self.atten_conv_layer(x) + x_out = x * x_atten + return x_out + + +class ContextPath(BaseModule): + """Context Path to provide sufficient receptive field. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + context_channels (Tuple[int]): The number of channel numbers + of various modules in Context Path. + Default: (128, 256, 512). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + Returns: + x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps + undergoing upsampling from 1/16 and 1/32 downsampling + feature maps. These two feature maps are used for Feature + Fusion Module and Auxiliary Head. + """ + + def __init__(self, + backbone_cfg, + context_channels=(128, 256, 512), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.backbone = MODELS.build(backbone_cfg) + + self.align_corners = align_corners + self.arm16 = AttentionRefinementModule(context_channels[1], + context_channels[0]) + self.arm32 = AttentionRefinementModule(context_channels[2], + context_channels[0]) + self.conv_head32 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_head16 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap_conv = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=context_channels[2], + out_channels=context_channels[0], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + x_4, x_8, x_16, x_32 = self.backbone(x) + x_gap = self.gap_conv(x_32) + + x_32_arm = self.arm32(x_32) + x_32_sum = x_32_arm + x_gap + x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') + x_32_up = self.conv_head32(x_32_up) + + x_16_arm = self.arm16(x_16) + x_16_sum = x_16_arm + x_32_up + x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') + x_16_up = self.conv_head16(x_16_up) + + return x_16_up, x_32_up + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module to fuse low level output feature of Spatial Path + and high level output feature of Context Path. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Feature Fusion Module. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_atten = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), nn.Sigmoid()) + + def forward(self, x_sp, x_cp): + x_concat = torch.cat([x_sp, x_cp], dim=1) + x_fuse = self.conv1(x_concat) + x_atten = self.gap(x_fuse) + # Note: No BN and more 1x1 conv in paper. + x_atten = self.conv_atten(x_atten) + x_atten = x_fuse * x_atten + x_out = x_atten + x_fuse + return x_out + + +@MODELS.register_module() +class BiSeNetV1(BaseModule): + """BiSeNetV1 backbone. + + This backbone is the implementation of `BiSeNet: Bilateral + Segmentation Network for Real-time Semantic + Segmentation `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input + image. Default: 3. + spatial_channels (Tuple[int]): Size of channel numbers of + various layers in Spatial Path. + Default: (64, 64, 64, 128). + context_channels (Tuple[int]): Size of channel numbers of + various modules in Context Path. + Default: (128, 256, 512). + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + out_channels(int): The number of channels of output. + It must be the same with `in_channels` of decode_head. + Default: 256. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + spatial_channels=(64, 64, 64, 128), + context_channels=(128, 256, 512), + out_indices=(0, 1, 2), + align_corners=False, + out_channels=256, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + assert len(spatial_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.out_indices = out_indices + self.align_corners = align_corners + self.context_path = ContextPath(backbone_cfg, context_channels, + self.align_corners) + self.spatial_path = SpatialPath(in_channels, spatial_channels) + self.ffm = FeatureFusionModule(context_channels[1], out_channels) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_context8, x_context16 = self.context_path(x) + x_spatial = self.spatial_path(x) + x_fuse = self.ffm(x_spatial, x_context8) + + outs = [x_fuse, x_context8, x_context16] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/bisenetv2.py b/mmseg/models/backbones/bisenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..32aa49822f7d0c3bd4839b3796a15689e1f4cbc0 --- /dev/null +++ b/mmseg/models/backbones/bisenetv2.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_activation_layer, build_norm_layer) +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class DetailBranch(BaseModule): + """Detail Branch with wide channels and shallow layers to capture low-level + details and generate high-resolution feature representation. + + Args: + detail_channels (Tuple[int]): Size of channel numbers of each stage + in Detail Branch, in paper it has 3 stages. + Default: (64, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Feature map of Detail Branch. + """ + + def __init__(self, + detail_channels=(64, 64, 128), + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + detail_branch = [] + for i in range(len(detail_channels)): + if i == 0: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + else: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=detail_channels[i - 1], + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + self.detail_branch = nn.ModuleList(detail_branch) + + def forward(self, x): + for stage in self.detail_branch: + x = stage(x) + return x + + +class StemBlock(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): First feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.conv_first = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.convs = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=out_channels // 2, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.pool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=False) + self.fuse_last = ConvModule( + in_channels=out_channels * 2, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.conv_first(x) + x_left = self.convs(x) + x_right = self.pool(x) + x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) + return x + + +class GELayer(BaseModule): + """Gather-and-Expansion Layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + stride (int): Stride of GELayer. Default: 1 + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Intermediate feature map in + Semantic Branch. + """ + + def __init__(self, + in_channels, + out_channels, + exp_ratio=6, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + mid_channel = in_channels * exp_ratio + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if stride == 1: + self.dwconv = nn.Sequential( + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.shortcut = None + else: + self.dwconv = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=mid_channel, + out_channels=mid_channel, + kernel_size=3, + stride=1, + padding=1, + groups=mid_channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + self.shortcut = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=None, + )) + + self.conv2 = nn.Sequential( + ConvModule( + in_channels=mid_channel, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + )) + + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.dwconv(x) + x = self.conv2(x) + if self.shortcut is not None: + shortcut = self.shortcut(identity) + x = x + shortcut + else: + x = x + identity + x = self.act(x) + return x + + +class CEBlock(BaseModule): + """Context Embedding Block for large receptive filed in Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Last feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + build_norm_layer(norm_cfg, self.in_channels)[1]) + self.conv_gap = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # Note: in paper here is naive conv2d, no bn-relu + self.conv_last = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + identity = x + x = self.gap(x) + x = self.conv_gap(x) + x = identity + x + x = self.conv_last(x) + return x + + +class SemanticBranch(BaseModule): + """Semantic Branch which is lightweight with narrow channels and deep + layers to obtain high-level semantic context. + + Args: + semantic_channels(Tuple[int]): Size of channel numbers of + various stages in Semantic Branch. + Default: (16, 32, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + semantic_outs (List[torch.Tensor]): List of several feature maps + for auxiliary heads (Booster) and Bilateral + Guided Aggregation Layer. + """ + + def __init__(self, + semantic_channels=(16, 32, 64, 128), + in_channels=3, + exp_ratio=6, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.semantic_channels = semantic_channels + self.semantic_stages = [] + for i in range(len(semantic_channels)): + stage_name = f'stage{i + 1}' + self.semantic_stages.append(stage_name) + if i == 0: + self.add_module( + stage_name, + StemBlock(self.in_channels, semantic_channels[i])) + elif i == (len(semantic_channels) - 1): + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + else: + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + + self.add_module(f'stage{len(semantic_channels)}_CEBlock', + CEBlock(semantic_channels[-1], semantic_channels[-1])) + self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') + + def forward(self, x): + semantic_outs = [] + for stage_name in self.semantic_stages: + semantic_stage = getattr(self, stage_name) + x = semantic_stage(x) + semantic_outs.append(x) + return semantic_outs + + +class BGALayer(BaseModule): + """Bilateral Guided Aggregation Layer to fuse the complementary information + from both Detail Branch and Semantic Branch. + + Args: + out_channels (int): Number of output channels. + Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + output (torch.Tensor): Output feature map for Segment heads. + """ + + def __init__(self, + out_channels=128, + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.out_channels = out_channels + self.align_corners = align_corners + self.detail_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.detail_down = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) + self.semantic_conv = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None)) + self.semantic_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.conv = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + inplace=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + def forward(self, x_d, x_s): + detail_dwconv = self.detail_dwconv(x_d) + detail_down = self.detail_down(x_d) + semantic_conv = self.semantic_conv(x_s) + semantic_dwconv = self.semantic_dwconv(x_s) + semantic_conv = resize( + input=semantic_conv, + size=detail_dwconv.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) + fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) + fuse_2 = resize( + input=fuse_2, + size=fuse_1.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = self.conv(fuse_1 + fuse_2) + return output + + +@MODELS.register_module() +class BiSeNetV2(BaseModule): + """BiSeNetV2: Bilateral Network with Guided Aggregation for + Real-time Semantic Segmentation. + + This backbone is the implementation of + `BiSeNetV2 `_. + + Args: + in_channels (int): Number of channel of input image. Default: 3. + detail_channels (Tuple[int], optional): Channels of each stage + in Detail Branch. Default: (64, 64, 128). + semantic_channels (Tuple[int], optional): Channels of each stage + in Semantic Branch. Default: (16, 32, 64, 128). + See Table 1 and Figure 3 of paper for more details. + semantic_expansion_ratio (int, optional): The expansion factor + expanding channel number of middle channels in Semantic Branch. + Default: 6. + bga_channels (int, optional): Number of middle channels in + Bilateral Guided Aggregation Layer. Default: 128. + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2, 3, 4). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_indices = out_indices + self.detail_channels = detail_channels + self.semantic_channels = semantic_channels + self.semantic_expansion_ratio = semantic_expansion_ratio + self.bga_channels = bga_channels + self.align_corners = align_corners + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.detail = DetailBranch(self.detail_channels, self.in_channels) + self.semantic = SemanticBranch(self.semantic_channels, + self.in_channels, + self.semantic_expansion_ratio) + self.bga = BGALayer(self.bga_channels, self.align_corners) + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_detail = self.detail(x) + x_semantic_lst = self.semantic(x) + x_head = self.bga(x_detail, x_semantic_lst[-1]) + outs = [x_head] + x_semantic_lst[:-1] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/cgnet.py b/mmseg/models/backbones/cgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b74b494f53466d1c608e50d088632aa952a5e534 --- /dev/null +++ b/mmseg/models/backbones/cgnet.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS + + +class GlobalContextExtractor(nn.Module): + """Global Context Extractor for CGNet. + + This class is employed to refine the joint feature of both local feature + and surrounding context. + + Args: + channel (int): Number of input feature channels. + reduction (int): Reductions for global context extractor. Default: 16. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, channel, reduction=16, with_cp=False): + super().__init__() + self.channel = channel + self.reduction = reduction + assert reduction >= 1 and channel >= reduction + self.with_cp = with_cp + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + + def _inner_forward(x): + num_batch, num_channel = x.size()[:2] + y = self.avg_pool(x).view(num_batch, num_channel) + y = self.fc(y).view(num_batch, num_channel, 1, 1) + return x * y + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class ContextGuidedBlock(nn.Module): + """Context Guided Block for CGNet. + + This class consists of four components: local feature extractor, + surrounding feature extractor, joint feature extractor and global + context extractor. + + Args: + in_channels (int): Number of input feature channels. + out_channels (int): Number of output feature channels. + dilation (int): Dilation rate for surrounding context extractor. + Default: 2. + reduction (int): Reduction for global context extractor. Default: 16. + skip_connect (bool): Add input to output or not. Default: True. + downsample (bool): Downsample the input to 1/2 or not. Default: False. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + dilation=2, + reduction=16, + skip_connect=True, + downsample=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.downsample = downsample + + channels = out_channels if downsample else out_channels // 2 + if 'type' in act_cfg and act_cfg['type'] == 'PReLU': + act_cfg['num_parameters'] = channels + kernel_size = 3 if downsample else 1 + stride = 2 if downsample else 1 + padding = (kernel_size - 1) // 2 + + self.conv1x1 = ConvModule( + in_channels, + channels, + kernel_size, + stride, + padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.f_loc = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=False) + self.f_sur = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=dilation, + groups=channels, + dilation=dilation, + bias=False) + + self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] + self.activate = nn.PReLU(2 * channels) + + if downsample: + self.bottleneck = build_conv_layer( + conv_cfg, + 2 * channels, + out_channels, + kernel_size=1, + bias=False) + + self.skip_connect = skip_connect and not downsample + self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) + + def forward(self, x): + + def _inner_forward(x): + out = self.conv1x1(x) + loc = self.f_loc(out) + sur = self.f_sur(out) + + joi_feat = torch.cat([loc, sur], 1) # the joint feature + joi_feat = self.bn(joi_feat) + joi_feat = self.activate(joi_feat) + if self.downsample: + joi_feat = self.bottleneck(joi_feat) # channel = out_channels + # f_glo is employed to refine the joint feature + out = self.f_glo(joi_feat) + + if self.skip_connect: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InputInjection(nn.Module): + """Downsampling module for CGNet.""" + + def __init__(self, num_downsampling): + super().__init__() + self.pool = nn.ModuleList() + for i in range(num_downsampling): + self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + for pool in self.pool: + x = pool(x) + return x + + +@MODELS.register_module() +class CGNet(BaseModule): + """CGNet backbone. + + This backbone is the implementation of `A Light-weight Context Guided + Network for Semantic Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Normally 3. + num_channels (tuple[int]): Numbers of feature channels at each stages. + Default: (32, 64, 128). + num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2. + Default: (3, 21). + dilations (tuple[int]): Dilation rate for surrounding context + extractors at stage 1 and stage 2. Default: (2, 4). + reductions (tuple[int]): Reductions for global context extractors at + stage 1 and stage 2. Default: (8, 16). + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + + super().__init__(init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer=['Conv2d', 'Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Constant', val=0, layer='PReLU') + ] + else: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.num_channels = num_channels + assert isinstance(self.num_channels, tuple) and len( + self.num_channels) == 3 + self.num_blocks = num_blocks + assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 + self.dilations = dilations + assert isinstance(self.dilations, tuple) and len(self.dilations) == 2 + self.reductions = reductions + assert isinstance(self.reductions, tuple) and len(self.reductions) == 2 + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': + self.act_cfg['num_parameters'] = num_channels[0] + self.norm_eval = norm_eval + self.with_cp = with_cp + + cur_channels = in_channels + self.stem = nn.ModuleList() + for i in range(3): + self.stem.append( + ConvModule( + cur_channels, + num_channels[0], + 3, + 2 if i == 0 else 1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + cur_channels = num_channels[0] + + self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 + self.inject_4x = InputInjection(2) # down-sample for Input, factor=4 + + cur_channels += in_channels + self.norm_prelu_0 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 1 + self.level1 = nn.ModuleList() + for i in range(num_blocks[0]): + self.level1.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[1], + num_channels[1], + dilations[0], + reductions[0], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[1] + in_channels + self.norm_prelu_1 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 2 + self.level2 = nn.ModuleList() + for i in range(num_blocks[1]): + self.level2.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[2], + num_channels[2], + dilations[1], + reductions[1], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[2] + self.norm_prelu_2 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + def forward(self, x): + output = [] + + # stage 0 + inp_2x = self.inject_2x(x) + inp_4x = self.inject_4x(x) + for layer in self.stem: + x = layer(x) + x = self.norm_prelu_0(torch.cat([x, inp_2x], 1)) + output.append(x) + + # stage 1 + for i, layer in enumerate(self.level1): + x = layer(x) + if i == 0: + down1 = x + x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1)) + output.append(x) + + # stage 2 + for i, layer in enumerate(self.level2): + x = layer(x) + if i == 0: + down2 = x + x = self.norm_prelu_2(torch.cat([down2, x], 1)) + output.append(x) + + return output + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/ddrnet.py b/mmseg/models/backbones/ddrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4508aade82b484abfcca593825649031db7cbdd0 --- /dev/null +++ b/mmseg/models/backbones/ddrnet.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer +from mmengine.model import BaseModule + +from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType + + +@MODELS.register_module() +class DDRNet(BaseModule): + """DDRNet backbone. + + This backbone is the implementation of `Deep Dual-resolution Networks for + Real-time and Accurate Semantic Segmentation of Road Scenes + `_. + Modified from https://github.com/ydhongHIT/DDRNet. + + Args: + in_channels (int): Number of input image channels. Default: 3. + channels: (int): The base channels of DDRNet. Default: 32. + ppm_channels (int): The channels of PPM module. Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + norm_cfg (dict): Config dict to build norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict, optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels: int = 3, + channels: int = 32, + ppm_channels: int = 128, + align_corners: bool = False, + norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + + self.in_channels = in_channels + self.ppm_channels = ppm_channels + + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + + # stage 0-2 + self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2) + self.relu = nn.ReLU() + + # low resolution(context) branch + self.context_branch_layers = nn.ModuleList() + for i in range(3): + self.context_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + inplanes=channels * 2**(i + 1), + planes=channels * 8 if i > 0 else channels * 4, + num_blocks=2 if i < 2 else 1, + stride=2)) + + # bilateral fusion + self.compression_1 = ConvModule( + channels * 4, + channels * 2, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.down_1 = ConvModule( + channels * 2, + channels * 4, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=None) + + self.compression_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.down_2 = nn.Sequential( + ConvModule( + channels * 2, + channels * 4, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels * 4, + channels * 8, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=None)) + + # high resolution(spatial) branch + self.spatial_branch_layers = nn.ModuleList() + for i in range(3): + self.spatial_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + inplanes=channels * 2, + planes=channels * 2, + num_blocks=2 if i < 2 else 1, + )) + + self.spp = DAPPM( + channels * 16, ppm_channels, channels * 4, num_scales=5) + + def _make_stem_layer(self, in_channels, channels, num_blocks): + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + + layers.extend([ + self._make_layer(BasicBlock, channels, channels, num_blocks), + nn.ReLU(), + self._make_layer( + BasicBlock, channels, channels * 2, num_blocks, stride=2), + nn.ReLU(), + ]) + + return nn.Sequential(*layers) + + def _make_layer(self, block, inplanes, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [ + block( + in_channels=inplanes, + channels=planes, + stride=stride, + downsample=downsample) + ] + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + in_channels=inplanes, + channels=planes, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) + + return nn.Sequential(*layers) + + def forward(self, x): + """Forward function.""" + out_size = (x.shape[-2] // 8, x.shape[-1] // 8) + + # stage 0-2 + x = self.stem(x) + + # stage3 + x_c = self.context_branch_layers[0](x) + x_s = self.spatial_branch_layers[0](x) + comp_c = self.compression_1(self.relu(x_c)) + x_c += self.down_1(self.relu(x_s)) + x_s += resize( + comp_c, + size=out_size, + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_context = x_s.clone() + + # stage4 + x_c = self.context_branch_layers[1](self.relu(x_c)) + x_s = self.spatial_branch_layers[1](self.relu(x_s)) + comp_c = self.compression_2(self.relu(x_c)) + x_c += self.down_2(self.relu(x_s)) + x_s += resize( + comp_c, + size=out_size, + mode='bilinear', + align_corners=self.align_corners) + + # stage5 + x_s = self.spatial_branch_layers[2](self.relu(x_s)) + x_c = self.context_branch_layers[2](self.relu(x_c)) + x_c = self.spp(x_c) + x_c = resize( + x_c, + size=out_size, + mode='bilinear', + align_corners=self.align_corners) + + return (temp_context, x_s + x_c) if self.training else x_s + x_c diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5ec672a086b5d67568514140023ce402eef92f --- /dev/null +++ b/mmseg/models/backbones/erfnet.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class DownsamplerBlock(BaseModule): + """Downsampler block of ERFNet. + + This module is a little different from basical ConvModule. + The features from Conv and MaxPool layers are + concatenated before BatchNorm. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels - in_channels, + kernel_size=3, + stride=2, + padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + conv_out = self.conv(input) + pool_out = self.pool(input) + pool_out = resize( + input=pool_out, + size=conv_out.size()[2:], + mode='bilinear', + align_corners=False) + output = torch.cat([conv_out, pool_out], 1) + output = self.bn(output) + output = self.act(output) + return output + + +class NonBottleneck1d(BaseModule): + """Non-bottleneck block of ERFNet. + + Args: + channels (int): Number of channels in Non-bottleneck block. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + dilation (int): Dilation rate for last two conv layers. + Default 1. + num_conv_layer (int): Number of 3x1 and 1x3 convolution layers. + Default 2. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + channels, + drop_rate=0, + dilation=1, + num_conv_layer=2, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.act = build_activation_layer(self.act_cfg) + + self.convs_layers = nn.ModuleList() + for conv_layer in range(num_conv_layer): + first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0) + first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1) + second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation) + second_conv_dilation = 1 if conv_layer == 0 else (1, dilation) + + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=first_conv_padding, + bias=True, + dilation=first_conv_dilation)) + self.convs_layers.append(self.act) + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=second_conv_padding, + bias=True, + dilation=second_conv_dilation)) + self.convs_layers.append( + build_norm_layer(self.norm_cfg, channels)[1]) + if conv_layer == 0: + self.convs_layers.append(self.act) + else: + self.convs_layers.append(nn.Dropout(p=drop_rate)) + + def forward(self, input): + output = input + for conv in self.convs_layers: + output = conv(output) + output = self.act(output + input) + return output + + +class UpsamplerBlock(BaseModule): + """Upsampler block of ERFNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=True) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + output = self.act(output) + return output + + +@MODELS.register_module() +class ERFNet(BaseModule): + """ERFNet backbone. + + This backbone is the implementation of `ERFNet: Efficient Residual + Factorized ConvNet for Real-time SemanticSegmentation + `_. + + Args: + in_channels (int): The number of channels of input + image. Default: 3. + enc_downsample_channels (Tuple[int]): Size of channel + numbers of various Downsampler block in encoder. + Default: (16, 64, 128). + enc_stage_non_bottlenecks (Tuple[int]): Number of stages of + Non-bottleneck block in encoder. + Default: (5, 8). + enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each + stage of Non-bottleneck block of encoder. + Default: (2, 4, 8, 16). + enc_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in encoder. + Default: (64, 128). + dec_upsample_channels (Tuple[int]): Size of channel numbers of + various Deconvolution block in decoder. + Default: (64, 16). + dec_stages_non_bottleneck (Tuple[int]): Number of stages of + Non-bottleneck block in decoder. + Default: (2, 2). + dec_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in decoder. + Default: (64, 16). + drop_rate (float): Probability of an element to be zeroed. + Default 0.1. + """ + + def __init__(self, + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_stage_non_bottlenecks=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + assert len(enc_downsample_channels) \ + == len(dec_upsample_channels)+1, 'Number of downsample\ + block of encoder does not \ + match number of upsample block of decoder!' + assert len(enc_downsample_channels) \ + == len(enc_stage_non_bottlenecks)+1, 'Number of \ + downsample block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(enc_downsample_channels) \ + == len(enc_non_bottleneck_channels)+1, 'Number of \ + downsample block of encoder does not match \ + number of channels of Non-bottleneck block of encoder!' + assert enc_stage_non_bottlenecks[-1] \ + % len(enc_non_bottleneck_dilations) == 0, 'Number of \ + Non-bottleneck block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(dec_upsample_channels) \ + == len(dec_stages_non_bottleneck), 'Number of \ + upsample block of decoder does not match \ + number of Non-bottleneck block of decoder!' + assert len(dec_stages_non_bottleneck) \ + == len(dec_non_bottleneck_channels), 'Number of \ + Non-bottleneck block of decoder does not match \ + number of channels of Non-bottleneck block of decoder!' + + self.in_channels = in_channels + self.enc_downsample_channels = enc_downsample_channels + self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks + self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations + self.enc_non_bottleneck_channels = enc_non_bottleneck_channels + self.dec_upsample_channels = dec_upsample_channels + self.dec_stages_non_bottleneck = dec_stages_non_bottleneck + self.dec_non_bottleneck_channels = dec_non_bottleneck_channels + self.dropout_ratio = dropout_ratio + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.encoder.append( + DownsamplerBlock(self.in_channels, enc_downsample_channels[0])) + + for i in range(len(enc_downsample_channels) - 1): + self.encoder.append( + DownsamplerBlock(enc_downsample_channels[i], + enc_downsample_channels[i + 1])) + # Last part of encoder is some dilated NonBottleneck1d blocks. + if i == len(enc_downsample_channels) - 2: + iteration_times = int(enc_stage_non_bottlenecks[-1] / + len(enc_non_bottleneck_dilations)) + for j in range(iteration_times): + for k in range(len(enc_non_bottleneck_dilations)): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k])) + else: + for j in range(enc_stage_non_bottlenecks[i]): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[i + 1], + self.dropout_ratio)) + + for i in range(len(dec_upsample_channels)): + if i == 0: + self.decoder.append( + UpsamplerBlock(enc_downsample_channels[-1], + dec_non_bottleneck_channels[i])) + else: + self.decoder.append( + UpsamplerBlock(dec_non_bottleneck_channels[i - 1], + dec_non_bottleneck_channels[i])) + for j in range(dec_stages_non_bottleneck[i]): + self.decoder.append( + NonBottleneck1d(dec_non_bottleneck_channels[i])) + + def forward(self, x): + for enc in self.encoder: + x = enc(x) + for dec in self.decoder: + x = dec(x) + return [x] diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff7a3191d2fee904c5200e0a526214a65f58b32 --- /dev/null +++ b/mmseg/models/backbones/fast_scnn.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmseg.models.decode_heads.psp_head import PPM +from mmseg.registry import MODELS +from ..utils import InvertedResidual, resize + + +class LearningToDownsample(nn.Module): + """Learning to downsample module. + + Args: + in_channels (int): Number of input channels. + dw_channels (tuple[int]): Number of output channels of the first and + the second depthwise conv (dwconv) layers. + out_channels (int): Number of output channels of the whole + 'learning to downsample' module. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + """ + + def __init__(self, + in_channels, + dw_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dw_act_cfg=None): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.dw_act_cfg = dw_act_cfg + dw_channels1 = dw_channels[0] + dw_channels2 = dw_channels[1] + + self.conv = ConvModule( + in_channels, + dw_channels1, + 3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.dsconv1 = DepthwiseSeparableConvModule( + dw_channels1, + dw_channels2, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + self.dsconv2 = DepthwiseSeparableConvModule( + dw_channels2, + out_channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + def forward(self, x): + x = self.conv(x) + x = self.dsconv1(x) + x = self.dsconv2(x) + return x + + +class GlobalFeatureExtractor(nn.Module): + """Global feature extractor module. + + Args: + in_channels (int): Number of input channels of the GFE module. + Default: 64 + block_channels (tuple[int]): Tuple of ints. Each int specifies the + number of output channels of each Inverted Residual module. + Default: (64, 96, 128) + out_channels(int): Number of output channels of the GFE module. + Default: 128 + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + Default: 6 + num_blocks (tuple[int]): Tuple of ints. Each int specifies the + number of times each Inverted Residual module is repeated. + The repeated Inverted Residual modules are called a 'group'. + Default: (3, 3, 3) + strides (tuple[int]): Tuple of ints. Each int specifies + the downsampling factor of each 'group'. + Default: (2, 2, 1) + pool_scales (tuple[int]): Tuple of ints. Each int specifies + the parameter required in 'global average pooling' within PPM. + Default: (1, 2, 3, 6) + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + expand_ratio=6, + num_blocks=(3, 3, 3), + strides=(2, 2, 1), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + assert len(block_channels) == len(num_blocks) == 3 + self.bottleneck1 = self._make_layer(in_channels, block_channels[0], + num_blocks[0], strides[0], + expand_ratio) + self.bottleneck2 = self._make_layer(block_channels[0], + block_channels[1], num_blocks[1], + strides[1], expand_ratio) + self.bottleneck3 = self._make_layer(block_channels[1], + block_channels[2], num_blocks[2], + strides[2], expand_ratio) + self.ppm = PPM( + pool_scales, + block_channels[2], + block_channels[2] // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=align_corners) + + self.out = ConvModule( + block_channels[2] * 2, + out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _make_layer(self, + in_channels, + out_channels, + blocks, + stride=1, + expand_ratio=6): + layers = [ + InvertedResidual( + in_channels, + out_channels, + stride, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + for i in range(1, blocks): + layers.append( + InvertedResidual( + out_channels, + out_channels, + 1, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.bottleneck1(x) + x = self.bottleneck2(x) + x = self.bottleneck3(x) + x = torch.cat([x, *self.ppm(x)], dim=1) + x = self.out(x) + return x + + +class FeatureFusionModule(nn.Module): + """Feature fusion module. + + Args: + higher_in_channels (int): Number of input channels of the + higher-resolution branch. + lower_in_channels (int): Number of input channels of the + lower-resolution branch. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + dwconv_act_cfg (dict): Config of activation layers in 3x3 conv. + Default: dict(type='ReLU'). + conv_act_cfg (dict): Config of activation layers in the two 1x1 conv. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + """ + + def __init__(self, + higher_in_channels, + lower_in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dwconv_act_cfg=dict(type='ReLU'), + conv_act_cfg=None, + align_corners=False): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dwconv_act_cfg = dwconv_act_cfg + self.conv_act_cfg = conv_act_cfg + self.align_corners = align_corners + self.dwconv = ConvModule( + lower_in_channels, + out_channels, + 3, + padding=1, + groups=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.dwconv_act_cfg) + self.conv_lower_res = ConvModule( + out_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.conv_higher_res = ConvModule( + higher_in_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.relu = nn.ReLU(True) + + def forward(self, higher_res_feature, lower_res_feature): + lower_res_feature = resize( + lower_res_feature, + size=higher_res_feature.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + lower_res_feature = self.dwconv(lower_res_feature) + lower_res_feature = self.conv_lower_res(lower_res_feature) + + higher_res_feature = self.conv_higher_res(higher_res_feature) + out = higher_res_feature + lower_res_feature + return self.relu(out) + + +@MODELS.register_module() +class FastSCNN(BaseModule): + """Fast-SCNN Backbone. + + This backbone is the implementation of `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels (int): Number of input image channels. Default: 3. + downsample_dw_channels (tuple[int]): Number of output channels after + the first conv layer & the second conv layer in + Learning-To-Downsample (LTD) module. + Default: (32, 48). + global_in_channels (int): Number of input channels of + Global Feature Extractor(GFE). + Equal to number of output channels of LTD. + Default: 64. + global_block_channels (tuple[int]): Tuple of integers that describe + the output channels for each of the MobileNet-v2 bottleneck + residual blocks in GFE. + Default: (64, 96, 128). + global_block_strides (tuple[int]): Tuple of integers + that describe the strides (downsampling factors) for each of the + MobileNet-v2 bottleneck residual blocks in GFE. + Default: (2, 2, 1). + global_out_channels (int): Number of output channels of GFE. + Default: 128. + higher_in_channels (int): Number of input channels of the higher + resolution branch in FFM. + Equal to global_in_channels. + Default: 64. + lower_in_channels (int): Number of input channels of the lower + resolution branch in FFM. + Equal to global_out_channels. + Default: 128. + fusion_out_channels (int): Number of output channels of FFM. + Default: 128. + out_indices (tuple): Tuple of indices of list + [higher_res_features, lower_res_features, fusion_output]. + Often set to (0,1,2) to enable aux. heads. + Default: (0, 1, 2). + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + dw_act_cfg=None, + init_cfg=None): + + super().__init__(init_cfg) + + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + + if global_in_channels != higher_in_channels: + raise AssertionError('Global Input Channels must be the same \ + with Higher Input Channels!') + elif global_out_channels != lower_in_channels: + raise AssertionError('Global Output Channels must be the same \ + with Lower Input Channels!') + + self.in_channels = in_channels + self.downsample_dw_channels1 = downsample_dw_channels[0] + self.downsample_dw_channels2 = downsample_dw_channels[1] + self.global_in_channels = global_in_channels + self.global_block_channels = global_block_channels + self.global_block_strides = global_block_strides + self.global_out_channels = global_out_channels + self.higher_in_channels = higher_in_channels + self.lower_in_channels = lower_in_channels + self.fusion_out_channels = fusion_out_channels + self.out_indices = out_indices + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.learning_to_downsample = LearningToDownsample( + in_channels, + downsample_dw_channels, + global_in_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + dw_act_cfg=dw_act_cfg) + self.global_feature_extractor = GlobalFeatureExtractor( + global_in_channels, + global_block_channels, + global_out_channels, + strides=self.global_block_strides, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.feature_fusion = FeatureFusionModule( + higher_in_channels, + lower_in_channels, + fusion_out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dwconv_act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, x): + higher_res_features = self.learning_to_downsample(x) + lower_res_features = self.global_feature_extractor(higher_res_features) + fusion_output = self.feature_fusion(higher_res_features, + lower_res_features) + + outs = [higher_res_features, lower_res_features, fusion_output] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2da755e731cfea911d47729f455c54c3d38a68e4 --- /dev/null +++ b/mmseg/models/backbones/hrnet.py @@ -0,0 +1,642 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import Upsample, resize +from .resnet import BasicBlock, Bottleneck + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + block_init_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + """Check branches configuration.""" + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ + f'{len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ + f'{len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ + f'{len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, num_channels[branch_index] * + block.expansion)[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + + return Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + """Build multiple branch.""" + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return ModuleList(branches) + + def _make_fuse_layers(self): + """Build fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + # we set align_corners=False for HRNet + Upsample( + scale_factor=2**(j - i), + mode='bilinear', + align_corners=False))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + elif j > i: + y = y + resize( + self.fuse_layers[i][j](x[j]), + size=x[i].shape[2:], + mode='bilinear', + align_corners=False) + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + This backbone is the implementation of `High-Resolution Representations + for Labeling Pixels and Regions `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Use `BN` by default. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + frozen_stages=-1, + zero_init_residual=False, + multiscale_output=True, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + # Assert configurations of 4 stages are in extra + assert 'stage1' in extra and 'stage2' in extra \ + and 'stage3' in extra and 'stage4' in extra + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + for i in range(4): + cfg = extra[f'stage{i + 1}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.frozen_stages = frozen_stages + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * block.expansion + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multiscale_output=multiscale_output) + + self._freeze_stages() + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + """Make each layer.""" + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + + return Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules), in_channels + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + + self.norm1.eval() + self.norm2.eval() + for m in [self.conv1, self.norm1, self.conv2, self.norm2]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if i == 1: + m = getattr(self, f'layer{i}') + t = getattr(self, f'transition{i}') + elif i == 4: + m = getattr(self, f'stage{i}') + else: + m = getattr(self, f'stage{i}') + t = getattr(self, f'transition{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + t.eval() + for param in t.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/icnet.py b/mmseg/models/backbones/icnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff3448569c5a3ec82a12726767fcbb48b3870d2 --- /dev/null +++ b/mmseg/models/backbones/icnet.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..decode_heads.psp_head import PPM +from ..utils import resize + + +@MODELS.register_module() +class ICNet(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This backbone is the implementation of + `ICNet `_. + + Args: + backbone_cfg (dict): Config dict to build backbone. Usually it is + ResNet but it can also be other backbones. + in_channels (int): The number of input image channels. Default: 3. + layer_channels (Sequence[int]): The numbers of feature channels at + layer 2 and layer 4 in ResNet. It can also be other backbones. + Default: (512, 2048). + light_branch_middle_channels (int): The number of channels of the + middle layer in light branch. Default: 32. + psp_out_channels (int): The number of channels of the output of PSP + module. Default: 512. + out_channels (Sequence[int]): The numbers of output feature channels + at each branches. Default: (64, 256, 256). + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + layer_channels=(512, 2048), + light_branch_middle_channels=32, + psp_out_channels=512, + out_channels=(64, 256, 256), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + if backbone_cfg is None: + raise TypeError('backbone_cfg must be passed from config file!') + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', mode='fan_out', layer='Conv2d'), + dict(type='Constant', val=1, layer='_BatchNorm'), + dict(type='Normal', mean=0.01, layer='Linear') + ] + super().__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.backbone = MODELS.build(backbone_cfg) + + # Note: Default `ceil_mode` is false in nn.MaxPool2d, set + # `ceil_mode=True` to keep information in the corner of feature map. + self.backbone.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=True) + + self.psp_modules = PPM( + pool_scales=pool_scales, + in_channels=layer_channels[1], + channels=psp_out_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + align_corners=align_corners) + + self.psp_bottleneck = ConvModule( + layer_channels[1] + len(pool_scales) * psp_out_channels, + psp_out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.conv_sub1 = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=out_channels[0], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + self.conv_sub2 = ConvModule( + layer_channels[0], + out_channels[1], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + self.conv_sub4 = ConvModule( + psp_out_channels, + out_channels[2], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + def forward(self, x): + output = [] + + # sub 1 + output.append(self.conv_sub1(x)) + + # sub 2 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.stem(x) + x = self.backbone.maxpool(x) + x = self.backbone.layer1(x) + x = self.backbone.layer2(x) + output.append(self.conv_sub2(x)) + + # sub 4 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.layer3(x) + x = self.backbone.layer4(x) + psp_outs = self.psp_modules(x) + [x] + psp_outs = torch.cat(psp_outs, dim=1) + x = self.psp_bottleneck(psp_outs) + + output.append(self.conv_sub4(x)) + + return output diff --git a/mmseg/models/backbones/mae.py b/mmseg/models/backbones/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f243f0857b9aca5454e8c1410075bff9281285 --- /dev/null +++ b/mmseg/models/backbones/mae.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved.import math +import math + +import torch +import torch.nn as nn +from mmengine.model import ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import _load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer + + +class MAEAttention(BEiTAttention): + """Multi-head self-attention with relative position bias used in MAE. + + This module is different from ``BEiTAttention`` by initializing the + relative bias table with zeros. + """ + + def init_weights(self): + """Initialize relative position bias with zeros.""" + + # As MAE initializes relative position bias as zeros and this class + # inherited from BEiT which initializes relative position bias + # with `trunc_normal`, `init_weights` here does + # nothing and just passes directly + + pass + + +class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + This module is different from ``BEiTTransformerEncoderLayer`` by replacing + ``BEiTAttention`` with ``MAEAttention``. + """ + + def build_attn(self, attn_cfg): + self.attn = MAEAttention(**attn_cfg) + + +@MODELS.register_module() +class MAE(BEiT): + """VisionTransformer with support for patch. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_values (float): Initialize the values of Attention and FFN + with learnable scaling. Defaults to 0.1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + out_indices=out_indices, + qv_bias=False, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + patch_norm=patch_norm, + final_norm=final_norm, + num_fcs=num_fcs, + norm_eval=norm_eval, + pretrained=pretrained, + init_values=init_values, + init_cfg=init_cfg) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self.num_patches = self.patch_shape[0] * self.patch_shape[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, embed_dims)) + + def _build_layers(self): + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + MAETransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias=True, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.patch_shape, + init_values=self.init_values)) + + def fix_init_weight(self): + """Rescale the initialization according to layer id. + + This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + """ + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + self.fix_init_weight() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + state_dict = self.resize_abs_pos_embed(state_dict) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def resize_abs_pos_embed(self, state_dict): + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(self.num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + return state_dict + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py new file mode 100644 index 0000000000000000000000000000000000000000..66556bdfca2b0bcb180afd23c2923c68b9ff3a69 --- /dev/null +++ b/mmseg/models/backbones/mit.py @@ -0,0 +1,450 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS +from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Conv to encode positional information. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super().__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class EfficientMultiheadAttention(MultiheadAttention): + """An implementation of Efficient Multi-head Attention of Segformer. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type='LN'), + sr_ratio=1): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + dropout_layer=dropout_layer, + init_cfg=init_cfg, + batch_first=batch_first, + bias=qkv_bias) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmseg import digit_version, mmcv_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'EfficientMultiheadAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # `need_weights=True` will let nn.MultiHeadAttention + # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` + # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set + # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. + # This issue - `https://github.com/pytorch/pytorch/issues/37583` report + # the error that large scale tensor sum operation may cause cuda error. + out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Segformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + init_cfg (dict, optional): Initialization config dict. + Default:None. + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + sr_ratio=1, + with_cp=False): + super().__init__() + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = EfficientMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + self.with_cp = with_cp + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class MixVisionTransformer(BaseModule): + """The backbone of Segformer. + + This backbone is the implementation of `SegFormer: Simple and + Efficient Design for Semantic Segmentation with + Transformers `_. + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 4, 8]. + patch_sizes (Sequence[int]): The patch_size of each overlapped patch + embedding. Default: [7, 3, 3, 3]. + strides (Sequence[int]): The stride of each overlapped patch embedding. + Default: [4, 2, 2, 2]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrained=None, + init_cfg=None, + with_cp=False): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + self.with_cp = with_cp + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=patch_sizes[i] // 2, + norm_cfg=norm_cfg) + layer = ModuleList([ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + sr_ratio=sr_ratios[i]) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + self.layers.append(ModuleList([patch_embed, layer, norm])) + cur += num_layer + + def init_weights(self): + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, hw_shape = layer[0](x) + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs diff --git a/mmseg/models/backbones/mobilenet_v2.py b/mmseg/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..1c21b5df97dade148136e8b0e6b039512f9e03f9 --- /dev/null +++ b/mmseg/models/backbones/mobilenet_v2.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import InvertedResidual, make_divisible + + +@MODELS.register_module() +class MobileNetV2(BaseModule): + """MobileNetV2 backbone. + + This backbone is the implementation of + `MobileNetV2: Inverted Residuals and Linear Bottlenecks + `_. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], + [6, 96, 3], [6, 160, 3], [6, 320, 1]] + + def __init__(self, + widen_factor=1., + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 7). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/mobilenet_v3.py b/mmseg/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..1efb6e097472d53a5269e52a39ff2cae48e834db --- /dev/null +++ b/mmseg/models/backbones/mobilenet_v3.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import Conv2dAdaptivePadding +from mmengine.model import BaseModule +from mmengine.utils import is_tuple_of +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import InvertedResidualV3 as InvertedResidual + + +@MODELS.register_module() +class MobileNetV3(BaseModule): + """MobileNetV3 backbone. + + This backbone is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + arch (str): Architecture of mobilnetv3, from {'small', 'large'}. + Default: 'small'. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (tuple[int]): Output from which layer. + Default: (0, 1, 12). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 + [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 + [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert arch in self.arch_settings + assert isinstance(reduction_factor, int) and reduction_factor > 0 + assert is_tuple_of(out_indices, int) + for index in out_indices: + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])+2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])+2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.reduction_factor = reduction_factor + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + + # build the first layer (layer0) + in_channels = 16 + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + + if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ + i >= 8: + mid_channels = mid_channels // self.reduction_factor + out_channels = out_channels // self.reduction_factor + + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=(in_channels != mid_channels), + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + + # build the last layer + # block5 layer12 os=32 for small model + # block6 layer16 os=32 for large model + layer = ConvModule( + in_channels=in_channels, + out_channels=576 if self.arch == 'small' else 960, + kernel_size=1, + stride=1, + dilation=4, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = f'layer{len(layer_setting) + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + + # next, convert backbone MobileNetV3 to a semantic segmentation version + if self.arch == 'small': + self.layer4.depthwise_conv.conv.stride = (1, 1) + self.layer9.depthwise_conv.conv.stride = (1, 1) + for i in range(4, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 9: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + else: + self.layer7.depthwise_conv.conv.stride = (1, 1) + self.layer13.depthwise_conv.conv.stride = (1, 1) + for i in range(7, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 13: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs + + def _freeze_stages(self): + for i in range(self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/mscan.py b/mmseg/models/backbones/mscan.py new file mode 100644 index 0000000000000000000000000000000000000000..7150cb7a1c13d11dcdcc6fbbc72931154853929e --- /dev/null +++ b/mmseg/models/backbones/mscan.py @@ -0,0 +1,467 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS + + +class Mlp(BaseModule): + """Multi Layer Perceptron (MLP) Module. + + Args: + in_features (int): The dimension of input features. + hidden_features (int): The dimension of hidden features. + Defaults: None. + out_features (int): The dimension of output features. + Defaults: None. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward function.""" + + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): The dimension of input channels. + out_channels (int): The dimension of output channels. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + in_channels, + out_channels, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels // 2)[1], + build_activation_layer(act_cfg), + nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class MSCAAttention(BaseModule): + """Attention Module in Multi-Scale Convolutional Attention Module (MSCA). + + Args: + channels (int): The dimension of channels. + kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + """ + + def __init__(self, + channels, + kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + paddings=[2, [0, 3], [0, 5], [0, 10]]): + super().__init__() + self.conv0 = nn.Conv2d( + channels, + channels, + kernel_size=kernel_sizes[0], + padding=paddings[0], + groups=channels) + for i, (kernel_size, + padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])): + kernel_size_ = [kernel_size, kernel_size[::-1]] + padding_ = [padding, padding[::-1]] + conv_name = [f'conv{i}_1', f'conv{i}_2'] + for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_, + conv_name): + self.add_module( + i_conv, + nn.Conv2d( + channels, + channels, + tuple(i_kernel), + padding=i_pad, + groups=channels)) + self.conv3 = nn.Conv2d(channels, channels, 1) + + def forward(self, x): + """Forward function.""" + + u = x.clone() + + attn = self.conv0(x) + + # Multi-Scale Feature extraction + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + + attn = attn + attn_0 + attn_1 + attn_2 + # Channel Mixing + attn = self.conv3(attn) + + # Convolutional Attention + x = attn * u + + return x + + +class MSCASpatialAttention(BaseModule): + """Spatial Attention Module in Multi-Scale Convolutional Attention Module + (MSCA). + + Args: + in_channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU')): + super().__init__() + self.proj_1 = nn.Conv2d(in_channels, in_channels, 1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = MSCAAttention(in_channels, + attention_kernel_sizes, + attention_kernel_paddings) + self.proj_2 = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + """Forward function.""" + + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class MSCABlock(BaseModule): + """Basic Multi-Scale Convolutional Attention Block. It leverage the large- + kernel attention (LKA) mechanism to build both channel and spatial + attention. In each branch, it uses two depth-wise strip convolutions to + approximate standard depth-wise convolutions with large kernels. The kernel + size for each branch is set to 7, 11, and 21, respectively. + + Args: + channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + mlp_ratio (float): The ratio of multiple input dimension to + calculate hidden feature in MLP layer. Defaults: 4.0. + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + drop_path (float): The ratio of drop paths. + Defaults: 0.0. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.norm1 = build_norm_layer(norm_cfg, channels)[1] + self.attn = MSCASpatialAttention(channels, attention_kernel_sizes, + attention_kernel_paddings, act_cfg) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, channels)[1] + mlp_hidden_channels = int(channels * mlp_ratio) + self.mlp = Mlp( + in_features=channels, + hidden_features=mlp_hidden_channels, + act_cfg=act_cfg, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + + def forward(self, x, H, W): + """Forward function.""" + + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.attn(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + x = x.view(B, C, N).permute(0, 2, 1) + return x + + +class OverlapPatchEmbed(BaseModule): + """Image to Patch Embedding. + + Args: + patch_size (int): The patch size. + Defaults: 7. + stride (int): Stride of the convolutional layer. + Default: 4. + in_channels (int): The number of input channels. + Defaults: 3. + embed_dims (int): The dimensions of embedding. + Defaults: 768. + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + patch_size=7, + stride=4, + in_channels=3, + embed_dim=768, + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2) + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + + x = x.flatten(2).transpose(1, 2) + + return x, H, W + + +@MODELS.register_module() +class MSCAN(BaseModule): + """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. + + This backbone is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Args: + in_channels (int): The number of input channels. Defaults: 3. + embed_dims (list[int]): Embedding dimension. + Defaults: [64, 128, 256, 512]. + mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim. + Defaults: [4, 4, 4, 4]. + drop_rate (float): Dropout rate. Defaults: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0. + depths (list[int]): Depths of each Swin Transformer stage. + Default: [3, 4, 6, 3]. + num_stages (int): MSCAN stages. Default: 4. + attention_kernel_sizes (list): Size of attention kernel in + Attention Module (Figure 2(b) of original paper). + Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): Size of attention paddings + in Attention Module (Figure 2(b) of original paper). + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + norm_cfg (dict): Config of norm layers. + Defaults: dict(type='SyncBN', requires_grad=True). + pretrained (str, optional): model pretrained path. + Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[4, 4, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.num_stages = num_stages + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + if i == 0: + patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + MSCABlock( + channels=embed_dims[i], + attention_kernel_sizes=attention_kernel_sizes, + attention_kernel_paddings=attention_kernel_paddings, + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + act_cfg=act_cfg, + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm = nn.LayerNorm(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + """Initialize modules of MSCAN.""" + + print('init cfg', self.init_cfg) + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + """Forward function.""" + + B = x.shape[0] + outs = [] + + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs diff --git a/mmseg/models/backbones/pidnet.py b/mmseg/models/backbones/pidnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b711a373701c0771c5c5997bbb8e5b345d70924 --- /dev/null +++ b/mmseg/models/backbones/pidnet.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.runner import CheckpointLoader +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType +from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck + + +class PagFM(BaseModule): + """Pixel-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + channels (int): The number of channels. + after_relu (bool): Whether to use ReLU before attention. + Default: False. + with_channel (bool): Whether to use channel attention. + Default: False. + upsample_mode (str): The mode of upsample. Default: 'bilinear'. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(typ='ReLU', inplace=True). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + channels: int, + after_relu: bool = False, + with_channel: bool = False, + upsample_mode: str = 'bilinear', + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(typ='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.after_relu = after_relu + self.with_channel = with_channel + self.upsample_mode = upsample_mode + self.f_i = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None) + self.f_p = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None) + if with_channel: + self.up = ConvModule( + channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + if after_relu: + self.relu = MODELS.build(act_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor: + """Forward function. + + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + + Returns: + Tensor: The feature map with pixel-attention-guided fusion. + """ + if self.after_relu: + x_p = self.relu(x_p) + x_i = self.relu(x_i) + + f_i = self.f_i(x_i) + f_i = F.interpolate( + f_i, + size=x_p.shape[2:], + mode=self.upsample_mode, + align_corners=False) + + f_p = self.f_p(x_p) + + if self.with_channel: + sigma = torch.sigmoid(self.up(f_p * f_i)) + else: + sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1)) + + x_i = F.interpolate( + x_i, + size=x_p.shape[2:], + mode=self.upsample_mode, + align_corners=False) + + out = sigma * x_i + (1 - sigma) * x_p + return out + + +class Bag(BaseModule): + """Boundary-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int): The kernel size of the convolution. Default: 3. + padding (int): The padding of the convolution. Default: 1. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer. + Default: dict(order=('norm', 'act', 'conv')). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + + self.conv = ConvModule( + in_channels, + out_channels, + kernel_size, + padding=padding, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor: + """Forward function. + + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + x_d (Tensor): The featrue map from D branch. + + Returns: + Tensor: The feature map with boundary-attention-guided fusion. + """ + sigma = torch.sigmoid(x_d) + return self.conv(sigma * x_p + (1 - sigma) * x_i) + + +class LightBag(BaseModule): + """Light Boundary-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. Default: None. + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.f_p = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.f_i = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor: + """Forward function. + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + x_d (Tensor): The featrue map from D branch. + + Returns: + Tensor: The feature map with light boundary-attention-guided + fusion. + """ + sigma = torch.sigmoid(x_d) + + f_p = self.f_p((1 - sigma) * x_i + x_p) + f_i = self.f_i(x_i + sigma * x_p) + + return f_p + f_i + + +@MODELS.register_module() +class PIDNet(BaseModule): + """PIDNet backbone. + + This backbone is the implementation of `PIDNet: A Real-time Semantic + Segmentation Network Inspired from PID Controller + `_. + Modified from https://github.com/XuJiacong/PIDNet. + + Licensed under the MIT License. + + Args: + in_channels (int): The number of input channels. Default: 3. + channels (int): The number of channels in the stem layer. Default: 64. + ppm_channels (int): The number of channels in the PPM layer. + Default: 96. + num_stem_blocks (int): The number of blocks in the stem layer. + Default: 2. + num_branch_blocks (int): The number of blocks in the branch layer. + Default: 3. + align_corners (bool): The align_corners argument of F.interpolate. + Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int = 3, + channels: int = 64, + ppm_channels: int = 96, + num_stem_blocks: int = 2, + num_branch_blocks: int = 3, + align_corners: bool = False, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None, + **kwargs): + super().__init__(init_cfg) + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + + # stem layer + self.stem = self._make_stem_layer(in_channels, channels, + num_stem_blocks) + self.relu = nn.ReLU() + + # I Branch + self.i_branch_layers = nn.ModuleList() + for i in range(3): + self.i_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + in_channels=channels * 2**(i + 1), + channels=channels * 8 if i > 0 else channels * 4, + num_blocks=num_branch_blocks if i < 2 else 2, + stride=2)) + + # P Branch + self.p_branch_layers = nn.ModuleList() + for i in range(3): + self.p_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + in_channels=channels * 2, + channels=channels * 2, + num_blocks=num_stem_blocks if i < 2 else 1)) + self.compression_1 = ConvModule( + channels * 4, + channels * 2, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.compression_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.pag_1 = PagFM(channels * 2, channels) + self.pag_2 = PagFM(channels * 2, channels) + + # D Branch + if num_stem_blocks == 2: + self.d_branch_layers = nn.ModuleList([ + self._make_single_layer(BasicBlock, channels * 2, channels), + self._make_layer(Bottleneck, channels, channels, 1) + ]) + channel_expand = 1 + spp_module = PAPPM + dfm_module = LightBag + act_cfg_dfm = None + else: + self.d_branch_layers = nn.ModuleList([ + self._make_single_layer(BasicBlock, channels * 2, + channels * 2), + self._make_single_layer(BasicBlock, channels * 2, channels * 2) + ]) + channel_expand = 2 + spp_module = DAPPM + dfm_module = Bag + act_cfg_dfm = act_cfg + + self.diff_1 = ConvModule( + channels * 4, + channels * channel_expand, + kernel_size=3, + padding=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.diff_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=3, + padding=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + + self.spp = spp_module( + channels * 16, ppm_channels, channels * 4, num_scales=5) + self.dfm = dfm_module( + channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm) + + self.d_branch_layers.append( + self._make_layer(Bottleneck, channels * 2, channels * 2, 1)) + + def _make_stem_layer(self, in_channels: int, channels: int, + num_blocks: int) -> nn.Sequential: + """Make stem layer. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_blocks (int): Number of blocks. + + Returns: + nn.Sequential: The stem layer. + """ + + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + + layers.append( + self._make_layer(BasicBlock, channels, channels, num_blocks)) + layers.append(nn.ReLU()) + layers.append( + self._make_layer( + BasicBlock, channels, channels * 2, num_blocks, stride=2)) + layers.append(nn.ReLU()) + + return nn.Sequential(*layers) + + def _make_layer(self, + block: BasicBlock, + in_channels: int, + channels: int, + num_blocks: int, + stride: int = 1) -> nn.Sequential: + """Make layer for PIDNet backbone. + Args: + block (BasicBlock): Basic block. + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. Default: 1. + + Returns: + nn.Sequential: The Branch Layer. + """ + downsample = None + if stride != 1 or in_channels != channels * block.expansion: + downsample = ConvModule( + in_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + norm_cfg=self.norm_cfg, + act_cfg=None) + + layers = [block(in_channels, channels, stride, downsample)] + in_channels = channels * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + in_channels, + channels, + stride=1, + act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) + return nn.Sequential(*layers) + + def _make_single_layer(self, + block: Union[BasicBlock, Bottleneck], + in_channels: int, + channels: int, + stride: int = 1) -> nn.Module: + """Make single layer for PIDNet backbone. + Args: + block (BasicBlock or Bottleneck): Basic block or Bottleneck. + in_channels (int): Number of input channels. + channels (int): Number of output channels. + stride (int): Stride of the first block. Default: 1. + + Returns: + nn.Module + """ + + downsample = None + if stride != 1 or in_channels != channels * block.expansion: + downsample = ConvModule( + in_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + norm_cfg=self.norm_cfg, + act_cfg=None) + return block( + in_channels, channels, stride, downsample, act_cfg_out=None) + + def init_weights(self): + """Initialize the weights in backbone. + + Since the D branch is not initialized by the pre-trained model, we + initialize it with the same method as the ResNet. + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if self.init_cfg is not None: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], map_location='cpu') + self.load_state_dict(ckpt, strict=False) + + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + + Args: + x (Tensor): Input tensor with shape (B, C, H, W). + + Returns: + Tensor or tuple[Tensor]: If self.training is True, return + tuple[Tensor], else return Tensor. + """ + w_out = x.shape[-1] // 8 + h_out = x.shape[-2] // 8 + + # stage 0-2 + x = self.stem(x) + + # stage 3 + x_i = self.relu(self.i_branch_layers[0](x)) + x_p = self.p_branch_layers[0](x) + x_d = self.d_branch_layers[0](x) + + comp_i = self.compression_1(x_i) + x_p = self.pag_1(x_p, comp_i) + diff_i = self.diff_1(x_i) + x_d += F.interpolate( + diff_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_p = x_p.clone() + + # stage 4 + x_i = self.relu(self.i_branch_layers[1](x_i)) + x_p = self.p_branch_layers[1](self.relu(x_p)) + x_d = self.d_branch_layers[1](self.relu(x_d)) + + comp_i = self.compression_2(x_i) + x_p = self.pag_2(x_p, comp_i) + diff_i = self.diff_2(x_i) + x_d += F.interpolate( + diff_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_d = x_d.clone() + + # stage 5 + x_i = self.i_branch_layers[2](x_i) + x_p = self.p_branch_layers[2](self.relu(x_p)) + x_d = self.d_branch_layers[2](self.relu(x_d)) + + x_i = self.spp(x_i) + x_i = F.interpolate( + x_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + out = self.dfm(x_p, x_i, x_d) + return (temp_p, out, temp_d) if self.training else out diff --git a/mmseg/models/backbones/resnest.py b/mmseg/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc380b4460915f476ffc1febcfc145a94fc7c7a --- /dev/null +++ b/mmseg/models/backbones/resnest.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None): + super().__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super().__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + This backbone is the implementation of `ResNeSt: + Split-Attention Networks `_. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super().__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmseg/models/backbones/resnet.py b/mmseg/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9226c90d85c938e76f322e58643ee9d7b17ba27b --- /dev/null +++ b/mmseg/models/backbones/resnet.py @@ -0,0 +1,712 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + """Basic block for ResNet.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super().__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super().__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + """Forward function for plugins.""" + out = x + for name in plugin_names: + out = getattr(self, name)(x) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + This backbone is the improved implementation of `Deep Residual Learning + for Image Recognition `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: (1, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default: (1, 1, 1, 1). + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: 'pytorch'. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): Dictionary to construct and config conv layer. + When conv_cfg is None, cfg will be set to dict(type='Conv2d'). + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (dict | None): Dictionary to construct and config DCN conv layer. + When dcn is not None, conv_cfg must be None. Default: None. + stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each + stage. The length of stage_with_dcn is equal to num_stages. + Default: (False, False, False, False). + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + + - position (str, required): Position inside block to insert plugin, + options: 'after_conv1', 'after_conv2', 'after_conv3'. + + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + Default: None. + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None. + contract_dilation (bool): Whether contract first dilation of each layer + Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + multi_grid=None, + contract_dilation=False, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.multi_grid = multi_grid + self.contract_dilation = contract_dilation + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + # multi grid is applied to last layer only + stage_multi_grid = multi_grid if i == len( + self.stage_blocks) - 1 else None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + multi_grid=stage_multi_grid, + contract_dilation=contract_dilation, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i+1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """make plugins for ResNet 'stage_idx'th stage . + + Currently we support to insert 'context_block', + 'empirical_attention_block', 'nonlocal_block' into the backbone like + ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be : + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose 'stage_idx=0', the structure of blocks in the stage would be: + conv1-> conv2->conv3->yyy->zzz1->zzz2 + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class ResNetV1c(ResNet): + """ResNetV1c variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in + the input stem with three 3x3 convs. For more details please refer to `Bag + of Tricks for Image Classification with Convolutional Neural Networks + `_. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class ResNetV1d(ResNet): + """ResNetV1d variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=True, **kwargs) diff --git a/mmseg/models/backbones/resnext.py b/mmseg/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..67a244a12f61b78ee12e89e8b45868781208614c --- /dev/null +++ b/mmseg/models/backbones/resnext.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + super().__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + This backbone is the implementation of `Aggregated + Residual Transformations for Deep Neural + Networks `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + num_stages (int): Resnet stages, normally 4. + groups (int): Group of resnext. + base_width (int): Base width of resnext. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmseg.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super().__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmseg/models/backbones/stdc.py b/mmseg/models/backbones/stdc.py new file mode 100644 index 0000000000000000000000000000000000000000..758a3c92e07dc8d2051f670adf00d163019d758c --- /dev/null +++ b/mmseg/models/backbones/stdc.py @@ -0,0 +1,422 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/MichaelFan01/STDC-Seg.""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmseg.registry import MODELS +from ..utils import resize +from .bisenetv1 import AttentionRefinementModule + + +class STDCModule(BaseModule): + """STDCModule. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels before scaling. + stride (int): The number of stride for the first conv layer. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layers. + fusion_type (str): Type of fusion operation. Default: 'add'. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + norm_cfg=None, + act_cfg=None, + num_convs=4, + fusion_type='add', + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert num_convs > 1 + assert fusion_type in ['add', 'cat'] + self.stride = stride + self.with_downsample = True if self.stride == 2 else False + self.fusion_type = fusion_type + + self.layers = ModuleList() + conv_0 = ConvModule( + in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) + + if self.with_downsample: + self.downsample = ConvModule( + out_channels // 2, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels // 2, + norm_cfg=norm_cfg, + act_cfg=None) + + if self.fusion_type == 'add': + self.layers.append(nn.Sequential(conv_0, self.downsample)) + self.skip = Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None)) + else: + self.layers.append(conv_0) + self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + self.layers.append(conv_0) + + for i in range(1, num_convs): + out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i + self.layers.append( + ConvModule( + out_channels // 2**i, + out_channels // out_factor, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + if self.fusion_type == 'add': + out = self.forward_add(inputs) + else: + out = self.forward_cat(inputs) + return out + + def forward_add(self, inputs): + layer_outputs = [] + x = inputs.clone() + for layer in self.layers: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + inputs = self.skip(inputs) + + return torch.cat(layer_outputs, dim=1) + inputs + + def forward_cat(self, inputs): + x0 = self.layers[0](inputs) + layer_outputs = [x0] + for i, layer in enumerate(self.layers[1:]): + if i == 0: + if self.with_downsample: + x = layer(self.downsample(x0)) + else: + x = layer(x0) + else: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + layer_outputs[0] = self.skip(x0) + return torch.cat(layer_outputs, dim=1) + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module. This module is different from FeatureFusionModule + in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter + channel number is calculated by given `scale_factor`, while + FeatureFusionModule in BiSeNetV1 only uses one ConvModule in + `self.conv_atten`. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + scale_factor (int): The number of channel scale factor. + Default: 4. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): The activation config for conv layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scale_factor=4, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + channels = out_channels // scale_factor + self.conv0 = ConvModule( + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + out_channels, + channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=act_cfg), + ConvModule( + channels, + out_channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=None), nn.Sigmoid()) + + def forward(self, spatial_inputs, context_inputs): + inputs = torch.cat([spatial_inputs, context_inputs], dim=1) + x = self.conv0(inputs) + attn = self.attention(x) + x_attn = x * attn + return x_attn + x + + +@MODELS.register_module() +class STDCNet(BaseModule): + """This backbone is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + stdc_type (int): The type of backbone structure, + `STDCNet1` and`STDCNet2` denotes two main backbones in paper, + whose FLOPs is 813M and 1446M, respectively. + in_channels (int): The num of input_channels. + channels (tuple[int]): The output channels for each stage. + bottleneck_type (str): The type of STDC Module type, the value must + be 'add' or 'cat'. + norm_cfg (dict): Config dict for normalization layer. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layer at each STDC Module. + Default: 4. + with_final_conv (bool): Whether add a conv layer at the Module output. + Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> import torch + >>> stdc_type = 'STDCNet1' + >>> in_channels = 3 + >>> channels = (32, 64, 256, 512, 1024) + >>> bottleneck_type = 'cat' + >>> inputs = torch.rand(1, 3, 1024, 2048) + >>> self = STDCNet(stdc_type, in_channels, + ... channels, bottleneck_type).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 256, 128, 256]) + outputs[1].shape = torch.Size([1, 512, 64, 128]) + outputs[2].shape = torch.Size([1, 1024, 32, 64]) + """ + + arch_settings = { + 'STDCNet1': [(2, 1), (2, 1), (2, 1)], + 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] + } + + def __init__(self, + stdc_type, + in_channels, + channels, + bottleneck_type, + norm_cfg, + act_cfg, + num_convs=4, + with_final_conv=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert stdc_type in self.arch_settings, \ + f'invalid structure {stdc_type} for STDCNet.' + assert bottleneck_type in ['add', 'cat'],\ + f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' + + assert len(channels) == 5,\ + f'invalid channels length {len(channels)} for STDCNet.' + + self.in_channels = in_channels + self.channels = channels + self.stage_strides = self.arch_settings[stdc_type] + self.prtrained = pretrained + self.num_convs = num_convs + self.with_final_conv = with_final_conv + + self.stages = ModuleList([ + ConvModule( + self.in_channels, + self.channels[0], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ]) + # `self.num_shallow_features` is the number of shallow modules in + # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. + # They are both not used for following modules like Attention + # Refinement Module and Feature Fusion Module. + # Thus they would be cut from `outs`. Please refer to Figure 4 + # of original paper for more details. + self.num_shallow_features = len(self.stages) + + for strides in self.stage_strides: + idx = len(self.stages) - 1 + self.stages.append( + self._make_stage(self.channels[idx], self.channels[idx + 1], + strides, norm_cfg, act_cfg, bottleneck_type)) + # After appending, `self.stages` is a ModuleList including several + # shallow modules and STDCModules. + # (len(self.stages) == + # self.num_shallow_features + len(self.stage_strides)) + if self.with_final_conv: + self.final_conv = ConvModule( + self.channels[-1], + max(1024, self.channels[-1]), + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def _make_stage(self, in_channels, out_channels, strides, norm_cfg, + act_cfg, bottleneck_type): + layers = [] + for i, stride in enumerate(strides): + layers.append( + STDCModule( + in_channels if i == 0 else out_channels, + out_channels, + stride, + norm_cfg, + act_cfg, + num_convs=self.num_convs, + fusion_type=bottleneck_type)) + return Sequential(*layers) + + def forward(self, x): + outs = [] + for stage in self.stages: + x = stage(x) + outs.append(x) + if self.with_final_conv: + outs[-1] = self.final_conv(outs[-1]) + outs = outs[self.num_shallow_features:] + return tuple(outs) + + +@MODELS.register_module() +class STDCContextPathNet(BaseModule): + """STDCNet with Context Path. The `outs` below is a list of three feature + maps from deep to shallow, whose height and width is from small to big, + respectively. The biggest feature map of `outs` is outputted for + `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. + The other two feature maps are used for Attention Refinement Module, + respectively. Besides, the biggest feature map of `outs` and the last + output of Attention Refinement Module are concatenated for Feature Fusion + Module. Then, this fusion feature map `feat_fuse` would be outputted for + `decode_head`. More details please refer to Figure 4 of original paper. + + Args: + backbone_cfg (dict): Config dict for stdc backbone. + last_in_channels (tuple(int)), The number of channels of last + two feature maps from stdc backbone. Default: (1024, 512). + out_channels (int): The channels of output feature maps. + Default: 128. + ffm_cfg (dict): Config dict for Feature Fusion Module. Default: + `dict(in_channels=512, out_channels=256, scale_factor=4)`. + upsample_mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'``. + align_corners (str): align_corners argument of F.interpolate. It + must be `None` if upsample_mode is ``'nearest'``. Default: None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Return: + outputs (tuple): The tuple of list of output feature map for + auxiliary heads and decoder head. + """ + + def __init__(self, + backbone_cfg, + last_in_channels=(1024, 512), + out_channels=128, + ffm_cfg=dict( + in_channels=512, out_channels=256, scale_factor=4), + upsample_mode='nearest', + align_corners=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.backbone = MODELS.build(backbone_cfg) + self.arms = ModuleList() + self.convs = ModuleList() + for channels in last_in_channels: + self.arms.append(AttentionRefinementModule(channels, out_channels)) + self.convs.append( + ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=norm_cfg)) + self.conv_avg = ConvModule( + last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) + + self.ffm = FeatureFusionModule(**ffm_cfg) + + self.upsample_mode = upsample_mode + self.align_corners = align_corners + + def forward(self, x): + outs = list(self.backbone(x)) + avg = F.adaptive_avg_pool2d(outs[-1], 1) + avg_feat = self.conv_avg(avg) + + feature_up = resize( + avg_feat, + size=outs[-1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + arms_out = [] + for i in range(len(self.arms)): + x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up + feature_up = resize( + x_arm, + size=outs[len(outs) - 1 - i - 1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + feature_up = self.convs[i](feature_up) + arms_out.append(feature_up) + + feat_fuse = self.ffm(outs[0], arms_out[1]) + + # The `outputs` has four feature maps. + # `outs[0]` is outputted for `STDCHead` auxiliary head. + # Two feature maps of `arms_out` are outputted for auxiliary head. + # `feat_fuse` is outputted for decoder head. + outputs = [outs[0]] + list(arms_out) + [feat_fuse] + return tuple(outputs) diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..67b28a96e15fe81e8213d67518d664383a4fd255 --- /dev/null +++ b/mmseg/models/backbones/swin.py @@ -0,0 +1,757 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) +from mmengine.runner import CheckpointLoader +from mmengine.utils import to_2tuple + +from mmseg.registry import MODELS +from ..utils.embed import PatchEmbed, PatchMerging + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.with_cp = with_cp + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + init_cfg=None) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@MODELS.register_module() +class SwinTransformer(BaseModule): + """Swin Transformer backbone. + + This backbone is the implementation of `Swin Transformer: + Hierarchical Vision Transformer using Shifted + Windows `_. + Inspiration from https://github.com/microsoft/Swin-Transformer. + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LN'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained=None, + frozen_stages=-1, + init_cfg=None): + self.frozen_stages = frozen_stages + + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + init_cfg = init_cfg + else: + raise TypeError('pretrained must be a str or None') + + super().__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=int(mlp_ratio * in_channels), + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + if self.init_cfg is None: + print_log(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + _state_dict = ckpt['model'] + else: + _state_dict = ckpt + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + print_log('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + if table_key in self.state_dict(): + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + print_log(f'Error in loading {table_key}, pass') + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape( + 1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + self.load_state_dict(state_dict, strict=False) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs diff --git a/mmseg/models/backbones/timm_backbone.py b/mmseg/models/backbones/timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..1eef302bddeac3cee71412bcb481b68b796e515f --- /dev/null +++ b/mmseg/models/backbones/timm_backbone.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from mmengine.model import BaseModule +from mmengine.registry import MODELS as MMENGINE_MODELS + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class TIMMBackbone(BaseModule): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + features_only=True, + pretrained=True, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super().__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model(x) + return features diff --git a/mmseg/models/backbones/twins.py b/mmseg/models/backbones/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a6eea795cf53bee6b52ece80d5d90ecc969970 --- /dev/null +++ b/mmseg/models/backbones/twins.py @@ -0,0 +1,588 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.models.backbones.mit import EfficientMultiheadAttention +from mmseg.registry import MODELS +from ..utils.embed import PatchEmbed + + +class GlobalSubsampledAttention(EfficientMultiheadAttention): + """Global Sub-sampled Attention (Spatial Reduction Attention) + + This module is modified from EfficientMultiheadAttention, + which is a module from mmseg.models.backbones.mit.py. + Specifically, there is no difference between + `GlobalSubsampledAttention` and `EfficientMultiheadAttention`, + `GlobalSubsampledAttention` is built as a brand new class + because it is renamed as `Global sub-sampled attention (GSA)` + in paper. + + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dims) + or (n, batch, embed_dims). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default: True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT. + Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type='LN'), + sr_ratio=1, + init_cfg=None): + super().__init__( + embed_dims, + num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio, + init_cfg=init_cfg) + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GSA. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \ + f'divided by num_heads ' \ + f'{num_heads}.' + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + x = x.view(b, h, w, c) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - w % self.window_size) % self.window_size + pad_b = (self.window_size - h % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(b, _h, self.window_size, _w, self.window_size, + c).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(b, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, c // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size, + self.window_size, c) + x = attn.transpose(2, 3).reshape(b, _h * self.window_size, + _w * self.window_size, c) + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + x = x.reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer in Twins-SVT. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + feat_token = x + cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w) + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +@MODELS.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4, 8]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [3, 4, 6, 3] + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [8, 4, 2, 1]. + norm_after_stage(bool): Add extra norm. Default False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + self.depths = depths + + # patch_embed + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.layers = ModuleList() + + for i in range(len(depths)): + self.patch_embeds.append( + PatchEmbed( + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dims=embed_dims[i], + conv_type='Conv2d', + kernel_size=patch_sizes[i], + stride=strides[i], + padding='corner', + norm_cfg=norm_cfg)) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in embed_dims + ]) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=sr_ratios[k]) for i in range(depths[k]) + ]) + self.layers.append(_block) + cur += depths[k] + + self.norm_name, norm = build_norm_layer( + norm_cfg, embed_dims[-1], postfix=1) + + self.out_indices = out_indices + self.norm_after_stage = norm_after_stage + if self.norm_after_stage: + self.norm_list = ModuleList() + for dim in embed_dims: + self.norm_list.append(build_norm_layer(norm_cfg, dim)[1]) + + def init_weights(self): + if self.init_cfg is not None: + super().init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(len(self.depths)): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.layers[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + if self.norm_after_stage: + x = self.norm_list[i](x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@MODELS.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Dropout rate. Default 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.2. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [4, 4, 4]. + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [4, 2, 1]. + windiow_sizes (list): Window size of LSA. Default: [7, 7, 7], + input_features_slice(bool): Input features need slice. Default: False. + norm_after_stage(bool): Add extra norm. Default False. + strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2) + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_cfg=dict(type='LN'), + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + windiow_sizes=[7, 7, 7], + norm_after_stage=True, + pretrained=None, + init_cfg=None): + super().__init__(in_channels, embed_dims, patch_sizes, strides, + num_heads, mlp_ratios, out_indices, qkv_bias, + drop_rate, attn_drop_rate, drop_path_rate, norm_cfg, + depths, sr_ratios, norm_after_stage, pretrained, + init_cfg) + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + for k in range(len(depths)): + for i in range(depths[k]): + if i % 2 == 0: + self.layers[k][i] = \ + LSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:k])+i], + qkv_bias=qkv_bias, + window_size=windiow_sizes[k]) diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..545921db8e14668e454f5834f9a1618fe0c04ffe --- /dev/null +++ b/mmseg/models/backbones/unet.py @@ -0,0 +1,436 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import UpConvBlock, Upsample + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@MODELS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super().__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + norm_name, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@MODELS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super().__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@MODELS.register_module() +class UNet(BaseModule): + """UNet backbone. + + This backbone is the implementation of `U-Net: Convolutional Networks + for Biomedical Image Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append(nn.Sequential(*enc_conv_block)) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be divisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0f688fcc46680b13904a26f14269b3d19d6ce3 --- /dev/null +++ b/mmseg/models/backbones/vit.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.registry import MODELS +from ..utils import PatchEmbed, resize + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False): + super().__init__() + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + attn_cfg.update( + dict( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + batch_first=batch_first, + bias=qkv_bias)) + + self.build_attn(attn_cfg) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + ffn_cfg.update( + dict( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.build_ffn(ffn_cfg) + self.with_cp = with_cp + + def build_attn(self, attn_cfg): + self.attn = MultiheadAttention(**attn_cfg) + + def build_ffn(self, ffn_cfg): + self.ffn = FFN(**ffn_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class VisionTransformer(BaseModule): + """Vision Transformer. + + This backbone is the implementation of `An Image is Worth 16x16 Words: + Transformers for Image Recognition at + Scale `_. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + patch_pad (str | int | None): The padding method in patch embedding. + Default: 'corner'. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_origin (bool): Whether to output the original input embedding. + Default: False + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qkv_bias (bool): enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Default: True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_bias (dict): Whether use bias in convolution of PatchEmbed Block. + Default: True. + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + pre_norm (bool): Whether to add a norm before Transformer Layers. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + frozen_exclude (List): List of parameters that are not to be frozen. + Default: ["all"], "all" means there are no frozen parameters. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + patch_pad='corner', + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_origin=False, + out_indices=-1, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + patch_bias=False, + pre_norm=False, + final_norm=False, + interpolate_mode='bicubic', + num_fcs=2, + norm_eval=False, + with_cp=False, + frozen_exclude=['all'], + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.img_size = img_size + self.patch_size = patch_size + self.interpolate_mode = interpolate_mode + self.norm_eval = norm_eval + self.with_cp = with_cp + self.pretrained = pretrained + self.out_origin = out_origin + self.frozen_exclude = frozen_exclude + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding=patch_pad, + bias=patch_bias, + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + + num_patches = (img_size[0] // patch_size) * \ + (img_size[1] // patch_size) + + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.pre_norm = pre_norm + + if self.pre_norm: + self.pre_ln_name, pre_ln = build_norm_layer( + norm_cfg, embed_dims, postfix='_pre') + self.add_module(self.pre_ln_name, pre_ln) + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + self._freeze() + + @property + def pre_ln(self): + return getattr(self, self.pre_ln_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def init_weights(self): + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']: + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + if self.init_cfg.get('type') == 'Pretrained': + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + elif self.init_cfg.get('type') == 'Pretrained_Part': + state_dict = checkpoint.copy() + para_prefix = 'image_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + print_log(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), self.interpolate_mode) + + load_state_dict(self, state_dict, strict=False, logger=None) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def _freeze(self): + if 'all' in self.frozen_exclude: + return + for name, param in self.named_parameters(): + if not any([exclude in name for exclude in self.frozen_exclude]): + param.requires_grad = False + + def _pos_embeding(self, patched_img, hw_shape, pos_embed): + """Positioning embeding method. + + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size) + 1: + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), + self.interpolate_mode) + return self.drop_after_pos(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + cls_token_weight = pos_embed[:, 0] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize( + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + cls_token_weight = cls_token_weight.unsqueeze(1) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self._pos_embeding(x, hw_shape, self.pos_embed) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + if self.pre_norm: + x = self.pre_ln(x) + + outs = [] + if self.out_origin: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/mmseg/models/backbones/vpd.py b/mmseg/models/backbones/vpd.py new file mode 100644 index 0000000000000000000000000000000000000000..e0536d31c64f82fb66117d9ebd2161d5f2df57bd --- /dev/null +++ b/mmseg/models/backbones/vpd.py @@ -0,0 +1,395 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py +# Original licence: MIT License +# ------------------------------------------------------------------------------ + +import math +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule +from mmengine.runner import CheckpointLoader, load_checkpoint + +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, OptConfigType + +try: + from ldm.modules.diffusionmodules.util import timestep_embedding + from ldm.util import instantiate_from_config + has_ldm = True +except ImportError: + has_ldm = False + + +def register_attention_control(model, controller): + """Registers a control function to manage attention within a model. + + Args: + model: The model to which attention is to be registered. + controller: The control function responsible for managing attention. + """ + + def ca_forward(self, place_in_unet): + """Custom forward method for attention. + + Args: + self: Reference to the current object. + place_in_unet: The location in UNet (down/mid/up). + + Returns: + The modified forward method. + """ + + def forward(x, context=None, mask=None): + h = self.heads + is_cross = context is not None + context = context or x # if context is None, use x + + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + q, k, v = ( + tensor.view(tensor.shape[0] * h, tensor.shape[1], + tensor.shape[2] // h) for tensor in [q, k, v]) + + sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale + + if mask is not None: + mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) + max_neg_value = -torch.finfo(sim.dtype).max + sim.masked_fill_(~mask, max_neg_value) + + attn = sim.softmax(dim=-1) + attn_mean = attn.view(h, attn.shape[0] // h, + *attn.shape[1:]).mean(0) + controller(attn_mean, is_cross, place_in_unet) + + out = torch.matmul(attn, v) + out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) + return self.to_out(out) + + return forward + + def register_recr(net_, count, place_in_unet): + """Recursive function to register the custom forward method to all + CrossAttention layers. + + Args: + net_: The network layer currently being processed. + count: The current count of layers processed. + place_in_unet: The location in UNet (down/mid/up). + + Returns: + The updated count of layers processed. + """ + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + if hasattr(net_, 'children'): + return sum( + register_recr(child, 0, place_in_unet) + for child in net_.children()) + return count + + cross_att_count = sum( + register_recr(net[1], 0, place) for net, place in [ + (child, 'down') if 'input_blocks' in name else ( + child, 'up') if 'output_blocks' in name else + (child, + 'mid') if 'middle_block' in name else (None, None) # Default case + for name, child in model.diffusion_model.named_children() + ] if net is not None) + + controller.num_att_layers = cross_att_count + + +class AttentionStore: + """A class for storing attention information in the UNet model. + + Attributes: + base_size (int): Base size for storing attention information. + max_size (int): Maximum size for storing attention information. + """ + + def __init__(self, base_size=64, max_size=None): + """Initialize AttentionStore with default or custom sizes.""" + self.reset() + self.base_size = base_size + self.max_size = max_size or (base_size // 2) + self.num_att_layers = -1 + + @staticmethod + def get_empty_store(): + """Returns an empty store for holding attention values.""" + return { + key: [] + for key in [ + 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', + 'up_self' + ] + } + + def reset(self): + """Resets the step and attention stores to their initial states.""" + self.cur_step = 0 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + """Processes a single forward step, storing the attention. + + Args: + attn: The attention tensor. + is_cross (bool): Whether it's cross attention. + place_in_unet (str): The location in UNet (down/mid/up). + + Returns: + The unmodified attention tensor. + """ + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= (self.max_size)**2: + self.step_store[key].append(attn) + return attn + + def between_steps(self): + """Processes and stores attention information between steps.""" + if not self.attention_store: + self.attention_store = self.step_store + else: + for key in self.attention_store: + self.attention_store[key] = [ + stored + step for stored, step in zip( + self.attention_store[key], self.step_store[key]) + ] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + """Calculates and returns the average attention across all steps.""" + return { + key: [item for item in self.step_store[key]] + for key in self.step_store + } + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + """Allows the class instance to be callable.""" + return self.forward(attn, is_cross, place_in_unet) + + @property + def num_uncond_att_layers(self): + """Returns the number of unconditional attention layers (default is + 0).""" + return 0 + + def step_callback(self, x_t): + """A placeholder for a step callback. + + Returns the input unchanged. + """ + return x_t + + +class UNetWrapper(nn.Module): + """A wrapper for UNet with optional attention mechanisms. + + Args: + unet (nn.Module): The UNet model to wrap + use_attn (bool): Whether to use attention. Defaults to True + base_size (int): Base size for the attention store. Defaults to 512 + max_attn_size (int, optional): Maximum size for the attention store. + Defaults to None + attn_selector (str): The types of attention to use. + Defaults to 'up_cross+down_cross' + """ + + def __init__(self, + unet, + use_attn=True, + base_size=512, + max_attn_size=None, + attn_selector='up_cross+down_cross'): + super().__init__() + + assert has_ldm, 'To use UNetWrapper, please install required ' \ + 'packages via `pip install -r requirements/optional.txt`.' + + self.unet = unet + self.attention_store = AttentionStore( + base_size=base_size // 8, max_size=max_attn_size) + self.attn_selector = attn_selector.split('+') + self.use_attn = use_attn + self.init_sizes(base_size) + if self.use_attn: + register_attention_control(unet, self.attention_store) + + def init_sizes(self, base_size): + """Initialize sizes based on the base size.""" + self.size16 = base_size // 32 + self.size32 = base_size // 16 + self.size64 = base_size // 8 + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """Forward pass through the model.""" + diffusion_model = self.unet.diffusion_model + if self.use_attn: + self.attention_store.reset() + hs, emb, out_list = self._unet_forward(x, timesteps, context, y, + diffusion_model) + if self.use_attn: + self._append_attn_to_output(out_list) + return out_list[::-1] + + def _unet_forward(self, x, timesteps, context, y, diffusion_model): + hs = [] + t_emb = timestep_embedding( + timesteps, diffusion_model.model_channels, repeat_only=False) + emb = diffusion_model.time_embed(t_emb) + h = x.type(diffusion_model.dtype) + for module in diffusion_model.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = diffusion_model.middle_block(h, emb, context) + out_list = [] + for i_out, module in enumerate(diffusion_model.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if i_out in [1, 4, 7]: + out_list.append(h) + h = h.type(x.dtype) + out_list.append(h) + return hs, emb, out_list + + def _append_attn_to_output(self, out_list): + avg_attn = self.attention_store.get_average_attention() + attns = {self.size16: [], self.size32: [], self.size64: []} + for k in self.attn_selector: + for up_attn in avg_attn[k]: + size = int(math.sqrt(up_attn.shape[1])) + up_attn = up_attn.transpose(-1, -2).reshape( + *up_attn.shape[:2], size, -1) + attns[size].append(up_attn) + attn16 = torch.stack(attns[self.size16]).mean(0) + attn32 = torch.stack(attns[self.size32]).mean(0) + attn64 = torch.stack(attns[self.size64]).mean(0) if len( + attns[self.size64]) > 0 else None + out_list[1] = torch.cat([out_list[1], attn16], dim=1) + out_list[2] = torch.cat([out_list[2], attn32], dim=1) + if attn64 is not None: + out_list[3] = torch.cat([out_list[3], attn64], dim=1) + + +class TextAdapter(nn.Module): + """A PyTorch Module that serves as a text adapter. + + This module takes text embeddings and adjusts them based on a scaling + factor gamma. + """ + + def __init__(self, text_dim=768): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(text_dim, text_dim), nn.GELU(), + nn.Linear(text_dim, text_dim)) + + def forward(self, texts, gamma): + texts_after = self.fc(texts) + texts = texts + gamma * texts_after + return texts + + +@MODELS.register_module() +class VPD(BaseModule): + """VPD (Visual Perception Diffusion) model. + + .. _`VPD`: https://arxiv.org/abs/2303.02153 + + Args: + diffusion_cfg (dict): Configuration for diffusion model. + class_embed_path (str): Path for class embeddings. + unet_cfg (dict, optional): Configuration for U-Net. + gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. + class_embed_select (bool, optional): If True, enables class embedding + selection. Defaults to False. + pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. + Defaults to None. + pad_val (Union[int, List[int]], optional): Padding value. + Defaults to 0. + init_cfg (dict, optional): Configuration for network initialization. + """ + + def __init__(self, + diffusion_cfg: ConfigType, + class_embed_path: str, + unet_cfg: OptConfigType = dict(), + gamma: float = 1e-4, + class_embed_select=False, + pad_shape: Optional[Union[int, List[int]]] = None, + pad_val: Union[int, List[int]] = 0, + init_cfg: OptConfigType = None): + + super().__init__(init_cfg=init_cfg) + + assert has_ldm, 'To use VPD model, please install required packages' \ + ' via `pip install -r requirements/optional.txt`.' + + if pad_shape is not None: + if not isinstance(pad_shape, (list, tuple)): + pad_shape = (pad_shape, pad_shape) + + self.pad_shape = pad_shape + self.pad_val = pad_val + + # diffusion model + diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) + sd_model = instantiate_from_config(diffusion_cfg) + if diffusion_checkpoint is not None: + load_checkpoint(sd_model, diffusion_checkpoint, strict=False) + + self.encoder_vq = sd_model.first_stage_model + self.unet = UNetWrapper(sd_model.model, **unet_cfg) + + # class embeddings & text adapter + class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) + text_dim = class_embeddings.size(-1) + self.text_adapter = TextAdapter(text_dim=text_dim) + self.class_embed_select = class_embed_select + if class_embed_select: + class_embeddings = torch.cat( + (class_embeddings, class_embeddings.mean(dim=0, + keepdims=True)), + dim=0) + self.register_buffer('class_embeddings', class_embeddings) + self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) + + def forward(self, x): + """Extract features from images.""" + + # calculate cross-attn map + if self.class_embed_select: + if isinstance(x, (tuple, list)): + x, class_ids = x[:2] + class_ids = class_ids.tolist() + else: + class_ids = [-1] * x.size(0) + class_embeddings = self.class_embeddings[class_ids] + c_crossattn = self.text_adapter(class_embeddings, self.gamma) + c_crossattn = c_crossattn.unsqueeze(1) + else: + class_embeddings = self.class_embeddings + c_crossattn = self.text_adapter(class_embeddings, self.gamma) + c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) + + # pad to required input shape for pretrained diffusion model + if self.pad_shape is not None: + pad_width = max(0, self.pad_shape[1] - x.shape[-1]) + pad_height = max(0, self.pad_shape[0] - x.shape[-2]) + x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) + + # forward the denoising model + with torch.no_grad(): + latents = self.encoder_vq.encode(x).mode().detach() + t = torch.ones((x.shape[0], ), device=x.device).long() + outs = self.unet(latents, t, context=c_crossattn) + + return outs diff --git a/mmseg/models/builder.py b/mmseg/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..081c646b49b8ff1ea6c42d1ea4e24e63cdf6b43a --- /dev/null +++ b/mmseg/models/builder.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + warnings.warn('``build_backbone`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + warnings.warn('``build_neck`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + warnings.warn('``build_head`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + warnings.warn('``build_loss`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/mmseg/models/data_preprocessor.py b/mmseg/models/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..8d32bc647b7d48183590408e36ec42ea36aea91c --- /dev/null +++ b/mmseg/models/data_preprocessor.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from numbers import Number +from typing import Any, Dict, List, Optional, Sequence + +import torch +from mmengine.model import BaseDataPreprocessor + +from mmseg.registry import MODELS +from mmseg.utils import stack_batch + + +@MODELS.register_module() +class SegDataPreProcessor(BaseDataPreprocessor): + """Image pre-processor for segmentation tasks. + + Comparing with the :class:`mmengine.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the input size with defined ``pad_val``, and pad seg map + with defined ``seg_pad_val``. + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + padding_mode (str): Type of padding. Default: constant. + - constant: pads with a constant value, this value is specified + with pad_val. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + batch_augments (list[dict], optional): Batch-level augmentations + test_cfg (dict, optional): The padding size config in testing, if not + specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[dict]] = None, + test_cfg: dict = None, + ): + super().__init__() + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + assert not (bgr_to_rgb and rgb_to_bgr), ( + '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + self.channel_conversion = rgb_to_bgr or bgr_to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + # TODO: support batch augmentations. + self.batch_augments = batch_augments + + # Support different padding methods in testing + self.test_cfg = test_cfg + + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: + """Perform normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Dict: Data in the same format as the model input. + """ + data = self.cast_data(data) # type: ignore + inputs = data['inputs'] + data_samples = data.get('data_samples', None) + # TODO: whether normalize should be after stack_batch + if self.channel_conversion and inputs[0].size(0) == 3: + inputs = [_input[[2, 1, 0], ...] for _input in inputs] + + inputs = [_input.float() for _input in inputs] + if self._enable_normalize: + inputs = [(_input - self.mean) / self.std for _input in inputs] + + if training: + assert data_samples is not None, ('During training, ', + '`data_samples` must be define.') + inputs, data_samples = stack_batch( + inputs=inputs, + data_samples=data_samples, + size=self.size, + size_divisor=self.size_divisor, + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + + if self.batch_augments is not None: + inputs, data_samples = self.batch_augments( + inputs, data_samples) + else: + img_size = inputs[0].shape[1:] + assert all(input_.shape[1:] == img_size for input_ in inputs), \ + 'The image size in a batch should be the same.' + # pad images when testing + if self.test_cfg: + inputs, padded_samples = stack_batch( + inputs=inputs, + size=self.test_cfg.get('size', None), + size_divisor=self.test_cfg.get('size_divisor', None), + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + for data_sample, pad_info in zip(data_samples, padded_samples): + data_sample.set_metainfo({**pad_info}) + else: + inputs = torch.stack(inputs, dim=0) + + return dict(inputs=inputs, data_samples=data_samples) diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4229763816e4100ab6718e4698a21ce92199371b --- /dev/null +++ b/mmseg/models/decode_heads/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ann_head import ANNHead +from .apc_head import APCHead +from .aspp_head import ASPPHead +from .cc_head import CCHead +from .da_head import DAHead +from .ddr_head import DDRHead +from .dm_head import DMHead +from .dnl_head import DNLHead +from .dpt_head import DPTHead +from .ema_head import EMAHead +from .enc_head import EncHead +from .fcn_head import FCNHead +from .fpn_head import FPNHead +from .gc_head import GCHead +from .ham_head import LightHamHead +from .isa_head import ISAHead +from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator +from .lraspp_head import LRASPPHead +from .mask2former_head import Mask2FormerHead +from .maskformer_head import MaskFormerHead +from .nl_head import NLHead +from .ocr_head import OCRHead +from .pid_head import PIDHead +from .point_head import PointHead +from .psa_head import PSAHead +from .psp_head import PSPHead +from .san_head import SideAdapterCLIPHead +from .segformer_head import SegformerHead +from .segmenter_mask_head import SegmenterMaskTransformerHead +from .sep_aspp_head import DepthwiseSeparableASPPHead +from .sep_fcn_head import DepthwiseSeparableFCNHead +from .setr_mla_head import SETRMLAHead +from .setr_up_head import SETRUPHead +from .stdc_head import STDCHead +from .uper_head import UPerHead +from .vpd_depth_head import VPDDepthHead + +__all__ = [ + 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', + 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', + 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', + 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', + 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', + 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead' +] diff --git a/mmseg/models/decode_heads/__pycache__/__init__.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c583c63a6a1b08e7d0835f7c276762d89e50ddd3 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ann_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/ann_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59d9714a3eb812d1010d0831ca48437c339055d0 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ann_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/apc_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/apc_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f475e06931089457f283aede8f59bb38c2041a Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/apc_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15aabc28e1c409598f8eb88b3456b892816cd70 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0183a1e3ebba586f8cbbd96df4f48c53ce58a0 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/cc_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/cc_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e68f6e99433acc600cd598f30ee36a02d825dad3 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/cc_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/da_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/da_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd6658e8a94f68c5566a019d2bec416c8c2f3d7f Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/da_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ddr_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/ddr_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e709ea6b4dd148c761d509d597cc810eb8dfc58 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ddr_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/decode_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/decode_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0cdfb122fb2930aaf9f7f66302f509735041ab9 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/decode_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/dm_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/dm_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..701f7c902df3163da9878396b8793bcadaee22ee Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/dm_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e91b8d6fe27026eac2a8ec13cc928d175cb33d9 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..193c80fa1e9753862e169f63517c7a66b056f3e6 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ema_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/ema_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3764d3f7172f9fca5f376b86ae07c991729a16f Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ema_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/enc_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/enc_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..355eaf4166f3a2cabb72d12f3bdb4e495bad7e00 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/enc_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88915c660169b1fd67af907521e5612b24a6ef46 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbe7a01a729d6ae347907c2806bdef75cd3b15dd Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/gc_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/gc_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dd1a87a1c318fa3ba18ba91aa7fdea00c68b86b Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/gc_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ham_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/ham_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57383a99ed58400ab3c4c89f8341539033cd9346 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ham_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/isa_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/isa_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66c38cd75f4466a0182fdb836ca3ef9c3880f187 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/isa_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/knet_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/knet_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc0226976b38e8616ac7675d69dcc96a345ddb57 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/knet_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c90bd302ccd4e9d4b9b05c0497a9b8c5fa19200 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/mask2former_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/mask2former_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4f4575e2f7c8038fc17a4908e34adb8ce44f4de Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/mask2former_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/maskformer_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/maskformer_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cee7b247a586d25d56f58f992eb1ba46265b6cbf Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/maskformer_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/nl_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/nl_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95c259ea26ec1d0d6542c09701f45d07d9bef929 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/nl_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a221a876d517e7e755eef33cc867b2a587231f6 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/pid_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/pid_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7425bffefcbad42f1a839524da89c0ce09bb8071 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/pid_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/point_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/point_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f76e9ce825dd51d37aa5b6fb5e30c3e17bcccca4 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/point_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/psa_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/psa_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e9a49473e4f37b6b5b1a7f2ffebba6483c07a8e Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/psa_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/psp_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/psp_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59145847a4c161f3bdea14cd60b4cb83a1a3794b Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/psp_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/san_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/san_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11cd96fd195bae501bcc5fd7e6f6e0ca77849a62 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/san_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffde6b29d779081c432faddd7fada2baf5c9f57a Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..391d537e1d32e454a070ca4d84a2e5bfeeafc475 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2056a0d99dd2635f9f10bf6777b2e75e2310033 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feb0bb3dfdb878fe70e4aa525f3b8b7129160a23 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bb506ba700451ea69b59338af48f57a899413cd Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8f66c0fb7861ab45141abdfa631466d5a20beae Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fd0ee36db0878ae80eca035fc8d8509eb097603 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/uper_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/uper_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..955348a292869e3056c5e1de7044df83639cef48 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/uper_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/vpd_depth_head.cpython-311.pyc b/mmseg/models/decode_heads/__pycache__/vpd_depth_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..898ee3114b60de80a6af8a5d228e034fc22163bd Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/vpd_depth_head.cpython-311.pyc differ diff --git a/mmseg/models/decode_heads/ann_head.py b/mmseg/models/decode_heads/ann_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2b40ef5aa1da0bc2473597fedca5b3f33973beb0 --- /dev/null +++ b/mmseg/models/decode_heads/ann_head.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PPMConcat(nn.ModuleList): + """Pyramid Pooling Module that only concat the features of each layer. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + """ + + def __init__(self, pool_scales=(1, 3, 6, 8)): + super().__init__( + [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) + + def forward(self, feats): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(feats) + ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) + concat_outs = torch.cat(ppm_outs, dim=2) + return concat_outs + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Make a ANN used SelfAttentionBlock. + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_scale (int): The scale of query feature map. + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, share_key_query, query_scale, key_pool_scales, + conv_cfg, norm_cfg, act_cfg): + key_psp = PPMConcat(key_pool_scales) + if query_scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=query_scale) + else: + query_downsample = None + super().__init__( + key_in_channels=low_in_channels, + query_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=share_key_query, + query_downsample=query_downsample, + key_downsample=key_psp, + key_query_num_convs=1, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + +class AFNB(nn.Module): + """Asymmetric Fusion Non-local Block(AFNB) + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + and query projection. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, query_scales, key_pool_scales, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=False, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + out_channels + high_in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, low_feats, high_feats): + """Forward function.""" + priors = [stage(high_feats, low_feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, high_feats], 1)) + return output + + +class APNB(nn.Module): + """Asymmetric Pyramid Non-local Block (APNB) + + Args: + in_channels (int): Input channels of key/query feature, + which is the key feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, out_channels, query_scales, + key_pool_scales, conv_cfg, norm_cfg, act_cfg): + super().__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=in_channels, + high_in_channels=in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=True, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + 2 * in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, feats): + """Forward function.""" + priors = [stage(feats, feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, feats], 1)) + return output + + +@MODELS.register_module() +class ANNHead(BaseDecodeHead): + """Asymmetric Non-local Neural Networks for Semantic Segmentation. + + This head is the implementation of `ANNNet + `_. + + Args: + project_channels (int): Projection channels for Nonlocal. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): The pooling scales of key feature map. + Default: (1, 3, 6, 8). + """ + + def __init__(self, + project_channels, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(self.in_channels) == 2 + low_in_channels, high_in_channels = self.in_channels + self.project_channels = project_channels + self.fusion = AFNB( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + out_channels=high_in_channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + high_in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.context = APNB( + in_channels=self.channels, + out_channels=self.channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + low_feats, high_feats = self._transform_inputs(inputs) + output = self.fusion(low_feats, high_feats) + output = self.dropout(output) + output = self.bottleneck(output) + output = self.context(output) + output = self.cls_seg(output) + + return output diff --git a/mmseg/models/decode_heads/apc_head.py b/mmseg/models/decode_heads/apc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..728f39659c63680944306fddc9e33b7c9172c1ba --- /dev/null +++ b/mmseg/models/decode_heads/apc_head.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ACM(nn.Module): + """Adaptive Context Module used in APCNet. + + Args: + pool_scale (int): Pooling scale used in Adaptive Context + Module to extract region features. + fusion (bool): Add one conv to fuse residual feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.pool_scale = pool_scale + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.pooled_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.global_info = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) + + self.residual_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) + # [batch_size, channels, h, w] + x = self.input_redu_conv(x) + # [batch_size, channels, pool_scale, pool_scale] + pooled_x = self.pooled_redu_conv(pooled_x) + batch_size = x.size(0) + # [batch_size, pool_scale * pool_scale, channels] + pooled_x = pooled_x.view(batch_size, self.channels, + -1).permute(0, 2, 1).contiguous() + # [batch_size, h * w, pool_scale * pool_scale] + affinity_matrix = self.gla(x + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) + ).permute(0, 2, 3, 1).reshape( + batch_size, -1, self.pool_scale**2) + affinity_matrix = F.sigmoid(affinity_matrix) + # [batch_size, h * w, channels] + z_out = torch.matmul(affinity_matrix, pooled_x) + # [batch_size, channels, h * w] + z_out = z_out.permute(0, 2, 1).contiguous() + # [batch_size, channels, h, w] + z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) + z_out = self.residual_conv(z_out) + z_out = F.relu(z_out + x) + if self.fusion: + z_out = self.fusion_conv(z_out) + + return z_out + + +@MODELS.register_module() +class APCHead(BaseDecodeHead): + """Adaptive Pyramid Context Network for Semantic Segmentation. + + This head is the implementation of + `APCNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Adaptive Context + Module. Default: (1, 2, 3, 6). + fusion (bool): Add one conv to fuse residual feature. + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): + super().__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.fusion = fusion + acm_modules = [] + for pool_scale in self.pool_scales: + acm_modules.append( + ACM(pool_scale, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.acm_modules = nn.ModuleList(acm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + acm_outs = [x] + for acm_module in self.acm_modules: + acm_outs.append(acm_module(x)) + acm_outs = torch.cat(acm_outs, dim=1) + output = self.bottleneck(acm_outs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/aspp_head.py b/mmseg/models/decode_heads/aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7185d7de58d35ef17e5d54e0e75b045e8724c4 --- /dev/null +++ b/mmseg/models/decode_heads/aspp_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ASPPModule(nn.ModuleList): + """Atrous Spatial Pyramid Pooling (ASPP) Module. + + Args: + dilations (tuple[int]): Dilation rate of each layer. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, + act_cfg): + super().__init__() + self.dilations = dilations + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for dilation in dilations: + self.append( + ConvModule( + self.in_channels, + self.channels, + 1 if dilation == 1 else 3, + dilation=dilation, + padding=0 if dilation == 1 else dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x): + """Forward function.""" + aspp_outs = [] + for aspp_module in self: + aspp_outs.append(aspp_module(x)) + + return aspp_outs + + +@MODELS.register_module() +class ASPPHead(BaseDecodeHead): + """Rethinking Atrous Convolution for Semantic Image Segmentation. + + This head is the implementation of `DeepLabV3 + `_. + + Args: + dilations (tuple[int]): Dilation rates for ASPP module. + Default: (1, 6, 12, 18). + """ + + def __init__(self, dilations=(1, 6, 12, 18), **kwargs): + super().__init__(**kwargs) + assert isinstance(dilations, (list, tuple)) + self.dilations = dilations + self.image_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.aspp_modules = ASPPModule( + dilations, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + (len(dilations) + 1) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + feats = self.bottleneck(aspp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/cascade_decode_head.py b/mmseg/models/decode_heads/cascade_decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2bcb9302235e3881696dff6657e3e7fb12609b --- /dev/null +++ b/mmseg/models/decode_heads/cascade_decode_head.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List + +from torch import Tensor + +from mmseg.utils import ConfigType +from .decode_head import BaseDecodeHead + + +class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): + """Base class for cascade decode head used in + :class:`CascadeEncoderDecoder.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @abstractmethod + def forward(self, inputs, prev_output): + """Placeholder of forward function.""" + pass + + def loss(self, inputs: List[Tensor], prev_output: Tensor, + batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor: + """Forward function for training. + + Args: + inputs (List[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs, prev_output) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + + return losses + + def predict(self, inputs: List[Tensor], prev_output: Tensor, + batch_img_metas: List[dict], tese_cfg: ConfigType): + """Forward function for testing. + + Args: + inputs (List[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + seg_logits = self.forward(inputs, prev_output) + + return self.predict_by_feat(seg_logits, batch_img_metas) diff --git a/mmseg/models/decode_heads/cc_head.py b/mmseg/models/decode_heads/cc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e9075a2648d77f6bca6bb29f3e7db52a329f7afb --- /dev/null +++ b/mmseg/models/decode_heads/cc_head.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + +try: + from mmcv.ops import CrissCrossAttention +except ModuleNotFoundError: + CrissCrossAttention = None + + +@MODELS.register_module() +class CCHead(FCNHead): + """CCNet: Criss-Cross Attention for Semantic Segmentation. + + This head is the implementation of `CCNet + `_. + + Args: + recurrence (int): Number of recurrence of Criss Cross Attention + module. Default: 2. + """ + + def __init__(self, recurrence=2, **kwargs): + if CrissCrossAttention is None: + raise RuntimeError('Please install mmcv-full for ' + 'CrissCrossAttention ops') + super().__init__(num_convs=2, **kwargs) + self.recurrence = recurrence + self.cca = CrissCrossAttention(self.channels) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + for _ in range(self.recurrence): + output = self.cca(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/da_head.py b/mmseg/models/decode_heads/da_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d87214365d2f8695b60ccab0c1850669ff8dd295 --- /dev/null +++ b/mmseg/models/decode_heads/da_head.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from torch import Tensor, nn + +from mmseg.registry import MODELS +from mmseg.utils import SampleList, add_prefix +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PAM(_SelfAttentionBlock): + """Position Attention Module (PAM) + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + """ + + def __init__(self, in_channels, channels): + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=1, + key_query_norm=False, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=False, + with_out=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + out = super().forward(x, x) + + out = self.gamma(out) + x + return out + + +class CAM(nn.Module): + """Channel Attention Module (CAM)""" + + def __init__(self): + super().__init__() + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + batch_size, channels, height, width = x.size() + proj_query = x.view(batch_size, channels, -1) + proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max( + energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = F.softmax(energy_new, dim=-1) + proj_value = x.view(batch_size, channels, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(batch_size, channels, height, width) + + out = self.gamma(out) + x + return out + + +@MODELS.register_module() +class DAHead(BaseDecodeHead): + """Dual Attention Network for Scene Segmentation. + + This head is the implementation of `DANet + `_. + + Args: + pam_channels (int): The channels of Position Attention Module(PAM). + """ + + def __init__(self, pam_channels, **kwargs): + super().__init__(**kwargs) + self.pam_channels = pam_channels + self.pam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam = PAM(self.channels, pam_channels) + self.pam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + self.cam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam = CAM() + self.cam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + def pam_cls_seg(self, feat): + """PAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.pam_conv_seg(feat) + return output + + def cam_cls_seg(self, feat): + """CAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.cam_conv_seg(feat) + return output + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + pam_feat = self.pam_in_conv(x) + pam_feat = self.pam(pam_feat) + pam_feat = self.pam_out_conv(pam_feat) + pam_out = self.pam_cls_seg(pam_feat) + + cam_feat = self.cam_in_conv(x) + cam_feat = self.cam(cam_feat) + cam_feat = self.cam_out_conv(cam_feat) + cam_out = self.cam_cls_seg(cam_feat) + + feat_sum = pam_feat + cam_feat + pam_cam_out = self.cls_seg(feat_sum) + + return pam_cam_out, pam_out, cam_out + + def predict(self, inputs, batch_img_metas: List[dict], test_cfg, + **kwargs) -> List[Tensor]: + """Forward function for testing, only ``pam_cam`` is used.""" + seg_logits = self.forward(inputs)[0] + return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) + + def loss_by_feat(self, seg_logit: Tuple[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" + pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit + loss = dict() + loss.update( + add_prefix( + super().loss_by_feat(pam_cam_seg_logit, batch_data_samples), + 'pam_cam')) + loss.update( + add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples), + 'pam')) + loss.update( + add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples), + 'cam')) + return loss diff --git a/mmseg/models/decode_heads/ddr_head.py b/mmseg/models/decode_heads/ddr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ba26d6503c09d7efb3ca6664c7baf59c9e6e3ce9 --- /dev/null +++ b/mmseg/models/decode_heads/ddr_head.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType, SampleList + + +@MODELS.register_module() +class DDRHead(BaseDecodeHead): + """Decode head for DDRNet. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_classes (int): Number of classes. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + """ + + def __init__(self, + in_channels: int, + channels: int, + num_classes: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + **kwargs): + super().__init__( + in_channels, + channels, + num_classes=num_classes, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + + self.head = self._make_base_head(self.in_channels, self.channels) + self.aux_head = self._make_base_head(self.in_channels // 2, + self.channels) + self.aux_cls_seg = nn.Conv2d( + self.channels, self.out_channels, kernel_size=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward( + self, + inputs: Union[Tensor, + Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: + if self.training: + c3_feat, c5_feat = inputs + x_c = self.head(c5_feat) + x_c = self.cls_seg(x_c) + x_s = self.aux_head(c3_feat) + x_s = self.aux_cls_seg(x_s) + + return x_c, x_s + else: + x_c = self.head(inputs) + x_c = self.cls_seg(x_c) + return x_c + + def _make_base_head(self, in_channels: int, + channels: int) -> nn.Sequential: + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + order=('norm', 'act', 'conv')), + build_norm_layer(self.norm_cfg, channels)[1], + build_activation_layer(self.act_cfg), + ] + + return nn.Sequential(*layers) + + def loss_by_feat(self, seg_logits: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + loss = dict() + context_logit, spatial_logit = seg_logits + seg_label = self._stack_batch_gt(batch_data_samples) + + context_logit = resize( + context_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + spatial_logit = resize( + spatial_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + seg_label = seg_label.squeeze(1) + + loss['loss_context'] = self.loss_decode[0](context_logit, seg_label) + loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label) + loss['acc_seg'] = accuracy( + context_logit, seg_label, ignore_index=self.ignore_index) + + return loss diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..179d871fd18d1af3e06a62e1e731572fb85683e2 --- /dev/null +++ b/mmseg/models/decode_heads/decode_head.py @@ -0,0 +1,366 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.structures import build_pixel_sampler +from mmseg.utils import ConfigType, SampleList +from ..builder import build_loss +from ..losses import accuracy +from ..utils import resize + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + 1. The ``init_weights`` method is used to initialize decode_head's + model parameters. After segmentor initialization, ``init_weights`` + is triggered when ``segmentor.init_weights()`` is called externally. + + 2. The ``loss`` method is used to calculate the loss of decode_head, + which includes two steps: (1) the decode_head model performs forward + propagation to obtain the feature maps (2) The ``loss_by_feat`` method + is called based on the feature maps to calculate the loss. + + .. code:: text + + loss(): forward() -> loss_by_feat() + + 3. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) the decode_head model performs forward + propagation to obtain the feature maps (2) The ``predict_by_feat`` method + is called based on the feature maps to predict segmentation results + including post-processing. + + .. code:: text + + predict(): forward() -> predict_by_feat() + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. Default: None. + threshold (float): Threshold for binary segmentation in the case of + `num_classes==1`. Default: None. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + out_channels=None, + threshold=None, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super().__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.ignore_index = ignore_index + self.align_corners = align_corners + + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert `seg_logits` into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Forward function for training. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + seg_logits = self.forward(inputs) + + return self.predict_by_feat(seg_logits, batch_img_metas) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + return torch.stack(gt_semantic_segs, dim=0) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute segmentation loss. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + seg_logits = resize( + input=seg_logits, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logits, seg_label, ignore_index=self.ignore_index) + return loss + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] + else: + size = batch_img_metas[0]['img_shape'] + + seg_logits = resize( + input=seg_logits, + size=size, + mode='bilinear', + align_corners=self.align_corners) + return seg_logits diff --git a/mmseg/models/decode_heads/dm_head.py b/mmseg/models/decode_heads/dm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7694abd8ac3a470d543c580bd97adceb5b647f7c --- /dev/null +++ b/mmseg/models/decode_heads/dm_head.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +class DCM(nn.Module): + """Dynamic Convolutional Module used in DMNet. + + Args: + filter_size (int): The filter size of generated convolution kernel + used in Dynamic Convolutional Module. + fusion (bool): Add one conv to fuse DCM output feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.filter_size = filter_size + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, + 0) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.norm_cfg is not None: + self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] + else: + self.norm = None + self.activate = build_activation_layer(self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + generated_filter = self.filter_gen_conv( + F.adaptive_avg_pool2d(x, self.filter_size)) + x = self.input_redu_conv(x) + b, c, h, w = x.shape + # [1, b * c, h, w], c = self.channels + x = x.view(1, b * c, h, w) + # [b * c, 1, filter_size, filter_size] + generated_filter = generated_filter.view(b * c, 1, self.filter_size, + self.filter_size) + pad = (self.filter_size - 1) // 2 + if (self.filter_size - 1) % 2 == 0: + p2d = (pad, pad, pad, pad) + else: + p2d = (pad + 1, pad, pad + 1, pad) + x = F.pad(input=x, pad=p2d, mode='constant', value=0) + # [1, b * c, h, w] + output = F.conv2d(input=x, weight=generated_filter, groups=b * c) + # [b, c, h, w] + output = output.view(b, c, h, w) + if self.norm is not None: + output = self.norm(output) + output = self.activate(output) + + if self.fusion: + output = self.fusion_conv(output) + + return output + + +@MODELS.register_module() +class DMHead(BaseDecodeHead): + """Dynamic Multi-scale Filters for Semantic Segmentation. + + This head is the implementation of + `DMNet `_. + + Args: + filter_sizes (tuple[int]): The size of generated convolutional filters + used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). + fusion (bool): Add one conv to fuse DCM output feature. + """ + + def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): + super().__init__(**kwargs) + assert isinstance(filter_sizes, (list, tuple)) + self.filter_sizes = filter_sizes + self.fusion = fusion + dcm_modules = [] + for filter_size in self.filter_sizes: + dcm_modules.append( + DCM(filter_size, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.dcm_modules = nn.ModuleList(dcm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(filter_sizes) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + dcm_outs = [x] + for dcm_module in self.dcm_modules: + dcm_outs.append(dcm_module(x)) + dcm_outs = torch.cat(dcm_outs, dim=1) + output = self.bottleneck(dcm_outs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/dnl_head.py b/mmseg/models/decode_heads/dnl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..248c11814108d02e88fa7e0cada061b3366e33ff --- /dev/null +++ b/mmseg/models/decode_heads/dnl_head.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d +from torch import nn + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +class DisentangledNonLocal2d(NonLocal2d): + """Disentangled Non-Local Blocks. + + Args: + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, *arg, temperature, **kwargs): + super().__init__(*arg, **kwargs) + self.temperature = temperature + self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) + + def embedded_gaussian(self, theta_x, phi_x): + """Embedded gaussian with temperature.""" + + # NonLocal2d pairwise_weight: [N, HxW, HxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= torch.tensor( + theta_x.shape[-1], + dtype=torch.float, + device=pairwise_weight.device)**torch.tensor( + 0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor( + self.temperature, device=pairwise_weight.device) + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def forward(self, x): + # x: [N, C, H, W] + n = x.size(0) + + # g_x: [N, HxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta_x: [N, HxW, C], phi_x: [N, C, HxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + # subtract mean + theta_x -= theta_x.mean(dim=-2, keepdim=True) + phi_x -= phi_x.mean(dim=-1, keepdim=True) + + pairwise_func = getattr(self, self.mode) + # pairwise_weight: [N, HxW, HxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # y: [N, HxW, C] + y = torch.matmul(pairwise_weight, g_x) + # y: [N, C, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + # unary_mask: [N, 1, HxW] + unary_mask = self.conv_mask(x) + unary_mask = unary_mask.view(n, 1, -1) + unary_mask = unary_mask.softmax(dim=-1) + # unary_x: [N, 1, C] + unary_x = torch.matmul(unary_mask, g_x) + # unary_x: [N, C, 1, 1] + unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( + n, self.inter_channels, 1, 1) + + output = x + self.conv_out(y + unary_x) + + return output + + +@MODELS.register_module() +class DNLHead(FCNHead): + """Disentangled Non-Local Neural Networks. + + This head is the implementation of `DNLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: False. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + temperature=0.05, + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.temperature = temperature + self.dnl_block = DisentangledNonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode, + temperature=self.temperature) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.dnl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/dpt_head.py b/mmseg/models/decode_heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d2cfd89daa4df48601e930cfd158dcf3c9a6a837 --- /dev/null +++ b/mmseg/models/decode_heads/dpt_head.py @@ -0,0 +1,294 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels=768, + out_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + init_cfg=None): + super().__init__(init_cfg) + + assert readout_type in ['ignore', 'add', 'project'] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList([ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + if self.readout_type == 'project': + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + Linear(2 * in_channels, in_channels), + build_activation_layer(dict(type='GELU')))) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == 'project': + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == 'add': + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + stride=1, + dilation=1, + init_cfg=None): + super().__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + expand=False, + align_corners=True, + init_cfg=None): + super().__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule( + self.in_channels, + self.out_channels, + kernel_size=1, + act_cfg=None, + bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize( + inputs[1], + size=(x.shape[2], x.shape[3]), + mode='bilinear', + align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + x = self.project(x) + return x + + +@MODELS.register_module() +class DPTHead(BaseDecodeHead): + """Vision Transformers for Dense Prediction. + + This head is implemented of `DPT `_. + + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + act_cfg (dict): The activation config for residual conv unit. + Default dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + """ + + def __init__(self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + expand_channels=False, + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN'), + **kwargs): + super().__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, + post_process_channels, + readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel + for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append( + ConvModule( + channel, + self.channels, + kernel_size=3, + padding=1, + act_cfg=None, + bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append( + FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule( + self.channels, + self.channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + + def forward(self, inputs): + assert len(inputs) == self.num_reassemble_blocks + x = self._transform_inputs(inputs) + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.cls_seg(out) + return out diff --git a/mmseg/models/decode_heads/ema_head.py b/mmseg/models/decode_heads/ema_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8dbb0c29b9b533dad962e48d71ae055f20aa07 --- /dev/null +++ b/mmseg/models/decode_heads/ema_head.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +def reduce_mean(tensor): + """Reduce mean when distributed training.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +class EMAModule(nn.Module): + """Expectation Maximization Attention Module used in EMANet. + + Args: + channels (int): Channels of the whole module. + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + """ + + def __init__(self, channels, num_bases, num_stages, momentum): + super().__init__() + assert num_stages >= 1, 'num_stages must be at least 1!' + self.num_bases = num_bases + self.num_stages = num_stages + self.momentum = momentum + + bases = torch.zeros(1, channels, self.num_bases) + bases.normal_(0, math.sqrt(2. / self.num_bases)) + # [1, channels, num_bases] + bases = F.normalize(bases, dim=1, p=2) + self.register_buffer('bases', bases) + + def forward(self, feats): + """Forward function.""" + batch_size, channels, height, width = feats.size() + # [batch_size, channels, height*width] + feats = feats.view(batch_size, channels, height * width) + # [batch_size, channels, num_bases] + bases = self.bases.repeat(batch_size, 1, 1) + + with torch.no_grad(): + for i in range(self.num_stages): + # [batch_size, height*width, num_bases] + attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = F.softmax(attention, dim=2) + # l1 norm + attention_normed = F.normalize(attention, dim=1, p=1) + # [batch_size, channels, num_bases] + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = feats_recon.view(batch_size, channels, height, width) + + if self.training: + bases = bases.mean(dim=0, keepdim=True) + bases = reduce_mean(bases) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + self.bases = (1 - + self.momentum) * self.bases + self.momentum * bases + + return feats_recon + + +@MODELS.register_module() +class EMAHead(BaseDecodeHead): + """Expectation Maximization Attention Networks for Semantic Segmentation. + + This head is the implementation of `EMANet + `_. + + Args: + ema_channels (int): EMA module channels + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + concat_input (bool): Whether concat the input and output of convs + before classification layer. Default: True + momentum (float): Momentum to update the base. Default: 0.1. + """ + + def __init__(self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs): + super().__init__(**kwargs) + self.ema_channels = ema_channels + self.num_bases = num_bases + self.num_stages = num_stages + self.concat_input = concat_input + self.momentum = momentum + self.ema_module = EMAModule(self.ema_channels, self.num_bases, + self.num_stages, self.momentum) + + self.ema_in_conv = ConvModule( + self.in_channels, + self.ema_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # project (0, inf) -> (-inf, inf) + self.ema_mid_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None) + for param in self.ema_mid_conv.parameters(): + param.requires_grad = False + + self.ema_out_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.bottleneck = ConvModule( + self.ema_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.ema_in_conv(x) + identity = feats + feats = self.ema_mid_conv(feats) + recon = self.ema_module(feats) + recon = F.relu(recon, inplace=True) + recon = self.ema_out_conv(recon) + output = F.relu(identity + recon, inplace=True) + output = self.bottleneck(output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ef48fb6995365ba374b29ea265608087500f27dc --- /dev/null +++ b/mmseg/models/decode_heads/enc_head.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, SampleList +from ..builder import build_loss +from ..utils import Encoding, resize +from .decode_head import BaseDecodeHead + + +class EncModule(nn.Module): + """Encoding Module used in EncNet. + + Args: + in_channels (int): Input channels. + num_codes (int): Number of code words. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): + super().__init__() + self.encoding_project = ConvModule( + in_channels, + in_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # TODO: resolve this hack + # change to 1d + if norm_cfg is not None: + encoding_norm_cfg = norm_cfg.copy() + if encoding_norm_cfg['type'] in ['BN', 'IN']: + encoding_norm_cfg['type'] += '1d' + else: + encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( + '2d', '1d') + else: + # fallback to BN1d + encoding_norm_cfg = dict(type='BN1d') + self.encoding = nn.Sequential( + Encoding(channels=in_channels, num_codes=num_codes), + build_norm_layer(encoding_norm_cfg, num_codes)[1], + nn.ReLU(inplace=True)) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels), nn.Sigmoid()) + + def forward(self, x): + """Forward function.""" + encoding_projection = self.encoding_project(x) + encoding_feat = self.encoding(encoding_projection).mean(dim=1) + batch_size, channels, _, _ = x.size() + gamma = self.fc(encoding_feat) + y = gamma.view(batch_size, channels, 1, 1) + output = F.relu_(x + x * y) + return encoding_feat, output + + +@MODELS.register_module() +class EncHead(BaseDecodeHead): + """Context Encoding for Semantic Segmentation. + + This head is the implementation of `EncNet + `_. + + Args: + num_codes (int): Number of code words. Default: 32. + use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to + regularize the training. Default: True. + add_lateral (bool): Whether use lateral connection to fuse features. + Default: False. + loss_se_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss', use_sigmoid=True). + """ + + def __init__(self, + num_codes=32, + use_se_loss=True, + add_lateral=False, + loss_se_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=0.2), + **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.use_se_loss = use_se_loss + self.add_lateral = add_lateral + self.num_codes = num_codes + self.bottleneck = ConvModule( + self.in_channels[-1], + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if add_lateral: + self.lateral_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the last one + self.lateral_convs.append( + ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.fusion = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.enc_module = EncModule( + self.channels, + num_codes=num_codes, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.use_se_loss: + self.loss_se_decode = build_loss(loss_se_decode) + self.se_layer = nn.Linear(self.channels, self.num_classes) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + feat = self.bottleneck(inputs[-1]) + if self.add_lateral: + laterals = [ + resize( + lateral_conv(inputs[i]), + size=feat.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + feat = self.fusion(torch.cat([feat, *laterals], 1)) + encode_feat, output = self.enc_module(feat) + output = self.cls_seg(output) + if self.use_se_loss: + se_output = self.se_layer(encode_feat) + return output, se_output + else: + return output + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType): + """Forward function for testing, ignore se_loss.""" + if self.use_se_loss: + seg_logits = self.forward(inputs)[0] + else: + seg_logits = self.forward(inputs) + return self.predict_by_feat(seg_logits, batch_img_metas) + + @staticmethod + def _convert_to_onehot_labels(seg_label, num_classes): + """Convert segmentation label to onehot. + + Args: + seg_label (Tensor): Segmentation label of shape (N, H, W). + num_classes (int): Number of classes. + + Returns: + Tensor: Onehot labels of shape (N, num_classes). + """ + + batch_size = seg_label.size(0) + onehot_labels = seg_label.new_zeros((batch_size, num_classes)) + for i in range(batch_size): + hist = seg_label[i].float().histc( + bins=num_classes, min=0, max=num_classes - 1) + onehot_labels[i] = hist > 0 + return onehot_labels + + def loss_by_feat(self, seg_logit: Tuple[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute segmentation and semantic encoding loss.""" + seg_logit, se_seg_logit = seg_logit + loss = dict() + loss.update(super().loss_by_feat(seg_logit, batch_data_samples)) + + seg_label = self._stack_batch_gt(batch_data_samples) + se_loss = self.loss_se_decode( + se_seg_logit, + self._convert_to_onehot_labels(seg_label, self.num_classes)) + loss['loss_se'] = se_loss + return loss diff --git a/mmseg/models/decode_heads/fcn_head.py b/mmseg/models/decode_heads/fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..341801888368d307da6b926a2c89f72b6b06476d --- /dev/null +++ b/mmseg/models/decode_heads/fcn_head.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class FCNHead(BaseDecodeHead): + """Fully Convolution Networks for Semantic Segmentation. + + This head is implemented of `FCNNet `_. + + Args: + num_convs (int): Number of convs in the head. Default: 2. + kernel_size (int): The kernel size for convs in the head. Default: 3. + concat_input (bool): Whether concat the input and output of convs + before classification layer. + dilation (int): The dilation rate for convs in the head. Default: 1. + """ + + def __init__(self, + num_convs=2, + kernel_size=3, + concat_input=True, + dilation=1, + **kwargs): + assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) + self.num_convs = num_convs + self.concat_input = concat_input + self.kernel_size = kernel_size + super().__init__(**kwargs) + if num_convs == 0: + assert self.in_channels == self.channels + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + ConvModule( + self.in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + for i in range(num_convs - 1): + convs.append( + ConvModule( + self.channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + feats = self.convs(x) + if self.concat_input: + feats = self.conv_cat(torch.cat([x, feats], dim=1)) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..25f481fe81c5f4f0aa37903aaf135dc63c930bf8 --- /dev/null +++ b/mmseg/models/decode_heads/fpn_head.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import Upsample, resize +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class FPNHead(BaseDecodeHead): + """Panoptic Feature Pyramid Networks. + + This head is the implementation of `Semantic FPN + `_. + + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, feature_strides, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.in_channels[i] if k == 0 else self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/gc_head.py b/mmseg/models/decode_heads/gc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..14f0ef021c1143d493e17f347f1f4da1145470b8 --- /dev/null +++ b/mmseg/models/decode_heads/gc_head.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ContextBlock + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class GCHead(FCNHead): + """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. + + This head is the implementation of `GCNet + `_. + + Args: + ratio (float): Multiplier of channels ratio. Default: 1/4. + pooling_type (str): The pooling type of context aggregation. + Options are 'att', 'avg'. Default: 'avg'. + fusion_types (tuple[str]): The fusion type for feature fusion. + Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) + """ + + def __init__(self, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.ratio = ratio + self.pooling_type = pooling_type + self.fusion_types = fusion_types + self.gc_block = ContextBlock( + in_channels=self.channels, + ratio=self.ratio, + pooling_type=self.pooling_type, + fusion_types=self.fusion_types) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.gc_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/ham_head.py b/mmseg/models/decode_heads/ham_head.py new file mode 100644 index 0000000000000000000000000000000000000000..073d8011b05dc8c5e8d48cc8b77484a27f7b2100 --- /dev/null +++ b/mmseg/models/decode_heads/ham_head.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.device import get_device + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class Matrix_Decomposition_2D_Base(nn.Module): + """Base class of 2D Matrix Decomposition. + + Args: + MD_S (int): The number of spatial coefficient in + Matrix Decomposition, it may be used for calculation + of the number of latent dimension D in Matrix + Decomposition. Defaults: 1. + MD_R (int): The number of latent dimension R in + Matrix Decomposition. Defaults: 64. + train_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in training. Defaults: 6. + eval_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in evaluation. Defaults: 7. + inv_t (int): Inverted multiple number to make coefficient + smaller in softmax. Defaults: 100. + rand_init (bool): Whether to initialize randomly. + Defaults: True. + """ + + def __init__(self, + MD_S=1, + MD_R=64, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True): + super().__init__() + + self.S = MD_S + self.R = MD_R + + self.train_steps = train_steps + self.eval_steps = eval_steps + + self.inv_t = inv_t + + self.rand_init = rand_init + + def _build_bases(self, B, S, D, R, device=None): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + """Forward Function.""" + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + if not self.rand_init and not hasattr(self, 'bases'): + bases = self._build_bases(1, self.S, D, self.R, device=x.device) + self.register_buffer('bases', bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R, device=x.device) + else: + bases = self.bases.repeat(B, 1, 1) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + + return x + + +class NMF2D(Matrix_Decomposition_2D_Base): + """Non-negative Matrix Factorization (NMF) module. + + It is inherited from ``Matrix_Decomposition_2D_Base`` module. + """ + + def __init__(self, args=dict()): + super().__init__(**args) + + self.inv_t = 1 + + def _build_bases(self, B, S, D, R, device=None): + """Build bases in initialization.""" + if device is None: + device = get_device() + bases = torch.rand((B * S, D, R)).to(device) + bases = F.normalize(bases, dim=1) + + return bases + + def local_step(self, x, bases, coef): + """Local step in iteration to renew bases and coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + """Compute coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class Hamburger(nn.Module): + """Hamburger Module. It consists of one slice of "ham" (matrix + decomposition) and two slices of "bread" (linear transformation). + + Args: + ham_channels (int): Input and output channels of feature. + ham_kwargs (dict): Config of matrix decomposition module. + norm_cfg (dict | None): Config of norm layers. + """ + + def __init__(self, + ham_channels=512, + ham_kwargs=dict(), + norm_cfg=None, + **kwargs): + super().__init__() + + self.ham_in = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) + + self.ham = NMF2D(ham_kwargs) + + self.ham_out = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + + def forward(self, x): + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=True) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=True) + + return ham + + +@MODELS.register_module() +class LightHamHead(BaseDecodeHead): + """SegNeXt decode head. + + This decode head is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Specifically, LightHamHead is inspired by HamNet from + `Is Attention Better Than Matrix Decomposition? + `. + + Args: + ham_channels (int): input channels for Hamburger. + Defaults: 512. + ham_kwargs (int): kwagrs for Ham. Defaults: dict(). + """ + + def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.ham_channels = ham_channels + + self.squeeze = ConvModule( + sum(self.in_channels), + self.ham_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs) + + self.align = ConvModule( + self.ham_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + inputs = [ + resize( + level, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + # apply a conv block to squeeze feature map + x = self.squeeze(inputs) + # apply hamburger module + x = self.hamburger(x) + + # apply a conv block to align feature map + output = self.align(x) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/isa_head.py b/mmseg/models/decode_heads/isa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..355f215f39007d0153c2fdb3b22a40e7f11a01e3 --- /dev/null +++ b/mmseg/models/decode_heads/isa_head.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Self-Attention Module. + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict | None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.output_project = self.build_project( + in_channels, + in_channels, + num_convs=1, + use_conv_module=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + """Forward function.""" + context = super().forward(x, x) + return self.output_project(context) + + +@MODELS.register_module() +class ISAHead(BaseDecodeHead): + """Interlaced Sparse Self-Attention for Semantic Segmentation. + + This head is the implementation of `ISA + `_. + + Args: + isa_channels (int): The channels of ISA Module. + down_factor (tuple[int]): The local group size of ISA. + """ + + def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): + super().__init__(**kwargs) + self.down_factor = down_factor + + self.in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.global_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.local_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.out_conv = ConvModule( + self.channels * 2, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x_ = self._transform_inputs(inputs) + x = self.in_conv(x_) + residual = x + + n, c, h, w = x.size() + loc_h, loc_w = self.down_factor # size of local group in H- and W-axes + glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w) + pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w + if pad_h > 0 or pad_w > 0: # pad if the size is not divisible + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2) + x = F.pad(x, padding) + + # global relation + x = x.view(n, c, glb_h, loc_h, glb_w, loc_w) + # do permutation to gather global group + x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w) + x = x.reshape(-1, c, glb_h, glb_w) + # apply attention within each global group + x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w) + + # local relation + x = x.view(n, loc_h, loc_w, c, glb_h, glb_w) + # do permutation to gather local group + x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w) + x = x.reshape(-1, c, loc_h, loc_w) + # apply attention within each local group + x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w) + + # permute each pixel back to its original position + x = x.view(n, glb_h, glb_w, c, loc_h, loc_w) + x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w) + x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w) + if pad_h > 0 or pad_w > 0: # remove padding + x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w] + + x = self.out_conv(torch.cat([x, residual], dim=1)) + out = self.cls_seg(x) + + return out diff --git a/mmseg/models/decode_heads/knet_head.py b/mmseg/models/decode_heads/knet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..82d3a2807685cdc896c881095f46fd50a450018e --- /dev/null +++ b/mmseg/models/decode_heads/knet_head.py @@ -0,0 +1,461 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention, + build_transformer_layer) +from mmengine.logging import print_log +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +@MODELS.register_module() +class KernelUpdator(nn.Module): + """Dynamic Kernel Updator in Kernel Update Head. + + Args: + in_channels (int): The number of channels of input feature map. + Default: 256. + feat_channels (int): The number of middle-stage channels in + the kernel updator. Default: 64. + out_channels (int): The number of output channels. + gate_sigmoid (bool): Whether use sigmoid function in gate + mechanism. Default: True. + gate_norm_act (bool): Whether add normalization and activation + layer in gate mechanism. Default: False. + activate_out: Whether add activation after gate mechanism. + Default: False. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='LN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='ReLU', inplace=True), + ): + super().__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + """Forward function of KernelUpdator. + + Args: + update_feature (torch.Tensor): Feature map assembled from + each group. It would be reshaped with last dimension + shape: `self.in_channels`. + input_feature (torch.Tensor): Intermediate feature + with shape: (N, num_classes, conv_kernel_size**2, channels). + Returns: + Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is + the number of classes, C1 and C2 are the feature map channels of + KernelUpdateHead and KernelUpdator, respectively. + """ + + update_feature = update_feature.reshape(-1, self.in_channels) + num_proposals = update_feature.size(0) + # dynamic_layer works for + # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper + parameters = self.dynamic_layer(update_feature) + param_in = parameters[:, :self.num_params_in].view( + -1, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels) + + # input_layer works for + # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper + input_feats = self.input_layer( + input_feature.reshape(num_proposals, -1, self.feat_channels)) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + # `gate_feats` is F^G in K-Net paper + gate_feats = input_in * param_in.unsqueeze(-2) + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # Gate mechanism. Eq.(5) in original paper. + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out.unsqueeze( + -2) + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +@MODELS.register_module() +class KernelUpdateHead(nn.Module): + """Kernel Update Head in K-Net. + + Args: + num_classes (int): Number of classes. Default: 150. + num_ffn_fcs (int): The number of fully-connected layers in + FFNs. Default: 2. + num_heads (int): The number of parallel attention heads. + Default: 8. + num_mask_fcs (int): The number of fully connected layers for + mask prediction. Default: 3. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 2048. + in_channels (int): The number of channels of input feature map. + Default: 256. + out_channels (int): The number of output channels. + Default: 256. + dropout (float): The Probability of an element to be + zeroed in MultiheadAttention and FFN. Default 0.0. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + ffn_act_cfg (dict): Config of activation layers in FFN. + Default: dict(type='ReLU'). + conv_kernel_size (int): The kernel size of convolution in + Kernel Update Head for dynamic kernel updation. + Default: 1. + feat_transform_cfg (dict | None): Config of feature transform. + Default: None. + kernel_init (bool): Whether initiate mask kernel in mask head. + Default: False. + with_ffn (bool): Whether add FFN in kernel update head. + Default: True. + feat_gather_stride (int): Stride of convolution in feature transform. + Default: 1. + mask_transform_stride (int): Stride of mask transform. + Default: 1. + kernel_updator_cfg (dict): Config of kernel updator. + Default: dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')). + """ + + def __init__(self, + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=1, + feat_transform_cfg=None, + kernel_init=False, + with_ffn=True, + feat_gather_stride=1, + mask_transform_stride=1, + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.fp16_enabled = False + self.dropout = dropout + self.num_heads = num_heads + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + transform_channels = in_channels + self.feat_transform = ConvModule( + transform_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + dropout=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.kernel_init: + print_log( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, x, proposal_feat, mask_preds, mask_shape=None): + """Forward function of Dynamic Instance Interactive Head. + + Args: + x (Tensor): Feature map from FPN with shape + (batch_size, feature_dimensions, H , W). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size, num_proposals, feature_dimensions) + mask_preds (Tensor): mask prediction from the former stage in shape + (batch_size, num_proposals, H, W). + + Returns: + Tuple: The first tensor is predicted mask with shape + (N, num_classes, H, W), the second tensor is dynamic kernel + with shape (N, num_classes, channels, K, K). + """ + N, num_proposals = proposal_feat.shape[:2] + if self.feat_transform is not None: + x = self.feat_transform(x) + + C, H, W = x.shape[-3:] + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds, (H, W), align_corners=False, mode='bilinear') + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.softmax(dim=1) + + # Group Feature Assembling. Eq.(3) in original paper. + # einsum is faster than bmm by 30% + x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + mask_feat = obj_feat + + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # but in real training group conv is slower than concat batch + # so we keep using concat batch. + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i:i + 1], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + + return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_proposals, self.in_channels, self.conv_kernel_size, + self.conv_kernel_size) + + +@MODELS.register_module() +class IterativeDecodeHead(BaseDecodeHead): + """K-Net: Towards Unified Image Segmentation. + + This head is the implementation of + `K-Net: `_. + + Args: + num_stages (int): The number of stages (kernel update heads) + in IterativeDecodeHead. Default: 3. + kernel_generate_head:(dict): Config of kernel generate head which + generate mask predictions, dynamic kernels and class predictions + for next kernel update heads. + kernel_update_head (dict): Config of kernel update head which refine + dynamic kernels and class predictions iteratively. + + """ + + def __init__(self, num_stages, kernel_generate_head, kernel_update_head, + **kwargs): + # ``IterativeDecodeHead`` would skip initialization of + # ``BaseDecodeHead`` which would be called when building + # ``self.kernel_generate_head``. + super(BaseDecodeHead, self).__init__(**kwargs) + assert num_stages == len(kernel_update_head) + self.num_stages = num_stages + self.kernel_generate_head = MODELS.build(kernel_generate_head) + self.kernel_update_head = nn.ModuleList() + self.align_corners = self.kernel_generate_head.align_corners + self.num_classes = self.kernel_generate_head.num_classes + self.input_transform = self.kernel_generate_head.input_transform + self.ignore_index = self.kernel_generate_head.ignore_index + self.out_channels = self.num_classes + + for head_cfg in kernel_update_head: + self.kernel_update_head.append(MODELS.build(head_cfg)) + + def forward(self, inputs): + """Forward function.""" + feats = self.kernel_generate_head._forward_feature(inputs) + sem_seg = self.kernel_generate_head.cls_seg(feats) + seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() + seg_kernels = seg_kernels[None].expand( + feats.size(0), *seg_kernels.size()) + + stage_segs = [sem_seg] + for i in range(self.num_stages): + sem_seg, seg_kernels = self.kernel_update_head[i](feats, + seg_kernels, + sem_seg) + stage_segs.append(sem_seg) + if self.training: + return stage_segs + # only return the prediction of the last stage during testing + return stage_segs[-1] + + def loss_by_feat(self, seg_logits: List[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + losses = dict() + for i, logit in enumerate(seg_logits): + loss = self.kernel_generate_head.loss_by_feat( + logit, batch_data_samples) + for k, v in loss.items(): + losses[f'{k}.s{i}'] = v + + return losses diff --git a/mmseg/models/decode_heads/lraspp_head.py b/mmseg/models/decode_heads/lraspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2465f27522e6ff106fcdf94a46aab42881260a --- /dev/null +++ b/mmseg/models/decode_heads/lraspp_head.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils import is_tuple_of + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class LRASPPHead(BaseDecodeHead): + """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. + + This head is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + branch_channels (tuple[int]): The number of output channels in every + each branch. Default: (32, 64). + """ + + def __init__(self, branch_channels=(32, 64), **kwargs): + super().__init__(**kwargs) + if self.input_transform != 'multiple_select': + raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' + f'must be \'multiple_select\'. But received ' + f'\'{self.input_transform}\'') + assert is_tuple_of(branch_channels, int) + assert len(branch_channels) == len(self.in_channels) - 1 + self.branch_channels = branch_channels + + self.convs = nn.Sequential() + self.conv_ups = nn.Sequential() + for i in range(len(branch_channels)): + self.convs.add_module( + f'conv{i}', + nn.Conv2d( + self.in_channels[i], branch_channels[i], 1, bias=False)) + self.conv_ups.add_module( + f'conv_up{i}', + ConvModule( + self.channels + branch_channels[i], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False)) + + self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) + + self.aspp_conv = ConvModule( + self.in_channels[-1], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False) + self.image_pool = nn.Sequential( + nn.AvgPool2d(kernel_size=49, stride=(16, 20)), + ConvModule( + self.in_channels[2], + self.channels, + 1, + act_cfg=dict(type='Sigmoid'), + bias=False)) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + x = inputs[-1] + + x = self.aspp_conv(x) * resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = self.conv_up_input(x) + + for i in range(len(self.branch_channels) - 1, -1, -1): + x = resize( + x, + size=inputs[i].size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = torch.cat([x, self.convs[i](inputs[i])], 1) + x = self.conv_ups[i](x) + + return self.cls_seg(x) diff --git a/mmseg/models/decode_heads/mask2former_head.py b/mmseg/models/decode_heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0135af0645830f5cf98595318c4bb20220e64b0b --- /dev/null +++ b/mmseg/models/decode_heads/mask2former_head.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +try: + from mmdet.models.dense_heads import \ + Mask2FormerHead as MMDET_Mask2FormerHead +except ModuleNotFoundError: + MMDET_Mask2FormerHead = BaseModule + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class Mask2FormerHead(MMDET_Mask2FormerHead): + """Implements the Mask2Former head. + + See `Mask2Former: Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + num_classes, + align_corners=False, + ignore_index=255, + **kwargs): + super().__init__(**kwargs) + + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros( + (0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg).long() + else: + gt_masks = torch.stack(masks).squeeze(1).long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + batch_data_samples = [ + SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas + ] + + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + if 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'] + else: + size = batch_img_metas[0]['img_shape'] + # upsample mask + mask_pred_results = F.interpolate( + mask_pred_results, size=size, mode='bilinear', align_corners=False) + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/mmseg/models/decode_heads/maskformer_head.py b/mmseg/models/decode_heads/maskformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6e61a7f63a33a508955a866e57c139ce8c40e0f6 --- /dev/null +++ b/mmseg/models/decode_heads/maskformer_head.py @@ -0,0 +1,174 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +try: + from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead +except ModuleNotFoundError: + MMDET_MaskFormerHead = BaseModule + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class MaskFormerHead(MMDET_MaskFormerHead): + """Implements the MaskFormer head. + + See `Per-Pixel Classification is Not All You Need for Semantic Segmentation + `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + num_classes: int = 150, + align_corners: bool = False, + ignore_index: int = 255, + **kwargs) -> None: + super().__init__(**kwargs) + + self.out_channels = kwargs['out_channels'] + self.align_corners = True + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + for data_sample in batch_data_samples: + # Add `batch_input_shape` in metainfo of data_sample, which would + # be used in MaskFormerHead of MMDetection. + metainfo = data_sample.metainfo + metainfo['batch_input_shape'] = metainfo['img_shape'] + data_sample.set_metainfo(metainfo) + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros((0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg) + else: + gt_masks = torch.stack(masks).squeeze(1) + + instance_data = InstanceData( + labels=gt_labels, masks=gt_masks.long()) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + + batch_data_samples = [] + for metainfo in batch_img_metas: + metainfo['batch_input_shape'] = metainfo['img_shape'] + batch_data_samples.append(SegDataSample(metainfo=metainfo)) + # Forward function of MaskFormerHead from MMDetection needs + # 'batch_data_samples' as inputs, which is image shape actually. + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=img_shape, + mode='bilinear', + align_corners=False) + + # semantic inference + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/mmseg/models/decode_heads/nl_head.py b/mmseg/models/decode_heads/nl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffcc2a2f081127f109deb0ad5bd1be0d6f50493 --- /dev/null +++ b/mmseg/models/decode_heads/nl_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class NLHead(FCNHead): + """Non-local Neural Networks. + + This head is the implementation of `NLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: True. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.nl_block = NonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.nl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/ocr_head.py b/mmseg/models/decode_heads/ocr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9afe37bebd6c16ff184dc482ae358eb7ae9a093a --- /dev/null +++ b/mmseg/models/decode_heads/ocr_head.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from ..utils import resize +from .cascade_decode_head import BaseCascadeDecodeHead + + +class SpatialGatherModule(nn.Module): + """Aggregate the context features according to the initial predicted + probability distribution. + + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, feats, probs): + """Forward function.""" + batch_size, num_classes, height, width = probs.size() + channels = feats.size(1) + probs = probs.view(batch_size, num_classes, -1) + feats = feats.view(batch_size, channels, -1) + # [batch_size, height*width, num_classes] + feats = feats.permute(0, 2, 1) + # [batch_size, channels, height*width] + probs = F.softmax(self.scale * probs, dim=2) + # [batch_size, channels, num_classes] + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(_SelfAttentionBlock): + """Make a OCR used SelfAttentionBlock.""" + + def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, + act_cfg): + if scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=scale) + else: + query_downsample = None + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=query_downsample, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.bottleneck = ConvModule( + in_channels * 2, + in_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, query_feats, key_feats): + """Forward function.""" + context = super().forward(query_feats, key_feats) + output = self.bottleneck(torch.cat([context, query_feats], dim=1)) + if self.query_downsample is not None: + output = resize(query_feats) + + return output + + +@MODELS.register_module() +class OCRHead(BaseCascadeDecodeHead): + """Object-Contextual Representations for Semantic Segmentation. + + This head is the implementation of `OCRNet + `_. + + Args: + ocr_channels (int): The intermediate channels of OCR block. + scale (int): The scale of probability map in SpatialGatherModule in + Default: 1. + """ + + def __init__(self, ocr_channels, scale=1, **kwargs): + super().__init__(**kwargs) + self.ocr_channels = ocr_channels + self.scale = scale + self.object_context_block = ObjectAttentionBlock( + self.channels, + self.ocr_channels, + self.scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.spatial_gather_module = SpatialGatherModule(self.scale) + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs, prev_output): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.bottleneck(x) + context = self.spatial_gather_module(feats, prev_output) + object_context = self.object_context_block(feats, context) + output = self.cls_seg(object_context) + + return output diff --git a/mmseg/models/decode_heads/pid_head.py b/mmseg/models/decode_heads/pid_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c092cb32d07c279c1d6a45d2e02baccb8e5ffa33 --- /dev/null +++ b/mmseg/models/decode_heads/pid_head.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType, SampleList + + +class BasePIDHead(BaseModule): + """Base class for PID head. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict or list[dict], optional): Init config dict. + Default: None. + """ + + def __init__(self, + in_channels: int, + channels: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv = ConvModule( + in_channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('norm', 'act', 'conv')) + _, self.norm = build_norm_layer(norm_cfg, num_features=channels) + self.act = build_activation_layer(act_cfg) + + def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor: + """Forward function. + Args: + x (Tensor): Input tensor. + cls_seg (nn.Module, optional): The classification head. + + Returns: + Tensor: Output tensor. + """ + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + if cls_seg is not None: + x = cls_seg(x) + return x + + +@MODELS.register_module() +class PIDHead(BaseDecodeHead): + """Decode head for PIDNet. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_classes (int): Number of classes. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + """ + + def __init__(self, + in_channels: int, + channels: int, + num_classes: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + **kwargs): + super().__init__( + in_channels, + channels, + num_classes=num_classes, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg) + self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg, + act_cfg) + self.d_head = BasePIDHead( + in_channels // 2, + in_channels // 4, + norm_cfg, + ) + self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward( + self, + inputs: Union[Tensor, + Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + Args: + inputs (Tensor | tuple[Tensor]): Input tensor or tuple of + Tensor. When training, the input is a tuple of three tensors, + (p_feat, i_feat, d_feat), and the output is a tuple of three + tensors, (p_seg_logit, i_seg_logit, d_seg_logit). + When inference, only the head of integral branch is used, and + input is a tensor of integral feature map, and the output is + the segmentation logit. + + Returns: + Tensor | tuple[Tensor]: Output tensor or tuple of tensors. + """ + if self.training: + x_p, x_i, x_d = inputs + x_p = self.p_head(x_p, self.p_cls_seg) + x_i = self.i_head(x_i, self.cls_seg) + x_d = self.d_head(x_d, self.d_cls_seg) + return x_p, x_i, x_d + else: + return self.i_head(inputs, self.cls_seg) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + gt_edge_segs = [ + data_sample.gt_edge_map.data for data_sample in batch_data_samples + ] + gt_sem_segs = torch.stack(gt_semantic_segs, dim=0) + gt_edge_segs = torch.stack(gt_edge_segs, dim=0) + return gt_sem_segs, gt_edge_segs + + def loss_by_feat(self, seg_logits: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + loss = dict() + p_logit, i_logit, d_logit = seg_logits + sem_label, bd_label = self._stack_batch_gt(batch_data_samples) + p_logit = resize( + input=p_logit, + size=sem_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + i_logit = resize( + input=i_logit, + size=sem_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + d_logit = resize( + input=d_logit, + size=bd_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + sem_label = sem_label.squeeze(1) + bd_label = bd_label.squeeze(1) + loss['loss_sem_p'] = self.loss_decode[0]( + p_logit, sem_label, ignore_index=self.ignore_index) + loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label) + loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label) + filler = torch.ones_like(sem_label) * self.ignore_index + sem_bd_label = torch.where( + torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler) + loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label) + loss['acc_seg'] = accuracy( + i_logit, sem_label, ignore_index=self.ignore_index) + return loss diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e433d66249a4690cea3e33e95ec54d58ee3a07 --- /dev/null +++ b/mmseg/models/decode_heads/point_head.py @@ -0,0 +1,367 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +try: + from mmcv.ops import point_sample +except ModuleNotFoundError: + point_sample = None + +from typing import List + +from mmseg.registry import MODELS +from mmseg.utils import SampleList +from ..losses import accuracy +from ..utils import resize +from .cascade_decode_head import BaseCascadeDecodeHead + + +def calculate_uncertainty(seg_logits): + """Estimate uncertainty based on seg logits. + + For each location of the prediction ``seg_logits`` we estimate + uncertainty as the difference between top first and top second + predicted logits. + + Args: + seg_logits (Tensor): Semantic segmentation logits, + shape (batch_size, num_classes, height, width). + + Returns: + scores (Tensor): T uncertainty scores with the most uncertain + locations having the highest uncertainty score, shape ( + batch_size, 1, height, width) + """ + top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +@MODELS.register_module() +class PointHead(BaseCascadeDecodeHead): + """A mask point head use in PointRend. + + This head is implemented of `PointRend: Image Segmentation as + Rendering `_. + ``PointHead`` use shared multi-layer perceptron (equivalent to + nn.Conv1d) to predict the logit of input points. The fine-grained feature + and coarse feature will be concatenate together for predication. + + Args: + num_fcs (int): Number of fc layers in the head. Default: 3. + in_channels (int): Number of input channels. Default: 256. + fc_channels (int): Number of fc channels. Default: 256. + num_classes (int): Number of classes for logits. Default: 80. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Default: False. + coarse_pred_each_layer (bool): Whether concatenate coarse feature with + the output of each fc layer. Default: True. + conv_cfg (dict|None): Dictionary to construct and config conv layer. + Default: dict(type='Conv1d')) + norm_cfg (dict|None): Dictionary to construct and config norm layer. + Default: None. + loss_point (dict): Dictionary to construct and config loss layer of + point head. Default: dict(type='CrossEntropyLoss', use_mask=True, + loss_weight=1.0). + """ + + def __init__(self, + num_fcs=3, + coarse_pred_each_layer=True, + conv_cfg=dict(type='Conv1d'), + norm_cfg=None, + act_cfg=dict(type='ReLU', inplace=False), + **kwargs): + super().__init__( + input_transform='multiple_select', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='fc_seg')), + **kwargs) + if point_sample is None: + raise RuntimeError('Please install mmcv-full for ' + 'point_sample ops') + + self.num_fcs = num_fcs + self.coarse_pred_each_layer = coarse_pred_each_layer + + fc_in_channels = sum(self.in_channels) + self.num_classes + fc_channels = self.channels + self.fcs = nn.ModuleList() + for k in range(num_fcs): + fc = ConvModule( + fc_in_channels, + fc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.fcs.append(fc) + fc_in_channels = fc_channels + fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ + else 0 + self.fc_seg = nn.Conv1d( + fc_in_channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0) + if self.dropout_ratio > 0: + self.dropout = nn.Dropout(self.dropout_ratio) + delattr(self, 'conv_seg') + + def cls_seg(self, feat): + """Classify each pixel with fc.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.fc_seg(feat) + return output + + def forward(self, fine_grained_point_feats, coarse_point_feats): + x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) + for fc in self.fcs: + x = fc(x) + if self.coarse_pred_each_layer: + x = torch.cat((x, coarse_point_feats), dim=1) + return self.cls_seg(x) + + def _get_fine_grained_point_feats(self, x, points): + """Sample from fine grained features. + + Args: + x (list[Tensor]): Feature pyramid from by neck or backbone. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + fine_grained_feats (Tensor): Sampled fine grained feature, + shape (batch_size, sum(channels of x), num_points). + """ + + fine_grained_feats_list = [ + point_sample(_, points, align_corners=self.align_corners) + for _ in x + ] + if len(fine_grained_feats_list) > 1: + fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) + else: + fine_grained_feats = fine_grained_feats_list[0] + + return fine_grained_feats + + def _get_coarse_point_feats(self, prev_output, points): + """Sample from fine grained features. + + Args: + prev_output (list[Tensor]): Prediction of previous decode head. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, + num_classes, num_points). + """ + + coarse_feats = point_sample( + prev_output, points, align_corners=self.align_corners) + + return coarse_feats + + def loss(self, inputs, prev_output, batch_data_samples: SampleList, + train_cfg, **kwargs): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self._transform_inputs(inputs) + with torch.no_grad(): + points = self.get_points_train( + prev_output, calculate_uncertainty, cfg=train_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats(prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + losses = self.loss_by_feat(point_logits, points, batch_data_samples) + + return losses + + def predict(self, inputs, prev_output, batch_img_metas: List[dict], + test_cfg, **kwargs): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + + x = self._transform_inputs(inputs) + refined_seg_logits = prev_output.clone() + for _ in range(test_cfg.subdivision_steps): + refined_seg_logits = resize( + refined_seg_logits, + scale_factor=test_cfg.scale_factor, + mode='bilinear', + align_corners=self.align_corners) + batch_size, channels, height, width = refined_seg_logits.shape + point_indices, points = self.get_points_test( + refined_seg_logits, calculate_uncertainty, cfg=test_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats( + prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) + refined_seg_logits = refined_seg_logits.reshape( + batch_size, channels, height * width) + refined_seg_logits = refined_seg_logits.scatter_( + 2, point_indices, point_logits) + refined_seg_logits = refined_seg_logits.view( + batch_size, channels, height, width) + + return self.predict_by_feat(refined_seg_logits, batch_img_metas, + **kwargs) + + def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs): + """Compute segmentation loss.""" + gt_semantic_seg = self._stack_batch_gt(batch_data_samples) + point_label = point_sample( + gt_semantic_seg.float(), + points, + mode='nearest', + align_corners=self.align_corners) + point_label = point_label.squeeze(1).long() + + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_module in losses_decode: + loss['point' + loss_module.loss_name] = loss_module( + point_logits, point_label, ignore_index=self.ignore_index) + + loss['acc_point'] = accuracy( + point_logits, point_label, ignore_index=self.ignore_index) + return loss + + def get_points_train(self, seg_logits, uncertainty_func, cfg): + """Sample points for training. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'uncertainty_func' function that takes point's logit prediction as + input. + + Args: + seg_logits (Tensor): Semantic segmentation logits, shape ( + batch_size, num_classes, height, width). + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Training config of point head. + + Returns: + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains the coordinates of ``num_points`` sampled + points. + """ + num_points = cfg.num_points + oversample_ratio = cfg.oversample_ratio + importance_sample_ratio = cfg.importance_sample_ratio + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = seg_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=seg_logits.device) + point_logits = point_sample(seg_logits, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=seg_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_point_coords = torch.rand( + batch_size, num_random_points, 2, device=seg_logits.device) + point_coords = torch.cat((point_coords, rand_point_coords), dim=1) + return point_coords + + def get_points_test(self, seg_logits, uncertainty_func, cfg): + """Sample points for testing. + + Find ``num_points`` most uncertain points from ``uncertainty_map``. + + Args: + seg_logits (Tensor): A tensor of shape (batch_size, num_classes, + height, width) for class-specific or class-agnostic prediction. + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Testing config of point head. + + Returns: + point_indices (Tensor): A tensor of shape (batch_size, num_points) + that contains indices from [0, height x width) of the most + uncertain points. + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the ``height x width`` grid . + """ + + num_points = cfg.subdivision_num_points + uncertainty_map = uncertainty_func(seg_logits) + batch_size, _, height, width = uncertainty_map.shape + h_step = 1.0 / height + w_step = 1.0 / width + + uncertainty_map = uncertainty_map.view(batch_size, height * width) + num_points = min(height * width, num_points) + point_indices = uncertainty_map.topk(num_points, dim=1)[1] + point_coords = torch.zeros( + batch_size, + num_points, + 2, + dtype=torch.float, + device=seg_logits.device) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % + width).float() * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // + width).float() * h_step + return point_indices, point_coords diff --git a/mmseg/models/decode_heads/psa_head.py b/mmseg/models/decode_heads/psa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..13ee5c58a569bb46612625b85685cd61b7e9df3e --- /dev/null +++ b/mmseg/models/decode_heads/psa_head.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + +try: + from mmcv.ops import PSAMask +except ModuleNotFoundError: + PSAMask = None + + +@MODELS.register_module() +class PSAHead(BaseDecodeHead): + """Point-wise Spatial Attention Network for Scene Parsing. + + This head is the implementation of `PSANet + `_. + + Args: + mask_size (tuple[int]): The PSA mask size. It usually equals input + size. + psa_type (str): The type of psa module. Options are 'collect', + 'distribute', 'bi-direction'. Default: 'bi-direction' + compact (bool): Whether use compact map for 'collect' mode. + Default: True. + shrink_factor (int): The downsample factors of psa mask. Default: 2. + normalization_factor (float): The normalize factor of attention. + psa_softmax (bool): Whether use softmax for attention. + """ + + def __init__(self, + mask_size, + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + **kwargs): + if PSAMask is None: + raise RuntimeError('Please install mmcv-full for PSAMask ops') + super().__init__(**kwargs) + assert psa_type in ['collect', 'distribute', 'bi-direction'] + self.psa_type = psa_type + self.compact = compact + self.shrink_factor = shrink_factor + self.mask_size = mask_size + mask_h, mask_w = mask_size + self.psa_softmax = psa_softmax + if normalization_factor is None: + normalization_factor = mask_h * mask_w + self.normalization_factor = normalization_factor + + self.reduce = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + if psa_type == 'bi-direction': + self.reduce_p = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention_p = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + self.psamask_collect = PSAMask('collect', mask_size) + self.psamask_distribute = PSAMask('distribute', mask_size) + else: + self.psamask = PSAMask(psa_type, mask_size) + self.proj = ConvModule( + self.channels * (2 if psa_type == 'bi-direction' else 1), + self.in_channels, + kernel_size=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + self.in_channels * 2, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + identity = x + align_corners = self.align_corners + if self.psa_type in ['collect', 'distribute']: + out = self.reduce(x) + n, c, h, w = out.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + out = resize( + out, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y = self.attention(out) + if self.compact: + if self.psa_type == 'collect': + y = y.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y = self.psamask(y) + if self.psa_softmax: + y = F.softmax(y, dim=1) + out = torch.bmm( + out.view(n, c, h * w), y.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + else: + x_col = self.reduce(x) + x_dis = self.reduce_p(x) + n, c, h, w = x_col.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + x_col = resize( + x_col, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + x_dis = resize( + x_dis, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y_col = self.attention(x_col) + y_dis = self.attention_p(x_dis) + if self.compact: + y_dis = y_dis.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y_col = self.psamask_collect(y_col) + y_dis = self.psamask_distribute(y_dis) + if self.psa_softmax: + y_col = F.softmax(y_col, dim=1) + y_dis = F.softmax(y_dis, dim=1) + x_col = torch.bmm( + x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + x_dis = torch.bmm( + x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + out = torch.cat([x_col, x_dis], 1) + out = self.proj(out) + out = resize( + out, + size=identity.shape[2:], + mode='bilinear', + align_corners=align_corners) + out = self.bottleneck(torch.cat((identity, out), dim=1)) + out = self.cls_seg(out) + return out diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a40ec41dec281e53815e9753ee2ba1a5da76bd05 --- /dev/null +++ b/mmseg/models/decode_heads/psp_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class PPM(nn.ModuleList): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, + act_cfg, align_corners, **kwargs): + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **kwargs))) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = resize( + ppm_out, + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +@MODELS.register_module() +class PSPHead(BaseDecodeHead): + """Pyramid Scene Parsing Network. + + This head is the implementation of + `PSPNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super().__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.psp_modules = PPM( + self.pool_scales, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + feats = self.bottleneck(psp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/san_head.py b/mmseg/models/decode_heads/san_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d20da801924080efeee30a246331af2e2e5df352 --- /dev/null +++ b/mmseg/models/decode_heads/san_head.py @@ -0,0 +1,736 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from mmcv.ops import point_sample +from mmengine.dist import all_reduce +from mmengine.model.weight_init import (caffe2_xavier_init, normal_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn import functional as F + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, MatchMasks, SampleList, + seg_data_to_instance_data) +from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer, + get_uncertain_point_coords_with_randomness, resize) +from .decode_head import BaseDecodeHead + + +class MLPMaskDecoder(nn.Module): + """Module for decoding query and visual features with MLP layers to + generate the attention biases and the mask proposals.""" + + def __init__( + self, + *, + in_channels: int, + total_heads: int = 1, + total_layers: int = 1, + embed_channels: int = 256, + mlp_channels: int = 256, + mlp_num_layers: int = 3, + rescale_attn_bias: bool = False, + ): + super().__init__() + self.total_heads = total_heads + self.total_layers = total_layers + + dense_affine_func = partial(nn.Conv2d, kernel_size=1) + # Query Branch + self.query_mlp = MLP(in_channels, mlp_channels, embed_channels, + mlp_num_layers) + # Pixel Branch + self.pix_mlp = MLP( + in_channels, + mlp_channels, + embed_channels, + mlp_num_layers, + affine_func=dense_affine_func, + ) + # Attention Bias Branch + self.attn_mlp = MLP( + in_channels, + mlp_channels, + embed_channels * self.total_heads * self.total_layers, + mlp_num_layers, + affine_func=dense_affine_func, + ) + if rescale_attn_bias: + self.bias_scaling = nn.Linear(1, 1) + else: + self.bias_scaling = nn.Identity() + + def forward(self, query: torch.Tensor, + x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward function. + Args: + query (Tensor): Query Tokens [B,N,C]. + x (Tensor): Visual features [B,C,H,W] + + Return: + mask_preds (Tensor): Mask proposals. + attn_bias (List[Tensor]): List of attention bias. + """ + query = self.query_mlp(query) + pix = self.pix_mlp(x) + b, c, h, w = pix.shape + # preidict mask + mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix) + # generate attn bias + attn = self.attn_mlp(x) + attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w) + attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn) + attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1) + attn_bias = attn_bias.chunk(self.total_layers, dim=1) + attn_bias = [attn.squeeze(1) for attn in attn_bias] + return mask_preds, attn_bias + + +class SideAdapterNetwork(nn.Module): + """Side Adapter Network for predicting mask proposals and attention bias. + + Args: + in_channels (int): Number of input channels. Default: 3. + clip_channels (int): Number of channels of visual features. + Default: 768. + embed_dims (int): embedding dimension. Default: 240. + patch_size (int): The patch size. Default: 16. + patch_bias (bool): Whether use bias in patch embedding. + Default: True. + num_queries (int): Number of queries for mask proposals. + Default: 100. + fusion_index (List[int]): The layer number of the encode + transformer to fuse with the CLIP feature. + Default: [0, 1, 2, 3]. + cfg_encoder (ConfigType): Configs for the encode layers. + cfg_decoder (ConfigType): Configs for the decode layers. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + """ + + def __init__( + self, + in_channels: int = 3, + clip_channels: int = 768, + embed_dims: int = 240, + patch_size: int = 16, + patch_bias: bool = True, + num_queries: int = 100, + fusion_index: list = [0, 1, 2, 3], + cfg_encoder: ConfigType = ..., + cfg_decoder: ConfigType = ..., + norm_cfg: dict = dict(type='LN'), + ): + super().__init__() + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding=0, + input_size=(640, 640), + bias=patch_bias, + norm_cfg=None, + init_cfg=None, + ) + ori_h, ori_w = self.patch_embed.init_out_size + num_patches = ori_h * ori_w + self.pos_embed = nn.Parameter( + torch.randn(1, num_patches, embed_dims) * .02) + self.query_pos_embed = nn.Parameter( + torch.zeros(1, num_queries, embed_dims)) + self.query_embed = nn.Parameter( + torch.zeros(1, num_queries, embed_dims)) + encode_layers = [] + for i in range(cfg_encoder.num_encode_layer): + encode_layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=cfg_encoder.num_heads, + feedforward_channels=cfg_encoder.mlp_ratio * embed_dims, + norm_cfg=norm_cfg)) + self.encode_layers = nn.ModuleList(encode_layers) + conv_clips = [] + for i in range(len(fusion_index)): + conv_clips.append( + nn.Sequential( + LayerNorm2d(clip_channels), + ConvModule( + clip_channels, + embed_dims, + kernel_size=1, + norm_cfg=None, + act_cfg=None))) + self.conv_clips = nn.ModuleList(conv_clips) + self.fusion_index = fusion_index + self.mask_decoder = MLPMaskDecoder( + in_channels=embed_dims, + total_heads=cfg_decoder.num_heads, + total_layers=cfg_decoder.num_layers, + embed_channels=cfg_decoder.embed_channels, + mlp_channels=cfg_decoder.mlp_channels, + mlp_num_layers=cfg_decoder.num_mlp, + rescale_attn_bias=cfg_decoder.rescale) + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.query_embed, std=0.02) + nn.init.normal_(self.query_pos_embed, std=0.02) + for i in range(len(self.conv_clips)): + caffe2_xavier_init(self.conv_clips[i][1].conv) + + def fuse_clip(self, fused_index: int, x: torch.Tensor, + clip_feature: torch.Tensor, hwshape: Tuple[int, + int], L: int): + """Fuse CLIP feature and visual tokens.""" + fused_clip = (resize( + self.conv_clips[fused_index](clip_feature.contiguous()), + size=hwshape, + mode='bilinear', + align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:, + ...].shape) + x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1) + return x + + def encode_feature(self, image: torch.Tensor, + clip_features: List[torch.Tensor], + deep_supervision_idxs: List[int]) -> List[List]: + """Encode images by a lightweight vision transformer.""" + assert len(self.fusion_index) == len(clip_features) + x, hwshape = self.patch_embed(image) + ori_h, ori_w = self.patch_embed.init_out_size + pos_embed = self.pos_embed + if self.pos_embed.shape[1] != x.shape[1]: + # resize the position embedding + pos_embed = ( + resize( + self.pos_embed.reshape(1, ori_h, ori_w, + -1).permute(0, 3, 1, 2), + size=hwshape, + mode='bicubic', + align_corners=False, + ).flatten(2).permute(0, 2, 1)) + pos_embed = torch.cat([ + self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed + ], + dim=1) + x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1) + x = x + pos_embed + L = hwshape[0] * hwshape[1] + fused_index = 0 + if self.fusion_index[fused_index] == 0: + x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L) + fused_index += 1 + outs = [] + for index, block in enumerate(self.encode_layers, start=1): + x = block(x) + if index < len(self.fusion_index + ) and index == self.fusion_index[fused_index]: + x = self.fuse_clip(fused_index, x, + clip_features[fused_index][0], hwshape, L) + fused_index += 1 + x_query = x[:, :-L, ...] + x_feat = x[:, -L:, ...].permute(0, 2, 1)\ + .reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1]) + + if index in deep_supervision_idxs or index == len( + self.encode_layers): + outs.append({'query': x_query, 'x': x_feat}) + + if index < len(self.encode_layers): + x = x + pos_embed + return outs + + def decode_feature(self, features): + mask_embeds = [] + attn_biases = [] + for feature in features: + mask_embed, attn_bias = self.mask_decoder(**feature) + mask_embeds.append(mask_embed) + attn_biases.append(attn_bias) + return mask_embeds, attn_biases + + def forward( + self, image: torch.Tensor, clip_features: List[torch.Tensor], + deep_supervision_idxs: List[int] + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """Forward function.""" + features = self.encode_feature(image, clip_features, + deep_supervision_idxs) + mask_embeds, attn_biases = self.decode_feature(features) + return mask_embeds, attn_biases + + +class RecWithAttnbias(nn.Module): + """Mask recognition module by applying the attention biases to rest deeper + CLIP layers. + + Args: + sos_token_format (str): The format of sos token. It should be + chosen from ["cls_token", "learnable_token", "pos_embedding"]. + Default: 'cls_token'. + sos_token_num (int): Number of sos token. It should be equal to + the number of quries. Default: 100. + num_layers (int): Number of rest CLIP layers for mask recognition. + Default: 3. + cross_attn (bool): Whether use cross attention to update sos token. + Default: False. + embed_dims (int): The feature dimension of CLIP layers. + Default: 768. + num_heads (int): Parallel attention heads of CLIP layers. + Default: 768. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Whether to use bias in multihead-attention. + Default: True. + out_dims (int): Number of channels of the output mask proposals. + It should be equal to the out_dims of text_encoder. + Default: 512. + final_norm (True): Whether use norm layer for sos token. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + frozen_exclude (List): List of parameters that are not to be frozen. + """ + + def __init__(self, + sos_token_format: str = 'cls_token', + sos_token_num: int = 100, + num_layers: int = 3, + cross_attn: bool = False, + embed_dims: int = 768, + num_heads: int = 12, + mlp_ratio: int = 4, + num_fcs: int = 2, + qkv_bias: bool = True, + out_dims: int = 512, + final_norm: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + frozen_exclude: List = []): + super().__init__() + + assert sos_token_format in [ + 'cls_token', 'learnable_token', 'pos_embedding' + ] + self.sos_token_format = sos_token_format + self.sos_token_num = sos_token_num + self.frozen_exclude = frozen_exclude + self.cross_attn = cross_attn + self.num_layers = num_layers + self.num_heads = num_heads + if sos_token_format in ['learnable_token', 'pos_embedding']: + self.sos_token = nn.Parameter( + torch.randn(sos_token_num, 1, self.proj.shape[0])) + self.frozen.append('sos_token') + + layers = [] + for i in range(num_layers): + layers.append( + BaseTransformerLayer( + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + batch_first=False, + bias=qkv_bias), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=mlp_ratio * embed_dims, + act_cfg=act_cfg), + operation_order=('norm', 'self_attn', 'norm', 'ffn'))) + self.layers = nn.ModuleList(layers) + + self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1] + self.proj = nn.Linear(embed_dims, out_dims, bias=False) + + self.final_norm = final_norm + self._freeze() + + def init_weights(self, rec_state_dict): + if hasattr(self, 'sos_token'): + normal_init(self.sos_token, std=0.02) + if rec_state_dict is not None: + load_state_dict(self, rec_state_dict, strict=False, logger=None) + else: + super().init_weights() + + def _freeze(self): + if 'all' in self.frozen_exclude: + return + for name, param in self.named_parameters(): + if not any([exclude in name for exclude in self.frozen_exclude]): + param.requires_grad = False + + def _build_attn_biases(self, attn_biases, target_shape): + formatted_attn_biases = [] + for attn_bias in attn_biases: + # convert it to proper format: N*num_head,L,L + # attn_bias: [N, num_head/1, num_sos,H,W] + n, num_head, num_sos, h, w = attn_bias.shape + # reshape and downsample + attn_bias = F.adaptive_max_pool2d( + attn_bias.reshape(n, num_head * num_sos, h, w), + output_size=target_shape) + attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape) + + true_num_head = self.num_heads + assert (num_head == 1 or num_head + == true_num_head), f'num_head={num_head} is not supported.' + if num_head == 1: + attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1) + attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1) + L = attn_bias.shape[-1] + if self.cross_attn: + # [n*num_head, num_sos, L] + formatted_attn_biases.append(attn_bias) + else: + # [n*num_head, num_sos+1+L, num_sos+1+L] + new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L, + num_sos + 1 + L) + new_attn_bias[:, :num_sos] = -100 + new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0 + new_attn_bias[:num_sos, num_sos] = -100 + new_attn_bias = ( + new_attn_bias[None, ...].expand(n * true_num_head, -1, + -1).clone()) + new_attn_bias[..., :num_sos, -L:] = attn_bias + formatted_attn_biases.append(new_attn_bias) + + if len(formatted_attn_biases) == 1: + formatted_attn_biases = [ + formatted_attn_biases[0] for _ in range(self.num_layers) + ] + return formatted_attn_biases + + def forward(self, bias: List[Tensor], feature: List[Tensor]): + """Forward function to recognize the category of masks + Args: + bias (List[Tensor]): Attention bias for transformer layers + feature (List[Tensor]): Output of the image encoder, + including cls_token and img_feature. + """ + cls_token = feature[1].unsqueeze(0) + img_feature = feature[0] + b, c, h, w = img_feature.shape + # construct clip shadow features + x = torch.cat( + [cls_token, + img_feature.reshape(b, c, -1).permute(2, 0, 1)]) + + # construct sos token + if self.sos_token_format == 'cls_token': + sos_token = cls_token.repeat(self.sos_token_num, 1, 1) + elif self.sos_token_format == 'learnable_token': + sos_token = self.sos_token.expand(-1, b, -1) + elif self.sos_token_format == 'pos_embedding': + sos_token = self.sos_token.expand(-1, b, -1) + cls_token + + # construct attn bias + attn_biases = self._build_attn_biases(bias, target_shape=(h, w)) + + if self.cross_attn: + for i, block in enumerate(self.layers): + if self.cross_attn: + sos_token = cross_attn_layer( + block, + sos_token, + x[1:, ], + attn_biases[i], + ) + if i < len(self.layers) - 1: + x = block(x) + else: + x = torch.cat([sos_token, x], dim=0) + for i, block in enumerate(self.layers): + x = block(x, attn_masks=[attn_biases[i]]) + sos_token = x[:self.sos_token_num] + + sos_token = sos_token.permute(1, 0, 2) # LND -> NLD + sos_token = self.ln_post(sos_token) + sos_token = self.proj(sos_token) + if self.final_norm: + sos_token = F.normalize(sos_token, dim=-1) + return sos_token + + +@MODELS.register_module() +class SideAdapterCLIPHead(BaseDecodeHead): + """Side Adapter Network (SAN) for open-vocabulary semantic segmentation + with pre-trained vision-language model. + + This decode head is the implementation of `Side Adapter Network + for Open-Vocabulary Semantic Segmentation` + . + Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501 + Copyright (c) 2023 MendelXu. + Licensed under the MIT License + + Args: + num_classes (int): the number of classes. + san_cfg (ConfigType): Configs for SideAdapterNetwork module + maskgen_cfg (ConfigType): Configs for RecWithAttnbias module + """ + + def __init__(self, num_classes: int, san_cfg: ConfigType, + maskgen_cfg: ConfigType, deep_supervision_idxs: List[int], + train_cfg: ConfigType, **kwargs): + super().__init__( + in_channels=san_cfg.in_channels, + channels=san_cfg.embed_dims, + num_classes=num_classes, + **kwargs) + assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \ + 'num_queries in san_cfg should be equal to sos_token_num ' \ + 'in maskgen_cfg' + del self.conv_seg + self.side_adapter_network = SideAdapterNetwork(**san_cfg) + self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg) + self.deep_supervision_idxs = deep_supervision_idxs + self.train_cfg = train_cfg + if train_cfg: + self.match_masks = MatchMasks( + num_points=train_cfg.num_points, + num_queries=san_cfg.num_queries, + num_classes=num_classes, + assigner=train_cfg.assigner) + + def init_weights(self): + + rec_state_dict = None + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') == 'Pretrained_Part': + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + rec_state_dict = checkpoint.copy() + para_prefix = 'decode_head.rec_with_attnbias' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + rec_state_dict.pop(k) + if para_prefix in k: + rec_state_dict[k[prefix_len:]] = v + + self.side_adapter_network.init_weights() + self.rec_with_attnbias.init_weights(rec_state_dict) + + def forward(self, inputs: Tuple[Tensor], + deep_supervision_idxs) -> Tuple[List]: + """Forward function. + + Args: + inputs (Tuple[Tensor]): A triplet including images, + list of multi-level visual features from image encoder and + class embeddings from text_encoder. + + Returns: + mask_props (List[Tensor]): Mask proposals predicted by SAN. + mask_logits (List[Tensor]): Class logits of mask proposals. + """ + imgs, clip_feature, class_embeds = inputs + # predict mask proposals and attention bias + mask_props, attn_biases = self.side_adapter_network( + imgs, clip_feature, deep_supervision_idxs) + + # mask recognition with attention bias + mask_embeds = [ + self.rec_with_attnbias(att_bias, clip_feature[-1]) + for att_bias in attn_biases + ] + # Obtain class prediction of masks by comparing the similarity + # between the image token and the text embedding of class names. + mask_logits = [ + torch.einsum('bqc,nc->bqn', mask_embed, class_embeds) + for mask_embed in mask_embeds + ] + return mask_props, mask_logits + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): Images, visual features from image encoder + and class embedding from text encoder. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + mask_props, mask_logits = self.forward(inputs, []) + + return self.predict_by_feat([mask_props[-1], mask_logits[-1]], + batch_img_metas) + + def predict_by_feat(self, seg_logits: List[Tensor], + batch_img_metas: List[dict]) -> Tensor: + """1. Transform a batch of mask proposals to the input shape. + 2. Generate segmentation map with mask proposals and class logits. + """ + mask_pred = seg_logits[0] + cls_score = seg_logits[1] + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] + else: + size = batch_img_metas[0]['img_shape'] + # upsample mask + mask_pred = F.interpolate( + mask_pred, size=size, mode='bilinear', align_corners=False) + + mask_cls = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred) + return seg_logits + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances = seg_data_to_instance_data(self.ignore_index, + batch_data_samples) + + # forward + all_mask_props, all_mask_logits = self.forward( + x, self.deep_supervision_idxs) + + # loss + losses = self.loss_by_feat(all_mask_logits, all_mask_props, + batch_gt_instances) + + return losses + + def loss_by_feat( + self, all_cls_scores: Tensor, all_mask_preds: Tensor, + batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + batch_gt_instances_list = [ + batch_gt_instances for _ in range(num_dec_layers) + ] + + losses = [] + for i in range(num_dec_layers): + cls_scores = all_cls_scores[i] + mask_preds = all_mask_preds[i] + # matching N mask predictions to K category labels + (labels, mask_targets, mask_weights, + avg_factor) = self.match_masks.get_targets( + cls_scores, mask_preds, batch_gt_instances_list[i]) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + num_total_masks = cls_scores.new_tensor([avg_factor], + dtype=torch.float) + all_reduce(num_total_masks, op='mean') + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] != 0: + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, + self.train_cfg.num_points, + self.train_cfg.oversample_ratio, + self.train_cfg.importance_sample_ratio) + # shape (num_total_gts, h, w) + # -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), + points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + loss = dict() + for loss_decode in losses_decode: + if 'loss_cls' in loss_decode.loss_name: + if loss_decode.loss_name == 'loss_cls_ce': + loss[loss_decode.loss_name] = loss_decode( + cls_scores, labels) + else: + assert False, "Only support 'CrossEntropyLoss' in" \ + ' classification loss' + + elif 'loss_mask' in loss_decode.loss_name: + if mask_targets.shape[0] == 0: + loss[loss_decode.loss_name] = mask_preds.sum() + elif loss_decode.loss_name == 'loss_mask_ce': + loss[loss_decode.loss_name] = loss_decode( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * + self.train_cfg.num_points) + elif loss_decode.loss_name == 'loss_mask_dice': + loss[loss_decode.loss_name] = loss_decode( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks) + else: + assert False, "Only support 'CrossEntropyLoss' and" \ + " 'DiceLoss' in mask loss" + else: + assert False, "Only support for 'loss_cls' and 'loss_mask'" + + losses.append(loss) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict.update(losses[-1]) + # loss from other decoder layers + for i, loss in enumerate(losses[:-1]): + for k, v in loss.items(): + loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v + return loss_dict diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f9eb0b320b4e7b892e0540cea5ba5ea7054f8008 --- /dev/null +++ b/mmseg/models/decode_heads/segformer_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class SegformerHead(BaseDecodeHead): + """The all mlp Head of segformer. + + This head is the implementation of + `Segformer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + out = self.cls_seg(out) + + return out diff --git a/mmseg/models/decode_heads/segmenter_mask_head.py b/mmseg/models/decode_heads/segmenter_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..85d27735ba8015772324177716b5e8d5f357295c --- /dev/null +++ b/mmseg/models/decode_heads/segmenter_mask_head.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmengine.model import ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SegmenterMaskTransformerHead(BaseDecodeHead): + """Segmenter: Transformer for Semantic Segmentation. + + This head is the implementation of + `Segmenter: `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input image. + num_layers (int): The depth of transformer. + num_heads (int): The number of attention heads. + embed_dims (int): The number of embedding dimension. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_path_rate (float): stochastic depth rate. Default 0.1. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + init_std (float): The value of std in weight initialization. + Default: 0.02. + """ + + def __init__( + self, + in_channels, + num_layers, + num_heads, + embed_dims, + mlp_ratio=4, + drop_path_rate=0.1, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_std=0.02, + **kwargs, + ): + super().__init__(in_channels=in_channels, **kwargs) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + batch_first=True, + )) + + self.dec_proj = nn.Linear(in_channels, embed_dims) + + self.cls_emb = nn.Parameter( + torch.randn(1, self.num_classes, embed_dims)) + self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) + self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) + + self.decoder_norm = build_norm_layer( + norm_cfg, embed_dims, postfix=1)[1] + self.mask_norm = build_norm_layer( + norm_cfg, self.num_classes, postfix=2)[1] + + self.init_std = init_std + + delattr(self, 'conv_seg') + + def init_weights(self): + trunc_normal_(self.cls_emb, std=self.init_std) + trunc_normal_init(self.patch_proj, std=self.init_std) + trunc_normal_init(self.classes_proj, std=self.init_std) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=self.init_std, bias=0) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.0) + + def forward(self, inputs): + x = self._transform_inputs(inputs) + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) + + x = self.dec_proj(x) + cls_emb = self.cls_emb.expand(x.size(0), -1, -1) + x = torch.cat((x, cls_emb), 1) + for layer in self.layers: + x = layer(x) + x = self.decoder_norm(x) + + patches = self.patch_proj(x[:, :-self.num_classes]) + cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) + + patches = F.normalize(patches, dim=2, p=2) + cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) + + masks = patches @ cls_seg_feat.transpose(1, 2) + masks = self.mask_norm(masks) + masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) + + return masks diff --git a/mmseg/models/decode_heads/sep_aspp_head.py b/mmseg/models/decode_heads/sep_aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9dba68c9ecc6909e47da4f2da6169d529910355d --- /dev/null +++ b/mmseg/models/decode_heads/sep_aspp_head.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .aspp_head import ASPPHead, ASPPModule + + +class DepthwiseSeparableASPPModule(ASPPModule): + """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable + conv.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for i, dilation in enumerate(self.dilations): + if dilation > 1: + self[i] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + 3, + dilation=dilation, + padding=dilation, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + +@MODELS.register_module() +class DepthwiseSeparableASPPHead(ASPPHead): + """Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation. + + This head is the implementation of `DeepLabV3+ + `_. + + Args: + c1_in_channels (int): The input channels of c1 decoder. If is 0, + the no decoder will be used. + c1_channels (int): The intermediate channels of c1 decoder. + """ + + def __init__(self, c1_in_channels, c1_channels, **kwargs): + super().__init__(**kwargs) + assert c1_in_channels >= 0 + self.aspp_modules = DepthwiseSeparableASPPModule( + dilations=self.dilations, + in_channels=self.in_channels, + channels=self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if c1_in_channels > 0: + self.c1_bottleneck = ConvModule( + c1_in_channels, + c1_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + else: + self.c1_bottleneck = None + self.sep_bottleneck = nn.Sequential( + DepthwiseSeparableConvModule( + self.channels + c1_channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + DepthwiseSeparableConvModule( + self.channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + if self.c1_bottleneck is not None: + c1_output = self.c1_bottleneck(inputs[0]) + output = resize( + input=output, + size=c1_output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = torch.cat([output, c1_output], dim=1) + output = self.sep_bottleneck(output) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3b15983bceaeff48534bbceedfdf1c434a8d1d1f --- /dev/null +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import DepthwiseSeparableConvModule + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class DepthwiseSeparableFCNHead(FCNHead): + """Depthwise-Separable Fully Convolutional Network for Semantic + Segmentation. + + This head is implemented according to `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels(int): Number of output channels of FFM. + channels(int): Number of middle-stage channels in the decode head. + concat_input(bool): Whether to concatenate original decode input into + the result of several consecutive convolution layers. + Default: True. + num_classes(int): Used to determine the dimension of + final prediction tensor. + in_index(int): Correspond with 'out_indices' in FastSCNN backbone. + norm_cfg (dict | None): Config of norm layers. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_decode(dict): Config of loss type and some + relevant additional options. + dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: None. + """ + + def __init__(self, dw_act_cfg=None, **kwargs): + super().__init__(**kwargs) + self.convs[0] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + for i in range(1, self.num_convs): + self.convs[i] = DepthwiseSeparableConvModule( + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + if self.concat_input: + self.conv_cat = DepthwiseSeparableConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1975991a60cc720650b880060efe10753f213131 --- /dev/null +++ b/mmseg/models/decode_heads/setr_mla_head.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import Upsample +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SETRMLAHead(BaseDecodeHead): + """Multi level feature aggretation head of SETR. + + MLA head of `SETR `_. + + Args: + mlahead_channels (int): Channels of conv-conv-4x of multi-level feature + aggregation. Default: 128. + up_scale (int): The scale factor of interpolate. Default:4. + """ + + def __init__(self, mla_channels=128, up_scale=4, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.mla_channels = mla_channels + + num_inputs = len(self.in_channels) + + # Refer to self.cls_seg settings of BaseDecodeHead + assert self.channels == num_inputs * mla_channels + + self.up_convs = nn.ModuleList() + for i in range(num_inputs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=self.in_channels[i], + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + in_channels=mla_channels, + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + + def forward(self, inputs): + inputs = self._transform_inputs(inputs) + outs = [] + for x, up_conv in zip(inputs, self.up_convs): + outs.append(up_conv(x)) + out = torch.cat(outs, dim=1) + out = self.cls_seg(out) + return out diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9c796d8161088c2d7effe17f5ba71e43ff62e50c --- /dev/null +++ b/mmseg/models/decode_heads/setr_up_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import Upsample +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SETRUPHead(BaseDecodeHead): + """Naive upsampling head and Progressive upsampling head of SETR. + + Naive or PUP head of `SETR `_. + + Args: + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + num_convs (int): Number of decoder convolutions. Default: 1. + up_scale (int): The scale factor of interpolate. Default:4. + kernel_size (int): The kernel size of convolution when decoding + feature information from backbone. Default: 3. + init_cfg (dict | list[dict] | None): Initialization config dict. + Default: dict( + type='Constant', val=1.0, bias=0, layer='LayerNorm'). + """ + + def __init__(self, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + num_convs=1, + up_scale=4, + kernel_size=3, + init_cfg=[ + dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), + dict( + type='Normal', + std=0.01, + override=dict(name='conv_seg')) + ], + **kwargs): + + assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' + + super().__init__(init_cfg=init_cfg, **kwargs) + + assert isinstance(self.in_channels, int) + + _, self.norm = build_norm_layer(norm_layer, self.in_channels) + + self.up_convs = nn.ModuleList() + in_channels = self.in_channels + out_channels = self.channels + for _ in range(num_convs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=int(kernel_size - 1) // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + in_channels = out_channels + + def forward(self, x): + x = self._transform_inputs(x) + + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + + for up_conv in self.up_convs: + x = up_conv(x) + out = self.cls_seg(x) + return out diff --git a/mmseg/models/decode_heads/stdc_head.py b/mmseg/models/decode_heads/stdc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1c1c21e3083fcb5098d2458e44538c0cf5b8f0e4 --- /dev/null +++ b/mmseg/models/decode_heads/stdc_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList +from .fcn_head import FCNHead + + +@MODELS.register_module() +class STDCHead(FCNHead): + """This head is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + boundary_threshold (float): The threshold of calculating boundary. + Default: 0.1. + """ + + def __init__(self, boundary_threshold=0.1, **kwargs): + super().__init__(**kwargs) + self.boundary_threshold = boundary_threshold + # Using register buffer to make laplacian kernel on the same + # device of `seg_label`. + self.register_buffer( + 'laplacian_kernel', + torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], + dtype=torch.float32, + requires_grad=False).reshape((1, 1, 3, 3))) + self.fusion_kernel = torch.nn.Parameter( + torch.tensor([[6. / 10], [3. / 10], [1. / 10]], + dtype=torch.float32).reshape(1, 3, 1, 1), + requires_grad=False) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute Detail Aggregation Loss.""" + # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv + # parameters. However, it is a constant in original repo and other + # codebase because it would not be added into computation graph + # after threshold operation. + seg_label = self._stack_batch_gt(batch_data_samples).to( + self.laplacian_kernel) + boundary_targets = F.conv2d( + seg_label, self.laplacian_kernel, padding=1) + boundary_targets = boundary_targets.clamp(min=0) + boundary_targets[boundary_targets > self.boundary_threshold] = 1 + boundary_targets[boundary_targets <= self.boundary_threshold] = 0 + + boundary_targets_x2 = F.conv2d( + seg_label, self.laplacian_kernel, stride=2, padding=1) + boundary_targets_x2 = boundary_targets_x2.clamp(min=0) + + boundary_targets_x4 = F.conv2d( + seg_label, self.laplacian_kernel, stride=4, padding=1) + boundary_targets_x4 = boundary_targets_x4.clamp(min=0) + + boundary_targets_x4_up = F.interpolate( + boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x2_up = F.interpolate( + boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') + + boundary_targets_x2_up[ + boundary_targets_x2_up > self.boundary_threshold] = 1 + boundary_targets_x2_up[ + boundary_targets_x2_up <= self.boundary_threshold] = 0 + + boundary_targets_x4_up[ + boundary_targets_x4_up > self.boundary_threshold] = 1 + boundary_targets_x4_up[ + boundary_targets_x4_up <= self.boundary_threshold] = 0 + + boundary_targets_pyramids = torch.stack( + (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), + dim=1) + + boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) + boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids, + self.fusion_kernel) + + boudary_targets_pyramid[ + boudary_targets_pyramid > self.boundary_threshold] = 1 + boudary_targets_pyramid[ + boudary_targets_pyramid <= self.boundary_threshold] = 0 + + seg_labels = boudary_targets_pyramid.long() + batch_sample_list = [] + for label in seg_labels: + seg_data_sample = SegDataSample() + seg_data_sample.gt_sem_seg = PixelData(data=label) + batch_sample_list.append(seg_data_sample) + + loss = super().loss_by_feat(seg_logits, batch_sample_list) + return loss diff --git a/mmseg/models/decode_heads/uper_head.py b/mmseg/models/decode_heads/uper_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ccc3173c0f1193e89ad48861aa7b5ee3b329cc --- /dev/null +++ b/mmseg/models/decode_heads/uper_head.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead +from .psp_head import PPM + + +@MODELS.register_module() +class UPerHead(BaseDecodeHead): + """Unified Perceptual Parsing for Scene Understanding. + + This head is the implementation of `UPerNet + `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module applied on the last feature. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + # PSP Module + self.psp_modules = PPM( + pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + inputs = self._transform_inputs(inputs) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], + size=prev_shape, + mode='bilinear', + align_corners=self.align_corners) + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + feats = self.fpn_bottleneck(fpn_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/vpd_depth_head.py b/mmseg/models/decode_heads/vpd_depth_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0c54c2da1b1e62b213f794a7d4e49cd3d753ca36 --- /dev/null +++ b/mmseg/models/decode_heads/vpd_depth_head.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import SampleList +from ..builder import build_loss +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class VPDDepthDecoder(BaseModule): + """VPD Depth Decoder class. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_deconv_layers (int): Number of deconvolution layers. + num_deconv_filters (List[int]): List of output channels for + deconvolution layers. + init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration + for weight initialization. Defaults to Normal for Conv2d and + ConvTranspose2d layers. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + num_deconv_layers: int, + num_deconv_filters: List[int], + init_cfg: Optional[Union[Dict, List[Dict]]] = dict( + type='Normal', + std=0.001, + layer=['Conv2d', 'ConvTranspose2d'])): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + ) + + conv_layers = [] + conv_layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=num_deconv_filters[-1], + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1)) + conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1]) + conv_layers.append(nn.ReLU(inplace=True)) + self.conv_layers = nn.Sequential(*conv_layers) + + self.up_sample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, x): + """Forward pass through the decoder network.""" + out = self.deconv_layers(x) + out = self.conv_layers(out) + + out = self.up_sample(out) + out = self.up_sample(out) + + return out + + def _make_deconv_layer(self, num_layers, num_deconv_filters): + """Make deconv layers.""" + + layers = [] + in_channels = self.in_channels + for i in range(num_layers): + + num_channels = num_deconv_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=in_channels, + out_channels=num_channels, + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + bias=False)) + layers.append(nn.BatchNorm2d(num_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = num_channels + + return nn.Sequential(*layers) + + +@MODELS.register_module() +class VPDDepthHead(BaseDecodeHead): + """Depth Prediction Head for VPD. + + .. _`VPD`: https://arxiv.org/abs/2303.02153 + + Args: + max_depth (float): Maximum depth value. Defaults to 10.0. + in_channels (Sequence[int]): Number of input channels for each + convolutional layer. + embed_dim (int): Dimension of embedding. Defaults to 192. + feature_dim (int): Dimension of aggregated feature. Defaults to 1536. + num_deconv_layers (int): Number of deconvolution layers in the + decoder. Defaults to 3. + num_deconv_filters (Sequence[int]): Number of filters for each deconv + layer. Defaults to (32, 32, 32). + fmap_border (Union[int, Sequence[int]]): Feature map border for + cropping. Defaults to 0. + align_corners (bool): Flag for align_corners in interpolation. + Defaults to False. + loss_decode (dict): Configurations for the loss function. Defaults to + dict(type='SiLogLoss'). + init_cfg (dict): Initialization configurations. Defaults to + dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']). + """ + + num_classes = 1 + out_channels = 1 + input_transform = None + + def __init__( + self, + max_depth: float = 10.0, + in_channels: Sequence[int] = [320, 640, 1280, 1280], + embed_dim: int = 192, + feature_dim: int = 1536, + num_deconv_layers: int = 3, + num_deconv_filters: Sequence[int] = (32, 32, 32), + fmap_border: Union[int, Sequence[int]] = 0, + align_corners: bool = False, + loss_decode: dict = dict(type='SiLogLoss'), + init_cfg=dict( + type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']), + ): + + super(BaseDecodeHead, self).__init__(init_cfg=init_cfg) + + # initialize parameters + self.in_channels = in_channels + self.max_depth = max_depth + self.align_corners = align_corners + + # feature map border + if isinstance(fmap_border, int): + fmap_border = (fmap_border, fmap_border) + self.fmap_border = fmap_border + + # define network layers + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), + nn.GroupNorm(16, in_channels[0]), + nn.ReLU(), + nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), + ) + self.conv2 = nn.Conv2d( + in_channels[1], in_channels[1], 3, stride=2, padding=1) + + self.conv_aggregation = nn.Sequential( + nn.Conv2d(sum(in_channels), feature_dim, 1), + nn.GroupNorm(16, feature_dim), + nn.ReLU(), + ) + + self.decoder = VPDDepthDecoder( + in_channels=embed_dim * 8, + out_channels=embed_dim, + num_deconv_layers=num_deconv_layers, + num_deconv_filters=num_deconv_filters) + + self.depth_pred_layer = nn.Sequential( + nn.Conv2d( + embed_dim, embed_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1)) + + # build loss + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_depth_maps = [ + data_sample.gt_depth_map.data for data_sample in batch_data_samples + ] + return torch.stack(gt_depth_maps, dim=0) + + def forward(self, x): + x = [ + x[0], x[1], + torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1) + ] + x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1) + x = self.conv_aggregation(x) + + x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) - + self.fmap_border[1]].contiguous() + x = self.decoder(x) + out = self.depth_pred_layer(x) + + depth = torch.sigmoid(out) * self.max_depth + + return depth + + def loss_by_feat(self, pred_depth_map: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute depth estimation loss. + + Args: + pred_depth_map (Tensor): The output from decode head forward + function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_dpeth_map`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + gt_depth_map = self._stack_batch_gt(batch_data_samples) + loss = dict() + pred_depth_map = resize( + input=pred_depth_map, + size=gt_depth_map.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + pred_depth_map, gt_depth_map) + else: + loss[loss_decode.loss_name] += loss_decode( + pred_depth_map, gt_depth_map) + + return loss diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0467cb3ad89b8c0c57f7f8eb58cbc2e23f50cdb4 --- /dev/null +++ b/mmseg/models/losses/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .accuracy import Accuracy, accuracy +from .boundary_loss import BoundaryLoss +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss +from .focal_loss import FocalLoss +from .huasdorff_distance_loss import HuasdorffDisstanceLoss +from .lovasz_loss import LovaszLoss +from .ohem_cross_entropy_loss import OhemCrossEntropy +from .silog_loss import SiLogLoss +from .tversky_loss import TverskyLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', + 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss', + 'HuasdorffDisstanceLoss', 'SiLogLoss' +] diff --git a/mmseg/models/losses/__pycache__/__init__.cpython-311.pyc b/mmseg/models/losses/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a8bada502a523bbd3686b45e63e06b493cb11b Binary files /dev/null and b/mmseg/models/losses/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/accuracy.cpython-311.pyc b/mmseg/models/losses/__pycache__/accuracy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfe7489e47ec77125f2af6998cfe0ae441c642d0 Binary files /dev/null and b/mmseg/models/losses/__pycache__/accuracy.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/boundary_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/boundary_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfcf42b8179bccd501a1a7f7bb516701871cad9d Binary files /dev/null and b/mmseg/models/losses/__pycache__/boundary_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18446b0722a16466835d932c85c31c7cc22493ad Binary files /dev/null and b/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/dice_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/dice_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77bba5d1c7ed5219f10e908561c3f945d2f737de Binary files /dev/null and b/mmseg/models/losses/__pycache__/dice_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/focal_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/focal_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dbf4322cfadba2f090937f6b464fa5ade8ea6eb Binary files /dev/null and b/mmseg/models/losses/__pycache__/focal_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/huasdorff_distance_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/huasdorff_distance_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09fd54294d9f1cbde10f28fae4f8b597444c4b0b Binary files /dev/null and b/mmseg/models/losses/__pycache__/huasdorff_distance_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/lovasz_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/lovasz_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3920ef966d39faad84e7ce5bb03aac8d4cec7147 Binary files /dev/null and b/mmseg/models/losses/__pycache__/lovasz_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/ohem_cross_entropy_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/ohem_cross_entropy_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7535672c8c1605fd830df012bf29a2a17d2351b0 Binary files /dev/null and b/mmseg/models/losses/__pycache__/ohem_cross_entropy_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/silog_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/silog_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eb1f8128acb7aac0516b7906a9be08c654a7fab Binary files /dev/null and b/mmseg/models/losses/__pycache__/silog_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/tversky_loss.cpython-311.pyc b/mmseg/models/losses/__pycache__/tversky_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a72c5c1eba17980b4f2af4f0c394e67974eec45 Binary files /dev/null and b/mmseg/models/losses/__pycache__/tversky_loss.cpython-311.pyc differ diff --git a/mmseg/models/losses/__pycache__/utils.cpython-311.pyc b/mmseg/models/losses/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd385c852b553cdc92ce427012f051de36e4f93c Binary files /dev/null and b/mmseg/models/losses/__pycache__/utils.cpython-311.pyc differ diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9e2d7701088adadd5b6bb71c718c986b87a066 --- /dev/null +++ b/mmseg/models/losses/accuracy.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + ignore_index (int | None): The label index to be ignored. Default: None + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + if ignore_index is not None: + correct = correct[:, target != ignore_index] + res = [] + eps = torch.finfo(torch.float32).eps + for k in topk: + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + if ignore_index is not None: + total_num = target[target != ignore_index].numel() + eps + else: + total_num = target.numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None, ignore_index=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + self.ignore_index = ignore_index + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh, + self.ignore_index) diff --git a/mmseg/models/losses/boundary_loss.py b/mmseg/models/losses/boundary_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e86b850d87e1d26be8cbb700758dae8dead82c58 --- /dev/null +++ b/mmseg/models/losses/boundary_loss.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class BoundaryLoss(nn.Module): + """Boundary loss. + + This function is modified from + `PIDNet `_. # noqa + Licensed under the MIT License. + + + Args: + loss_weight (float): Weight of the loss. Defaults to 1.0. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + loss_weight: float = 1.0, + loss_name: str = 'loss_boundary'): + super().__init__() + self.loss_weight = loss_weight + self.loss_name_ = loss_name + + def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor: + """Forward function. + Args: + bd_pre (Tensor): Predictions of the boundary head. + bd_gt (Tensor): Ground truth of the boundary. + + Returns: + Tensor: Loss tensor. + """ + log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1) + target_t = bd_gt.view(1, -1).float() + + pos_index = (target_t == 1) + neg_index = (target_t == 0) + + weight = torch.zeros_like(log_p) + pos_num = pos_index.sum() + neg_num = neg_index.sum() + sum_num = pos_num + neg_num + weight[pos_index] = neg_num * 1.0 / sum_num + weight[neg_index] = pos_num * 1.0 / sum_num + + loss = F.binary_cross_entropy_with_logits( + log_p, target_t, weight, reduction='mean') + + return self.loss_weight * loss + + @property + def loss_name(self): + return self.loss_name_ diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..988fb789c11eca9d002b2c02f227450d704aeaef --- /dev/null +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS +from .utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and reduction == 'mean': + if class_weight is None: + if avg_non_ignore: + avg_factor = label.numel() - (label + == ignore_index).sum().item() + else: + avg_factor = label.numel() + + else: + # the average factor should take the class weights into account + label_weights = torch.stack([class_weight[cls] for cls in label + ]).to(device=class_weight.device) + + if avg_non_ignore: + label_weights[label == ignore_index] = 0 + avg_factor = label_weights.sum() + + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + # As the ignore_index often set as 255, so the + # binary class label check should mask out + # ignore_index + assert label[label != ignore_index].max() <= 1, \ + 'For pred with shape [N, 1, H, W], its label must have at ' \ + 'most 2 classes' + pred = pred.squeeze(1) + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == 'mean' and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_ce', + avg_non_ignore=False): + super().__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == 'mean': + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=-100, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2ffdba8daf867032b6d7b4e0d70a9b7a0c50fe --- /dev/null +++ b/mmseg/models/losses/dice_loss.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +import torch.nn as nn + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +def _expand_onehot_labels_dice(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + """Expand onehot labels to match the size of prediction. + + Args: + pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W). + target (torch.Tensor): The learning label of the prediction, + has a shape (N, H, W). + + Returns: + torch.Tensor: The target after one-hot encoding, + has a shape (N, num_class, H, W). + """ + num_classes = pred.shape[1] + one_hot_target = torch.clamp(target, min=0, max=num_classes) + one_hot_target = torch.nn.functional.one_hot(one_hot_target, + num_classes + 1) + one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2) + return one_hot_target + + +def dice_loss(pred: torch.Tensor, + target: torch.Tensor, + weight: Union[torch.Tensor, None], + eps: float = 1e-3, + reduction: Union[str, None] = 'mean', + naive_dice: Union[bool, None] = False, + avg_factor: Union[int, None] = None, + ignore_index: Union[int, None] = 255) -> float: + """Calculate dice loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + ignore_index (int, optional): The label index to be ignored. + Defaults to 255. + """ + if ignore_index is not None: + num_classes = pred.shape[1] + pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] + target = target[:, torch.arange(num_classes) != ignore_index, :, :] + assert pred.shape[1] != 0 # if the ignored index is the only class + input = pred.flatten(1) + target = target.flatten(1).float() + a = torch.sum(input * target, 1) + if naive_dice: + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + else: + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class DiceLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=False, + loss_weight=1.0, + ignore_index=255, + eps=1e-3, + loss_name='loss_dice'): + """Compute dice loss. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power. Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + ignore_index (int, optional): The label index to be ignored. + Default: 255. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_dice'. + """ + + super().__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + self.ignore_index = ignore_index + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + one_hot_target = target + if (pred.shape != target.shape): + one_hot_target = _expand_onehot_labels_dice(pred, target) + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + elif pred.shape[1] != 1: + # softmax does not work when there is only 1 class + pred = pred.softmax(dim=1) + loss = self.loss_weight * dice_loss( + pred, + one_hot_target, + weight, + eps=self.eps, + reduction=reduction, + naive_dice=self.naive_dice, + avg_factor=avg_factor, + ignore_index=self.ignore_index) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/focal_loss.py b/mmseg/models/losses/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6507ed7a9112993733ac25bc095da0b571e14363 --- /dev/null +++ b/mmseg/models/losses/focal_loss.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/open-mmlab/mmdetection +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +# This method is used when cuda is not available +def py_sigmoid_focal_loss(pred, + target, + one_hot_target=None, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction with + shape (N, C) + one_hot_target (None): Placeholder. It should be None. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + if isinstance(alpha, list): + alpha = pred.new_tensor(alpha) + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * one_minus_pt.pow(gamma) + + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + final_weight = torch.ones(1, pred.size(1)).type_as(loss) + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + one_hot_target, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + r"""A wrapper of cuda version `Focal Loss + `_. + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. It's shape + should be (N, ) + one_hot_target (torch.Tensor): The learning label with shape (N, C) + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + final_weight = torch.ones(1, pred.size(1)).type_as(pred) + if isinstance(alpha, list): + # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if + # a list is given, we set the input alpha as 0.5. This means setting + # equal weight for foreground class and background class. By + # multiplying the loss by 2, the effect of setting alpha as 0.5 is + # undone. The alpha of type list is used to regulate the loss in the + # post-processing process. + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, 0.5, None, 'none') * 2 + alpha = pred.new_tensor(alpha) + final_weight = final_weight * ( + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) + else: + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.5, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_focal'): + """`Focal Loss `_ + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal + Loss. Defaults to 0.5. When a list is provided, the length + of the list should be equal to the number of classes. + Please be careful that this parameter is not the + class-wise weight but the weight of a binary classification + problem. This binary classification problem regards the + pixels which belong to one class as the foreground + and the other pixels as the background, each element in + the list is the weight of the corresponding foreground class. + The value of alpha or each element of alpha should be a float + in the interval [0, 1]. If you want to specify the class-wise + weight, please use `class_weight` parameter. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_focal'. + """ + super().__init__() + assert use_sigmoid is True, \ + 'AssertionError: Only sigmoid focal loss supported now.' + assert reduction in ('none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert isinstance(alpha, (float, list)), \ + 'AssertionError: alpha should be of type float' + assert isinstance(gamma, float), \ + 'AssertionError: gamma should be of type float' + assert isinstance(loss_weight, float), \ + 'AssertionError: loss_weight should be of type float' + assert isinstance(loss_name, str), \ + 'AssertionError: loss_name should be of type str' + assert isinstance(class_weight, list) or class_weight is None, \ + 'AssertionError: class_weight must be None or of type list' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.class_weight = class_weight + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction with shape + (N, C) where C = number of classes, or + (N, C, d_1, d_2, ..., d_K) with K≥1 in the + case of K-dimensional loss. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of + K-dimensional loss. If containing class probabilities, + same shape as the input. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + ignore_index (int, optional): The label index to be ignored. + Default: 255 + Returns: + torch.Tensor: The calculated loss + """ + assert isinstance(ignore_index, int), \ + 'ignore_index must be of type int' + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert pred.shape == target.shape or \ + (pred.size(0) == target.size(0) and + pred.shape[2:] == target.shape[1:]), \ + "The shape of pred doesn't match the shape of target" + + original_shape = pred.shape + + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + + if original_shape == target.shape: + # target with shape [B, C, d_1, d_2, ...] + # transform it's shape into [N, C] + # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k] + target = target.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + target = target.reshape(target.size(0), -1) + # [C, N] -> [N, C] + target = target.transpose(0, 1).contiguous() + else: + # target with shape [B, d_1, d_2, ...] + # transform it's shape into [N, ] + target = target.view(-1).contiguous() + valid_mask = (target != ignore_index).view(-1, 1) + # avoid raising error when using F.one_hot() + target = torch.where(target == ignore_index, target.new_tensor(0), + target) + + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + num_classes = pred.size(1) + if torch.cuda.is_available() and pred.is_cuda: + if target.dim() == 1: + one_hot_target = F.one_hot( + target, num_classes=num_classes + 1) + if num_classes == 1: + one_hot_target = one_hot_target[:, 1] + target = 1 - target + else: + one_hot_target = one_hot_target[:, :num_classes] + else: + one_hot_target = target + target = target.argmax(dim=1) + valid_mask = (target != ignore_index).view(-1, 1) + calculate_loss_func = sigmoid_focal_loss + else: + one_hot_target = None + if target.dim() == 1: + target = F.one_hot(target, num_classes=num_classes + 1) + if num_classes == 1: + target = target[:, 1] + else: + target = target[:, num_classes] + else: + valid_mask = (target.argmax(dim=1) != ignore_index).view( + -1, 1) + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + one_hot_target, + weight, + gamma=self.gamma, + alpha=self.alpha, + class_weight=self.class_weight, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor) + + if reduction == 'none': + # [N, C] -> [C, N] + loss_cls = loss_cls.transpose(0, 1) + # [C, N] -> [C, B, d1, d2, ...] + # original_shape: [B, C, d1, d2, ...] + loss_cls = loss_cls.reshape(original_shape[1], + original_shape[0], + *original_shape[2:]) + # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] + loss_cls = loss_cls.transpose(0, 1).contiguous() + else: + raise NotImplementedError + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/huasdorff_distance_loss.py b/mmseg/models/losses/huasdorff_distance_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d950ba728f8d419ea2b291e2159b926aca44038c --- /dev/null +++ b/mmseg/models/losses/huasdorff_distance_loss.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/ +master/code/train_LA_HD.py (Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.ndimage import distance_transform_edt as distance +from torch import Tensor + +from mmseg.registry import MODELS +from .utils import get_class_weight, weighted_loss + + +def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor: + """ + compute the distance transform map of foreground in mask + Args: + img_gt: Ground truth of the image, (b, h, w) + pred: Predictions of the segmentation head after softmax, (b, c, h, w) + + Returns: + output: the foreground Distance Map (SDM) + dtm(x) = 0; x in segmentation boundary + inf|x-y|; x in segmentation + """ + + fg_dtm = torch.zeros_like(pred) + out_shape = pred.shape + for b in range(out_shape[0]): # batch size + for c in range(1, out_shape[1]): # default 0 channel is background + posmask = img_gt[b].byte() + if posmask.any(): + posdis = distance(posmask) + fg_dtm[b][c] = torch.from_numpy(posdis) + + return fg_dtm + + +@weighted_loss +def hd_loss(seg_soft: Tensor, + gt: Tensor, + seg_dtm: Tensor, + gt_dtm: Tensor, + class_weight=None, + ignore_index=255) -> Tensor: + """ + compute huasdorff distance loss for segmentation + Args: + seg_soft: softmax results, shape=(b,c,x,y) + gt: ground truth, shape=(b,x,y) + seg_dtm: segmentation distance transform map, shape=(b,c,x,y) + gt_dtm: ground truth distance transform map, shape=(b,c,x,y) + + Returns: + output: hd_loss + """ + assert seg_soft.shape[0] == gt.shape[0] + total_loss = 0 + num_class = seg_soft.shape[1] + if class_weight is not None: + assert class_weight.ndim == num_class + for i in range(1, num_class): + if i != ignore_index: + delta_s = (seg_soft[:, i, ...] - gt.float())**2 + s_dtm = seg_dtm[:, i, ...]**2 + g_dtm = gt_dtm[:, i, ...]**2 + dtm = s_dtm + g_dtm + multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm) + hd_loss = multiplied.mean() + if class_weight is not None: + hd_loss *= class_weight[i] + total_loss += hd_loss + + return total_loss / num_class + + +@MODELS.register_module() +class HuasdorffDisstanceLoss(nn.Module): + """HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform + Maps Boost Segmentation CNNs: An Empirical Study. + + `_. + Args: + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + reduction='mean', + class_weight=None, + loss_weight=1.0, + ignore_index=255, + loss_name='loss_huasdorff_disstance', + **kwargs): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + self.ignore_index = ignore_index + + def forward(self, + pred: Tensor, + target: Tensor, + avg_factor=None, + reduction_override=None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predictions of the segmentation head. (B, C, H, W) + target (Tensor): Ground truth of the image. (B, H, W) + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + Returns: + Tensor: Loss tensor. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred_soft = F.softmax(pred, dim=1) + valid_mask = (target != self.ignore_index).long() + target = target * valid_mask + + with torch.no_grad(): + gt_dtm = compute_dtm(target.cpu(), pred_soft) + gt_dtm = gt_dtm.float() + seg_dtm2 = compute_dtm( + pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft) + seg_dtm2 = seg_dtm2.float() + + loss_hd = self.loss_weight * hd_loss( + pred_soft, + target, + seg_dtm=seg_dtm2, + gt_dtm=gt_dtm, + reduction=reduction, + avg_factor=avg_factor, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss_hd + + @property + def loss_name(self): + return self._loss_name diff --git a/mmseg/models/losses/kldiv_loss.py b/mmseg/models/losses/kldiv_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..496ef9713f085a36d46837ac0b51d4cb9f956fce --- /dev/null +++ b/mmseg/models/losses/kldiv_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class KLDivLoss(nn.Module): + + def __init__(self, + temperature: float = 1.0, + reduction: str = 'mean', + loss_name: str = 'loss_kld'): + """Kullback-Leibler divergence Loss. + + + + Args: + temperature (float, optional): Temperature param + reduction (str, optional): The method to reduce the loss into a + scalar. Default is "mean". Options are "none", "sum", + and "mean" + """ + + assert isinstance(temperature, (float, int)), \ + 'Expected temperature to be' \ + f'float or int, but got {temperature.__class__.__name__} instead' + assert temperature != 0., 'Temperature must not be zero' + + assert reduction in ['mean', 'none', 'sum'], \ + 'Reduction must be one of the options ("mean", ' \ + f'"sum", "none"), but got {reduction}' + + super().__init__() + self.temperature = temperature + self.reduction = reduction + self._loss_name = loss_name + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Forward function. Calculate KL divergence Loss. + + Args: + input (Tensor): Logit tensor, + the data type is float32 or float64. + The shape is (N, C) where N is batchsize and C is number of + channels. + If there more than 2 dimensions, shape is (N, C, D1, D2, ... + Dk), k>= 1 + target (Tensor): Logit tensor, + the data type is float32 or float64. + input and target must be with the same shape. + + Returns: + (Tensor): Reduced loss. + """ + assert isinstance(input, torch.Tensor), 'Expected input to' \ + f'be Tensor, but got {input.__class__.__name__} instead' + assert isinstance(target, torch.Tensor), 'Expected target to' \ + f'be Tensor, but got {target.__class__.__name__} instead' + + assert input.shape == target.shape, 'Input and target ' \ + 'must have same shape,' \ + f'but got shapes {input.shape} and {target.shape}' + + input = F.softmax(input / self.temperature, dim=1) + target = F.softmax(target / self.temperature, dim=1) + + loss = F.kl_div(input, target, reduction='none', log_target=False) + loss = loss * self.temperature**2 + + batch_size = input.shape[0] + + if self.reduction == 'sum': + # Change view to calculate instance-wise sum + loss = loss.view(batch_size, -1) + return torch.sum(loss, dim=1) + + elif self.reduction == 'mean': + # Change view to calculate instance-wise mean + loss = loss.view(batch_size, -1) + return torch.mean(loss, dim=1) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/lovasz_loss.py b/mmseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f9d8a15330a45d0d2d25f3c18d9386e2b335e --- /dev/null +++ b/mmseg/models/losses/lovasz_loss.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import is_list_of + +from mmseg.registry import MODELS +from .utils import get_class_weight, weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@MODELS.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_lovasz'. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_lovasz'): + super().__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/ohem_cross_entropy_loss.py b/mmseg/models/losses/ohem_cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a519b4d84e1dbf86ebc7ad07372ddbdfb0ff3d13 --- /dev/null +++ b/mmseg/models/losses/ohem_cross_entropy_loss.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class OhemCrossEntropy(nn.Module): + """OhemCrossEntropy loss. + + This func is modified from + `PIDNet `_. # noqa + + Licensed under the MIT License. + + Args: + ignore_label (int): Labels to ignore when computing the loss. + Default: 255 + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: 0.7. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + ignore_label: int = 255, + thres: float = 0.7, + min_kept: int = 100000, + loss_weight: float = 1.0, + class_weight: Optional[Union[List[float], str]] = None, + loss_name: str = 'loss_ohem'): + super().__init__() + self.thresh = thres + self.min_kept = max(1, min_kept) + self.ignore_label = ignore_label + self.loss_weight = loss_weight + self.loss_name_ = loss_name + self.class_weight = class_weight + + def forward(self, score: Tensor, target: Tensor) -> Tensor: + """Forward function. + Args: + score (Tensor): Predictions of the segmentation head. + target (Tensor): Ground truth of the image. + + Returns: + Tensor: Loss tensor. + """ + # score: (N, C, H, W) + pred = F.softmax(score, dim=1) + if self.class_weight is not None: + class_weight = score.new_tensor(self.class_weight) + else: + class_weight = None + + pixel_losses = F.cross_entropy( + score, + target, + weight=class_weight, + ignore_index=self.ignore_label, + reduction='none').contiguous().view(-1) # (N*H*W) + mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W) + + tmp_target = target.clone() # (N, H, W) + tmp_target[tmp_target == self.ignore_label] = 0 + # pred: (N, C, H, W) -> (N*H*W, C) + pred = pred.gather(1, tmp_target.unsqueeze(1)) + # pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W) + pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort() + if pred.numel() > 0: + min_value = pred[min(self.min_kept, pred.numel() - 1)] + else: + return score.new_tensor(0.0) + threshold = max(min_value, self.thresh) + + pixel_losses = pixel_losses[mask][ind] + pixel_losses = pixel_losses[pred < threshold] + return self.loss_weight * pixel_losses.mean() + + @property + def loss_name(self): + return self.loss_name_ diff --git a/mmseg/models/losses/silog_loss.py b/mmseg/models/losses/silog_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc07aac424a9308bce33e00c621369ac555f4ba --- /dev/null +++ b/mmseg/models/losses/silog_loss.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +def silog_loss(pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + eps: float = 1e-4, + reduction: Union[str, None] = 'mean', + avg_factor: Optional[int] = None) -> Tensor: + """Computes the Scale-Invariant Logarithmic (SI-Log) loss between + prediction and target. + + Args: + pred (Tensor): Predicted output. + target (Tensor): Ground truth. + weight (Optional[Tensor]): Optional weight to apply on the loss. + eps (float): Epsilon value to avoid division and log(0). + reduction (Union[str, None]): Specifies the reduction to apply to the + output: 'mean', 'sum' or None. + avg_factor (Optional[int]): Optional average factor for the loss. + + Returns: + Tensor: The calculated SI-Log loss. + """ + pred, target = pred.flatten(1), target.flatten(1) + valid_mask = (target > eps).detach().float() + + diff_log = torch.log(target.clamp(min=eps)) - torch.log( + pred.clamp(min=eps)) + + valid_mask = (target > eps).detach() & (~torch.isnan(diff_log)) + diff_log[~valid_mask] = 0.0 + valid_mask = valid_mask.float() + + diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum( + dim=1) / valid_mask.sum(dim=1).clamp(min=eps) + diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum( + dim=1).clamp(min=eps) + + loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2)) + + if weight is not None: + weight = weight.float() + + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class SiLogLoss(nn.Module): + """Compute SiLog loss. + + Args: + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_silog'. + """ + + def __init__(self, + reduction='mean', + loss_weight=1.0, + eps=1e-6, + loss_name='loss_silog'): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ): + + assert pred.shape == target.shape, 'the shapes of pred ' \ + f'({pred.shape}) and target ({target.shape}) are mismatch' + + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + loss = self.loss_weight * silog_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + ) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/tversky_loss.py b/mmseg/models/losses/tversky_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..bfca1af6669e3ac328492da11758a084999ef906 --- /dev/null +++ b/mmseg/models/losses/tversky_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from +https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333 +(Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + tversky_loss = binary_tversky_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + alpha=alpha, + beta=beta, + smooth=smooth) + if class_weight is not None: + tversky_loss *= class_weight[i] + total_loss += tversky_loss + return total_loss / num_classes + + +@weighted_loss +def binary_tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) + FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1) + FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1) + tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) + + return 1 - tversky + + +@LOSSES.register_module() +class TverskyLoss(nn.Module): + """TverskyLoss. This loss is proposed in `Tversky loss function for image + segmentation using 3D fully convolutional deep networks. + + `_. + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + alpha(float, in [0, 1]): + The coefficient of false positives. Default: 0.3. + beta (float, in [0, 1]): + The coefficient of false negatives. Default: 0.7. + Note: alpha + beta = 1. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_tversky'. + """ + + def __init__(self, + smooth=1, + class_weight=None, + loss_weight=1.0, + ignore_index=255, + alpha=0.3, + beta=0.7, + loss_name='loss_tversky'): + super().__init__() + self.smooth = smooth + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + self.alpha = alpha + self.beta = beta + self._loss_name = loss_name + + def forward(self, pred, target, **kwargs): + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * tversky_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + alpha=self.alpha, + beta=self.beta, + smooth=self.smooth, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/utils.py b/mmseg/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..047803473316ff5fc58de2b8e35ef0087bc3b624 --- /dev/null +++ b/mmseg/models/losses/utils.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.fileio import load + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction) -> torch.Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, + weight=None, + reduction='mean', + avg_factor=None) -> torch.Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff03186a92b78f942e79cff9eec9f5e2784c359a --- /dev/null +++ b/mmseg/models/necks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .featurepyramid import Feature2Pyramid +from .fpn import FPN +from .ic_neck import ICNeck +from .jpu import JPU +from .mla_neck import MLANeck +from .multilevel_neck import MultiLevelNeck + +__all__ = [ + 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' +] diff --git a/mmseg/models/necks/__pycache__/__init__.cpython-311.pyc b/mmseg/models/necks/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..612219444976656f04f5e2590aaebc5f185344c9 Binary files /dev/null and b/mmseg/models/necks/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/necks/__pycache__/featurepyramid.cpython-311.pyc b/mmseg/models/necks/__pycache__/featurepyramid.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d591979876eb40eac6beb9a057ffe81219633e Binary files /dev/null and b/mmseg/models/necks/__pycache__/featurepyramid.cpython-311.pyc differ diff --git a/mmseg/models/necks/__pycache__/fpn.cpython-311.pyc b/mmseg/models/necks/__pycache__/fpn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8a9154d0ce66068d5b527325df5e5bf9c6762c Binary files /dev/null and b/mmseg/models/necks/__pycache__/fpn.cpython-311.pyc differ diff --git a/mmseg/models/necks/__pycache__/ic_neck.cpython-311.pyc b/mmseg/models/necks/__pycache__/ic_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f66ccc1d1b36c79e9cdde9e44d98223cf88b91ea Binary files /dev/null and b/mmseg/models/necks/__pycache__/ic_neck.cpython-311.pyc differ diff --git a/mmseg/models/necks/__pycache__/jpu.cpython-311.pyc b/mmseg/models/necks/__pycache__/jpu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5620c5d4e552344f2858c70e9477c7db952f906c Binary files /dev/null and b/mmseg/models/necks/__pycache__/jpu.cpython-311.pyc differ diff --git a/mmseg/models/necks/__pycache__/mla_neck.cpython-311.pyc b/mmseg/models/necks/__pycache__/mla_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebe900a3a77008c45e7c0ab90a906b292b436137 Binary files /dev/null and b/mmseg/models/necks/__pycache__/mla_neck.cpython-311.pyc differ diff --git a/mmseg/models/necks/__pycache__/multilevel_neck.cpython-311.pyc b/mmseg/models/necks/__pycache__/multilevel_neck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fa6d035873fae2c3e5b60b9e36620e0c2f6cf98 Binary files /dev/null and b/mmseg/models/necks/__pycache__/multilevel_neck.cpython-311.pyc differ diff --git a/mmseg/models/necks/featurepyramid.py b/mmseg/models/necks/featurepyramid.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1250d39dafcf78880aa282bcba4215520ad94e --- /dev/null +++ b/mmseg/models/necks/featurepyramid.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_norm_layer + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class Feature2Pyramid(nn.Module): + """Feature2Pyramid. + + A neck structure connect ViT backbone and decoder_heads. + + Args: + embed_dims (int): Embedding dimension. + rescales (list[float]): Different sampling multiples were + used to obtain pyramid features. Default: [4, 2, 1, 0.5]. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + embed_dim, + rescales=[4, 2, 1, 0.5], + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.rescales = rescales + self.upsample_4x = None + for k in self.rescales: + if k == 4: + self.upsample_4x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + build_norm_layer(norm_cfg, embed_dim)[1], + nn.GELU(), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + elif k == 2: + self.upsample_2x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2)) + elif k == 1: + self.identity = nn.Identity() + elif k == 0.5: + self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) + elif k == 0.25: + self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) + else: + raise KeyError(f'invalid {k} for feature2pyramid') + + def forward(self, inputs): + assert len(inputs) == len(self.rescales) + outputs = [] + if self.upsample_4x is not None: + ops = [ + self.upsample_4x, self.upsample_2x, self.identity, + self.downsample_2x + ] + else: + ops = [ + self.upsample_2x, self.identity, self.downsample_2x, + self.downsample_4x + ] + for i in range(len(inputs)): + outputs.append(ops[i](inputs[i])) + return tuple(outputs) diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ddab74c00a262a89031fda44824c5de0e2e9a362 --- /dev/null +++ b/mmseg/models/necks/fpn.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class FPN(BaseModule): + """Feature Pyramid Network. + + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super().__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmseg/models/necks/ic_neck.py b/mmseg/models/necks/ic_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..9763541e0980cb0ec53a342b656e64c99d87ed7e --- /dev/null +++ b/mmseg/models/necks/ic_neck.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class CascadeFeatureFusion(BaseModule): + """Cascade Feature Fusion Unit in ICNet. + + Args: + low_channels (int): The number of input channels for + low resolution feature map. + high_channels (int): The number of input channels for + high resolution feature map. + out_channels (int): The number of output channels. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Returns: + x (Tensor): The output tensor of shape (N, out_channels, H, W). + x_low (Tensor): The output tensor of shape (N, out_channels, H, W) + for Cascade Label Guidance in auxiliary heads. + """ + + def __init__(self, + low_channels, + high_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.conv_low = ConvModule( + low_channels, + out_channels, + 3, + padding=2, + dilation=2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_high = ConvModule( + high_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_low, x_high): + x_low = resize( + x_low, + size=x_high.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + # Note: Different from original paper, `x_low` is underwent + # `self.conv_low` rather than another 1x1 conv classifier + # before being used for auxiliary head. + x_low = self.conv_low(x_low) + x_high = self.conv_high(x_high) + x = x_low + x_high + x = F.relu(x, inplace=True) + return x, x_low + + +@MODELS.register_module() +class ICNeck(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This head is the implementation of `ICHead + `_. + + Args: + in_channels (int): The number of input image channels. Default: 3. + out_channels (int): The numbers of output feature channels. + Default: 128. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(64, 256, 256), + out_channels=128, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(in_channels) == 3, 'Length of input channels \ + must be 3!' + + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.cff_24 = CascadeFeatureFusion( + self.in_channels[2], + self.in_channels[1], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + self.cff_12 = CascadeFeatureFusion( + self.out_channels, + self.in_channels[0], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, inputs): + assert len(inputs) == 3, 'Length of input feature \ + maps must be 3!' + + x_sub1, x_sub2, x_sub4 = inputs + x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) + x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) + # Note: `x_cff_12` is used for decode_head, + # `x_24` and `x_12` are used for auxiliary head. + return x_24, x_12, x_cff_12 diff --git a/mmseg/models/necks/jpu.py b/mmseg/models/necks/jpu.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea0fe2183377d3e3c1a87ca8a0df909b123cdfa --- /dev/null +++ b/mmseg/models/necks/jpu.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class JPU(BaseModule): + """FastFCN: Rethinking Dilated Convolution in the Backbone + for Semantic Segmentation. + + This Joint Pyramid Upsampling (JPU) neck is the implementation of + `FastFCN `_. + + Args: + in_channels (Tuple[int], optional): The number of input channels + for each convolution operations before upsampling. + Default: (512, 1024, 2048). + mid_channels (int): The number of output channels of JPU. + Default: 512. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + dilations (tuple[int]): Dilation rate of each Depthwise + Separable ConvModule. Default: (1, 2, 4, 8). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(512, 1024, 2048), + mid_channels=512, + start_level=0, + end_level=-1, + dilations=(1, 2, 4, 8), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, tuple) + assert isinstance(dilations, tuple) + self.in_channels = in_channels + self.mid_channels = mid_channels + self.start_level = start_level + self.num_ins = len(in_channels) + if end_level == -1: + self.backbone_end_level = self.num_ins + else: + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + + self.dilations = dilations + self.align_corners = align_corners + + self.conv_layers = nn.ModuleList() + self.dilation_layers = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + conv_layer = nn.Sequential( + ConvModule( + self.in_channels[i], + self.mid_channels, + kernel_size=3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.conv_layers.append(conv_layer) + for i in range(len(dilations)): + dilation_layer = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=(self.backbone_end_level - self.start_level) * + self.mid_channels, + out_channels=self.mid_channels, + kernel_size=3, + stride=1, + padding=dilations[i], + dilation=dilations[i], + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=act_cfg)) + self.dilation_layers.append(dilation_layer) + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels), 'Length of inputs must \ + be the same with self.in_channels!' + + feats = [ + self.conv_layers[i - self.start_level](inputs[i]) + for i in range(self.start_level, self.backbone_end_level) + ] + + h, w = feats[0].shape[2:] + for i in range(1, len(feats)): + feats[i] = resize( + feats[i], + size=(h, w), + mode='bilinear', + align_corners=self.align_corners) + + feat = torch.cat(feats, dim=1) + concat_feat = torch.cat([ + self.dilation_layers[i](feat) for i in range(len(self.dilations)) + ], + dim=1) + + outs = [] + + # Default: outs[2] is the output of JPU for decoder head, outs[1] is + # the feature map from backbone for auxiliary head. Additionally, + # outs[0] can also be used for auxiliary head. + for i in range(self.start_level, self.backbone_end_level - 1): + outs.append(inputs[i]) + outs.append(concat_feat) + return tuple(outs) diff --git a/mmseg/models/necks/mla_neck.py b/mmseg/models/necks/mla_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..db250aefbfa45beaa98855be79ddc7f5e7276cca --- /dev/null +++ b/mmseg/models/necks/mla_neck.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.registry import MODELS + + +class MLAModule(nn.Module): + + def __init__(self, + in_channels=[1024, 1024, 1024, 1024], + out_channels=256, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.channel_proj = nn.ModuleList() + for i in range(len(in_channels)): + self.channel_proj.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.feat_extract = nn.ModuleList() + for i in range(len(in_channels)): + self.feat_extract.append( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + + # feat_list -> [p2, p3, p4, p5] + feat_list = [] + for x, conv in zip(inputs, self.channel_proj): + feat_list.append(conv(x)) + + # feat_list -> [p5, p4, p3, p2] + # mid_list -> [m5, m4, m3, m2] + feat_list = feat_list[::-1] + mid_list = [] + for feat in feat_list: + if len(mid_list) == 0: + mid_list.append(feat) + else: + mid_list.append(mid_list[-1] + feat) + + # mid_list -> [m5, m4, m3, m2] + # out_list -> [o2, o3, o4, o5] + out_list = [] + for mid, conv in zip(mid_list, self.feat_extract): + out_list.append(conv(mid)) + + return tuple(out_list) + + +@MODELS.register_module() +class MLANeck(nn.Module): + """Multi-level Feature Aggregation. + + This neck is `The Multi-level Feature Aggregation construction of + SETR `_. + + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + norm_cfg=None, + act_cfg=None): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + + # In order to build general vision transformer backbone, we have to + # move MLA to neck. + self.norm = nn.ModuleList([ + build_norm_layer(norm_layer, in_channels[i])[1] + for i in range(len(in_channels)) + ]) + + self.mla = MLAModule( + in_channels=in_channels, + out_channels=out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # Convert from nchw to nlc + outs = [] + for i in range(len(inputs)): + x = inputs[i] + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm[i](x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + outs.append(x) + + outs = self.mla(outs) + return tuple(outs) diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..c997125f24791b1c01248c60a27fa37a986c6c82 --- /dev/null +++ b/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model.weight_init import xavier_init + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[float]): Scale factors for each input feature map. + Default: [0.5, 1, 2, 4] + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = resize( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59b012f41725d26d099b8f890630d1dc04019ba5 --- /dev/null +++ b/mmseg/models/segmentors/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseSegmentor +from .cascade_encoder_decoder import CascadeEncoderDecoder +from .depth_estimator import DepthEstimator +from .encoder_decoder import EncoderDecoder +from .multimodal_encoder_decoder import MultimodalEncoderDecoder +from .seg_tta import SegTTAModel + +__all__ = [ + 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel', + 'MultimodalEncoderDecoder', 'DepthEstimator' +] diff --git a/mmseg/models/segmentors/__pycache__/__init__.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ea45db55a5b21323b17f085fea715ac83e37f46 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/base.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a931f1aab65ab387eeb686948cae39a2e700803 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/base.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04775e07002f5fb7f3da76f6822940ceacfc2605 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/depth_estimator.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/depth_estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..880a3496cadebf57be9a4af67a722aa06883b5de Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/depth_estimator.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bd4ae7f6210eda2cad29aee9d856e603eb8e96e Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/multimodal_encoder_decoder.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/multimodal_encoder_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914500d5ccf4377616dc502307e3216783d0d536 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/multimodal_encoder_decoder.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/seg_tta.cpython-311.pyc b/mmseg/models/segmentors/__pycache__/seg_tta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa04704f970522bae46501feac3348c0fd4c2b48 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/seg_tta.cpython-311.pyc differ diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..17a0bb2b33e57684bccaaf892af69bcba69dd773 --- /dev/null +++ b/mmseg/models/segmentors/base.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +from mmengine.model import BaseModel +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.structures import SegDataSample +from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig, + OptSampleList, SampleList) +from ..utils import resize + + +class BaseSegmentor(BaseModel, metaclass=ABCMeta): + """Base class for segmentors. + + Args: + data_preprocessor (dict, optional): Model preprocessing config + for processing the input data. it usually includes + ``to_rgb``, ``pad_size_divisor``, ``pad_val``, + ``mean`` and ``std``. Default to None. + init_cfg (dict, optional): the config to control the + initialization. Default to None. + """ + + def __init__(self, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + @property + def with_neck(self) -> bool: + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self) -> bool: + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self) -> bool: + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, inputs: Tensor) -> bool: + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + def forward(self, + inputs: Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`SegDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in + general. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + @abstractmethod + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing.""" + pass + + @abstractmethod + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. + + Usually includes backbone, neck and head forward without any post- + processing. + """ + pass + + def postprocess_result(self, + seg_logits: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + seg_logits (Tensor): The segmentation results, seg_logits from + model of each input image. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + batch_size, C, H, W = seg_logits.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_seg_logits shape is 1, C, H, W after remove padding + i_seg_logits = seg_logits[i:i + 1, :, + padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logits = i_seg_logits.flip(dims=(3, )) + else: + i_seg_logits = i_seg_logits.flip(dims=(2, )) + + # resize as original shape + i_seg_logits = resize( + i_seg_logits, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_seg_logits = seg_logits[i] + + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_logits = i_seg_logits.sigmoid() + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) + data_samples[i].set_data({ + 'seg_logits': + PixelData(**{'data': i_seg_logits}), + 'pred_sem_seg': + PixelData(**{'data': i_seg_pred}) + }) + + return data_samples diff --git a/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmseg/models/segmentors/cascade_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0184a3533a18cbe96a28bbb645c3e73bbffcdeee --- /dev/null +++ b/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from torch import Tensor, nn + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .encoder_decoder import EncoderDecoder + + +@MODELS.register_module() +class CascadeEncoderDecoder(EncoderDecoder): + """Cascade Encoder Decoder segmentors. + + CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of + CascadeEncoderDecoder are cascaded. The output of previous decoder_head + will be the input of next decoder_head. + + Args: + + num_stages (int): How many stages will be cascaded. + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ + + def __init__(self, + num_stages: int, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + self.num_stages = num_stages + super().__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + pretrained=pretrained, + init_cfg=init_cfg) + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + assert isinstance(decode_head, list) + assert len(decode_head) == self.num_stages + self.decode_head = nn.ModuleList() + for i in range(self.num_stages): + self.decode_head.append(MODELS.build(decode_head[i])) + self.align_corners = self.decode_head[-1].align_corners + self.num_classes = self.decode_head[-1].num_classes + self.out_channels = self.decode_head[-1].out_channels + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + out = self.decode_head[0].forward(x) + for i in range(1, self.num_stages - 1): + out = self.decode_head[i].forward(x, out) + seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas, + self.test_cfg) + + return seg_logits_list + + def _decode_head_forward_train(self, inputs: Tensor, + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + + loss_decode = self.decode_head[0].loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode_0')) + # get batch_img_metas + batch_size = len(data_samples) + batch_img_metas = [] + for batch_index in range(batch_size): + metainfo = data_samples[batch_index].metainfo + batch_img_metas.append(metainfo) + + for i in range(1, self.num_stages): + # forward test again, maybe unnecessary for most methods. + if i == 1: + prev_outputs = self.decode_head[0].forward(inputs) + else: + prev_outputs = self.decode_head[i - 1].forward( + inputs, prev_outputs) + loss_decode = self.decode_head[i].loss(inputs, prev_outputs, + data_samples, + self.train_cfg) + losses.update(add_prefix(loss_decode, f'decode_{i}')) + + return losses + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_semantic_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + + out = self.decode_head[0].forward(x) + for i in range(1, self.num_stages): + # TODO support PointRend tensor mode + out = self.decode_head[i].forward(x, out) + + return out diff --git a/mmseg/models/segmentors/depth_estimator.py b/mmseg/models/segmentors/depth_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..1020637e737a3c72ba6a48f2d1228717470ba862 --- /dev/null +++ b/mmseg/models/segmentors/depth_estimator.py @@ -0,0 +1,392 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import print_log +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from ..utils import resize +from .encoder_decoder import EncoderDecoder + + +@MODELS.register_module() +class DepthEstimator(EncoderDecoder): + """Encoder Decoder depth estimator. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict depth estimation results, + which includes two steps: (1) Run inference function to obtain the list of + depth (2) Call post-processing function to obtain list of + ``SegDataSample`` including ``pred_depth_map``. + + .. code:: text + + predict(): inference() -> postprocess_result() + inference(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of depth estimator. + decode_head (ConfigType): The config for the decode head of depth estimator. + neck (OptConfigType): The config for the neck of depth estimator. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + depth estimator. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + pretrained=pretrained, + init_cfg=init_cfg) + + def extract_feat(self, + inputs: Tensor, + batch_img_metas: Optional[List[dict]] = None) -> Tensor: + """Extract features from images.""" + + if getattr(self.backbone, 'class_embed_select', False) and \ + isinstance(batch_img_metas, list) and \ + 'category_id' in batch_img_metas[0]: + cat_ids = [meta['category_id'] for meta in batch_img_metas] + cat_ids = torch.tensor(cat_ids).to(inputs.device) + inputs = (inputs, cat_ids) + + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a depth map of the same + size as input.""" + x = self.extract_feat(inputs, batch_img_metas) + depth = self.decode_head.predict(x, batch_img_metas, self.test_cfg) + + return depth + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_depth_map`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + x = self.extract_feat(inputs, batch_img_metas) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_depth_map`. + + Returns: + list[:obj:`SegDataSample`]: Depth estimation results of the + input images. Each SegDataSample usually contain: + + - ``pred_depth_max``(PixelData): Prediction of depth estimation. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + depth = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(depth, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_depth_map`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_flip_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap and flip. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The depth estimation results. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is depth tensor map + # with shape [N, C, H, W] + crop_depth_map = self.encode_decode(crop_img, batch_img_metas) + + # average out the original and flipped prediction + crop_depth_map_flip = self.encode_decode( + crop_img.flip(dims=(3, )), batch_img_metas) + crop_depth_map_flip = crop_depth_map_flip.flip(dims=(3, )) + crop_depth_map = (crop_depth_map + crop_depth_map_flip) / 2.0 + + preds += F.pad(crop_depth_map, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + depth = preds / count_mat + + return depth + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The depth estimation results. + """ + assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole', + 'slide_flip'], \ + f'Only "slide", "slide_flip" or "whole" test mode are ' \ + f'supported, but got {self.test_cfg["mode"]}.' + ori_shape = batch_img_metas[0]['ori_shape'] + if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas): + print_log( + 'Image shapes are different in the batch.', + logger='current', + level=logging.WARN) + if self.test_cfg.mode == 'slide': + depth_map = self.slide_inference(inputs, batch_img_metas) + if self.test_cfg.mode == 'slide_flip': + depth_map = self.slide_flip_inference(inputs, batch_img_metas) + else: + depth_map = self.whole_inference(inputs, batch_img_metas) + + return depth_map + + def postprocess_result(self, + depth: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + depth (Tensor): The depth estimation results. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_depth_map`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Depth estomation results of the + input images. Each SegDataSample usually contain: + + - ``pred_depth_map``(PixelData): Prediction of depth estimation. + """ + batch_size, C, H, W = depth.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_depth shape is 1, C, H, W after remove padding + i_depth = depth[i:i + 1, :, padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_depth = i_depth.flip(dims=(3, )) + else: + i_depth = i_depth.flip(dims=(2, )) + + # resize as original shape + i_depth = resize( + i_depth, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_depth = depth[i] + + data_samples[i].set_data( + {'pred_depth_map': PixelData(**{'data': i_depth})}) + + return data_samples diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4050e0b736f98c17629a93e2f70be1d7e84fbb --- /dev/null +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional + +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import print_log +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .base import BaseSegmentor + + +@MODELS.register_module() +class EncoderDecoder(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + infercen(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None: + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(MODELS.build(head_cfg)) + else: + self.auxiliary_head = MODELS.build(auxiliary_head) + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract features from images.""" + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + seg_logits = self.decode_head.predict(x, batch_img_metas, + self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole'], \ + f'Only "slide" or "whole" test mode are supported, but got ' \ + f'{self.test_cfg["mode"]}.' + ori_shape = batch_img_metas[0]['ori_shape'] + if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas): + print_log( + 'Image shapes are different in the batch.', + logger='current', + level=logging.WARN) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/mmseg/models/segmentors/multimodal_encoder_decoder.py b/mmseg/models/segmentors/multimodal_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75aa8b9b17688cb5f54da08f9300af82b3339967 --- /dev/null +++ b/mmseg/models/segmentors/multimodal_encoder_decoder.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .base import BaseSegmentor + + +@MODELS.register_module() +class MultimodalEncoderDecoder(BaseSegmentor): + """Multimodal Encoder-Decoder segmentors. + + Multimodal segmentation architecture is used for open-vocabulary + semantic segmentation with combining the visual and language + pretrain models. It consists of a image_encoder (backbone) to extract + visual feature, a text encoder to extract text feature, and a decode + head to generate semantic maps. + Note that the deep supervision during training is implemented in decode head. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() + _decode_head_forward_train(): decode_head.loss() + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + inference(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + image_encoder (ConfigType): The config for the visual encoder of segmentor. + text_encoder ((ConfigType): The config for the text encoder of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + asymetric_input (bool): whether to use different size of input for image encoder + and decode head. Defaults to False. + encoder_resolution (float): resize scale of input images for image encoder. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + image_encoder: ConfigType, + text_encoder: ConfigType, + decode_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + asymetric_input: bool = True, + encoder_resolution: float = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + image_encoder.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + text_encoder.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + decode_head.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + + if asymetric_input: + assert encoder_resolution is not None, \ + 'if asymetric_input set True, ' \ + 'clip_resolution must be a certain value' + self.asymetric_input = asymetric_input + self.encoder_resolution = encoder_resolution + self.image_encoder = MODELS.build(image_encoder) + self.text_encoder = MODELS.build(text_encoder) + self._init_decode_head(decode_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract visual features from images.""" + x = self.image_encoder(inputs) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode the name of classes with text_encoder and encode images with + image_encoder. + + Then decode the class embedding and visual feature into a semantic + segmentation map of the same size as input. + """ + classifier_embeds = self.text_encoder() + clip_inputs = inputs + if self.asymetric_input: + clip_inputs = F.interpolate( + inputs, scale_factor=self.encoder_resolution, mode='bilinear') + x = self.image_encoder(clip_inputs) + seg_logits = self.decode_head.predict([inputs, x, classifier_embeds], + batch_img_metas, self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + classifier_embeds = self.text_encoder() + clip_inputs = inputs + if self.asymetric_input: + clip_inputs = F.interpolate( + inputs, scale_factor=self.encoder_resolution, mode='bilinear') + x = self.image_encoder(clip_inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train( + [inputs, x, classifier_embeds], data_samples) + losses.update(loss_decode) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = batch_img_metas[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/mmseg/models/segmentors/seg_tta.py b/mmseg/models/segmentors/seg_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..63ef61d223a572dec4fc3e43e1550b98cd2e7302 --- /dev/null +++ b/mmseg/models/segmentors/seg_tta.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.model import BaseTTAModel +from mmengine.structures import PixelData + +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +@MODELS.register_module() +class SegTTAModel(BaseTTAModel): + + def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[SampleList]): List of predictions + of all enhanced data. + + Returns: + SampleList: Merged prediction. + """ + predictions = [] + for data_samples in data_samples_list: + seg_logits = data_samples[0].seg_logits.data + logits = torch.zeros(seg_logits.shape).to(seg_logits) + for data_sample in data_samples: + seg_logit = data_sample.seg_logits.data + if self.module.out_channels > 1: + logits += seg_logit.softmax(dim=0) + else: + logits += seg_logit.sigmoid() + logits /= len(data_samples) + if self.module.out_channels == 1: + seg_pred = (logits > self.module.decode_head.threshold + ).to(logits).squeeze(1) + else: + seg_pred = logits.argmax(dim=0) + data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)}) + if hasattr(data_samples[0], 'gt_sem_seg'): + data_sample.set_data( + {'gt_sem_seg': data_samples[0].gt_sem_seg}) + data_sample.set_metainfo({'img_path': data_samples[0].img_path}) + predictions.append(data_sample) + return predictions diff --git a/mmseg/models/text_encoder/__init__.py b/mmseg/models/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..199856d9d79800cbcd9aa7b77223a6528c6b7e0a --- /dev/null +++ b/mmseg/models/text_encoder/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .clip_text_encoder import CLIPTextEncoder + +__all__ = ['CLIPTextEncoder'] diff --git a/mmseg/models/text_encoder/__pycache__/__init__.cpython-311.pyc b/mmseg/models/text_encoder/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094cd8c5d6dde6e35285f9a1c55275a4ee50140d Binary files /dev/null and b/mmseg/models/text_encoder/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/text_encoder/__pycache__/clip_text_encoder.cpython-311.pyc b/mmseg/models/text_encoder/__pycache__/clip_text_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..face23dcc5636187ed0278bcec3e79febb9af0c9 Binary files /dev/null and b/mmseg/models/text_encoder/__pycache__/clip_text_encoder.cpython-311.pyc differ diff --git a/mmseg/models/text_encoder/clip_text_encoder.py b/mmseg/models/text_encoder/clip_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1a18b86395ebcf0443e9aab05f4454acada98990 --- /dev/null +++ b/mmseg/models/text_encoder/clip_text_encoder.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from mmengine.model import BaseModule, ModuleList +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn import functional as F + +from mmseg.registry import MODELS +from mmseg.utils import get_classes, get_predefined_templates, tokenizer + + +@MODELS.register_module() +class CLIPTextEncoder(BaseModule): + """A text encoder with transformer architecture to encode the label text. + + Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501 + Copyright (c) 2023 MendelXu. + Licensed under the MIT License + + Args: + dataset_name: (str|None): The name of the dataset to which + the data belongs. + vocabulary: (List[str]|None): The list of class names. Default: None. + templates: (List[str]|None): The prompt template used for labels. + Default: None. + total_vocab_size: (int): Number of all words used by the pre-trained + model. Default: 49408 (CLIP). + context_length: (int): The max length of prompt text. + Default: 77 (CLIP). + embed_dims: (int): Width of transformer model. Default: 512. + num_layers: (int): Depth of transformer. Default: 12, + num_heads: (int): Number of attention heads in transformer. + Default: 8, + mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in + transformer. Default: 4, + output_dims: (int) Dim of output text embeddings. Default: 512, + cache_feature: (bool) Whether to save class embeddings in cache. + Default: True, + cat_bg: (bool) Whether to add background embedding. Default: True. + norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + dataset_name: str = None, + vocabulary: List[str] = None, + templates: str = 'vild', + total_vocab_size: int = 49408, + context_length: int = 77, + embed_dims: int = 512, + num_layers: int = 12, + num_heads: int = 8, + mlp_ratio: int = 4, + output_dims: int = 512, + cache_feature: bool = True, + cat_bg: bool = True, + norm_cfg: dict = dict(type='LN'), + init_cfg: dict = None): + super().__init__(init_cfg) + if isinstance(templates, List): + self.templates = templates + else: + self.templates = get_predefined_templates(templates) + + assert dataset_name is not None or vocabulary is not None, \ + "text_encoder required either 'dataset_name' or 'vocabulary'" + assert dataset_name is None or vocabulary is None, \ + "there is conflict between 'dataset_name' and 'vocabulary'" + self.dataset_name = dataset_name + self.vocabulary = vocabulary + self.num_pos = context_length + self.token_embedding = nn.Embedding(total_vocab_size, embed_dims) + self.positional_embedding = nn.Parameter( + torch.empty(context_length, embed_dims)) + self.text_projection = nn.Parameter( + torch.empty(embed_dims, output_dims)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.transformer = ModuleList() + self.register_buffer( + 'attn_mask', self.build_attention_mask(), persistent=False) + for i in range(num_layers): + self.transformer.append( + BaseTransformerLayer( + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + batch_first=False, + bias=True), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=mlp_ratio * embed_dims, + act_cfg=dict(type='QuickGELU')), + operation_order=('norm', 'self_attn', 'norm', 'ffn'))) + self.ln_final = build_norm_layer( + norm_cfg, embed_dims, postfix='_final')[1] + + self.cache_feature = cache_feature + if self.cache_feature: + self.cache = {} + + self._freeze() + + self.cat_bg = cat_bg + if self.cat_bg: + self.bg_embed = nn.Parameter( + torch.randn(1, self.text_projection.shape[1])) + + @property + def ln_final(self): + return getattr(self, self.final_name) + + def build_attention_mask(self): + """lazily create causal attention mask, with full attention between the + tokens. + + pytorch uses additive attention mask; fill with -inf + """ + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def _freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def init_weights(self): + if self.cat_bg: + nn.init.normal_( + self.bg_embed, + std=self.bg_embed.shape[1]**-0.5, + ) + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') == 'Pretrained_Part': + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + state_dict = checkpoint.copy() + para_prefix = 'text_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v + + load_state_dict(self, state_dict, strict=False, logger=None) + + else: + super().init_weights() + + @torch.no_grad() + def encode_text(self, text, normalize=False): + """encode class token.""" + + embed_device = self.token_embedding.weight.device + x = self.token_embedding( + text.to(embed_device)) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + for block in self.transformer: + x = block(query=x, attn_masks=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def template_encode(self, vocabulary): + """Prompt engineering.""" + text_embed_bucket = [] + for template in self.templates: + text_inputs = tokenizer.tokenize( + [template.format(noun) for noun in vocabulary]) + text_embed = self.encode_text(text_inputs, normalize=True) + text_embed_bucket.append(text_embed) + text_embed = torch.stack(text_embed_bucket).mean(dim=0) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + return text_embed + + def forward(self): + """Forward function.""" + if self.dataset_name is None: # encoding vocabulary directly + class_names = self.vocabulary + if self.cache_feature: + new_classes = [ + word for word in class_names if word not in self.cache + ] + if len(new_classes) > 0: + class_embeds = self.template_encode(new_classes) + self.cache.update(dict(zip(new_classes, class_embeds))) + class_embeds = torch.stack( + [self.cache[word] for word in class_names]) + else: + class_embeds = self.template_encode(class_names) + + else: # encoding the classes of the dataset + class_names = get_classes(self.dataset_name) + if class_names[0] == 'background': + class_names = class_names[1:] + if self.cache_feature: + if self.dataset_name not in self.cache: + class_embeds = self.template_encode(class_names) + self.cache[self.dataset_name] = class_embeds + else: + class_embeds = self.cache[self.dataset_name] + else: + class_embeds = self.template_encode(class_names) + + if self.cat_bg: + class_embeds = torch.cat([class_embeds, self.bg_embed]) + class_embeds = F.normalize(class_embeds, p=2, dim=-1) + return self.logit_scale.exp() * class_embeds + + +@MODELS.register_module() +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/main/clip/model.py + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0751b17c02de14e9bf1bfc02230d507a143e9c0 --- /dev/null +++ b/mmseg/models/utils/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .basic_block import BasicBlock, Bottleneck +from .embed import PatchEmbed +from .encoding import Encoding +from .inverted_residual import InvertedResidual, InvertedResidualV3 +from .make_divisible import make_divisible +from .point_sample import get_uncertain_point_coords_with_randomness +from .ppm import DAPPM, PAPPM +from .res_layer import ResLayer +from .se_layer import SELayer +from .self_attention_block import SelfAttentionBlock +from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, + nlc_to_nchw) +from .up_conv_block import UpConvBlock + +# isort: off +from .wrappers import Upsample, resize +from .san_layers import MLP, LayerNorm2d, cross_attn_layer + +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', + 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', + 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck', + 'cross_attn_layer', 'LayerNorm2d', 'MLP', + 'get_uncertain_point_coords_with_randomness' +] diff --git a/mmseg/models/utils/__pycache__/__init__.cpython-311.pyc b/mmseg/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c488a4f49148a58ecbe8d10a91d0931ecd97fc44 Binary files /dev/null and b/mmseg/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/basic_block.cpython-311.pyc b/mmseg/models/utils/__pycache__/basic_block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..360e965dc55e2ce2ca96238e2be85d5e980d36c7 Binary files /dev/null and b/mmseg/models/utils/__pycache__/basic_block.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/embed.cpython-311.pyc b/mmseg/models/utils/__pycache__/embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d394ee3fbebefa72ff511ee91534b62b6920c0d3 Binary files /dev/null and b/mmseg/models/utils/__pycache__/embed.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/encoding.cpython-311.pyc b/mmseg/models/utils/__pycache__/encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcc9751f623d4a46b7158796c5db884d0f08ad76 Binary files /dev/null and b/mmseg/models/utils/__pycache__/encoding.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/inverted_residual.cpython-311.pyc b/mmseg/models/utils/__pycache__/inverted_residual.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..009473b1b15e003c25411f5c53beba9f00d7f336 Binary files /dev/null and b/mmseg/models/utils/__pycache__/inverted_residual.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/make_divisible.cpython-311.pyc b/mmseg/models/utils/__pycache__/make_divisible.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654b2839d7d8efaad1116065b4c595f43ea07c0c Binary files /dev/null and b/mmseg/models/utils/__pycache__/make_divisible.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/point_sample.cpython-311.pyc b/mmseg/models/utils/__pycache__/point_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edc443fe5b4385870ca7345906de28bc8b639c6b Binary files /dev/null and b/mmseg/models/utils/__pycache__/point_sample.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/ppm.cpython-311.pyc b/mmseg/models/utils/__pycache__/ppm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37cf05fae5c63f48c9cc1ca5443bca744897d9e9 Binary files /dev/null and b/mmseg/models/utils/__pycache__/ppm.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/res_layer.cpython-311.pyc b/mmseg/models/utils/__pycache__/res_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a277c680b214ad99644f6908079c79cc273ac3e Binary files /dev/null and b/mmseg/models/utils/__pycache__/res_layer.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/san_layers.cpython-311.pyc b/mmseg/models/utils/__pycache__/san_layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d10fd12d3ba3a09cc36655c2d5d6a0faf42e90db Binary files /dev/null and b/mmseg/models/utils/__pycache__/san_layers.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/se_layer.cpython-311.pyc b/mmseg/models/utils/__pycache__/se_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2beaaba6e320ca9b76a604f72c10585bdc0b7ddd Binary files /dev/null and b/mmseg/models/utils/__pycache__/se_layer.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/self_attention_block.cpython-311.pyc b/mmseg/models/utils/__pycache__/self_attention_block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e01843f05dd0b447dd659ba48e66286943a08b Binary files /dev/null and b/mmseg/models/utils/__pycache__/self_attention_block.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/shape_convert.cpython-311.pyc b/mmseg/models/utils/__pycache__/shape_convert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..159d74ff678c7c200c966e9315e790c6744070d1 Binary files /dev/null and b/mmseg/models/utils/__pycache__/shape_convert.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/up_conv_block.cpython-311.pyc b/mmseg/models/utils/__pycache__/up_conv_block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca0ca157cd37552d574d56e45f003a98a081ae1 Binary files /dev/null and b/mmseg/models/utils/__pycache__/up_conv_block.cpython-311.pyc differ diff --git a/mmseg/models/utils/__pycache__/wrappers.cpython-311.pyc b/mmseg/models/utils/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28186b824389023469247406eb2d67a6ff2d889e Binary files /dev/null and b/mmseg/models/utils/__pycache__/wrappers.cpython-311.pyc differ diff --git a/mmseg/models/utils/basic_block.py b/mmseg/models/utils/basic_block.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1ad8146dd200c5f1e543adf22ada654ee196a4 --- /dev/null +++ b/mmseg/models/utils/basic_block.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType + + +class BasicBlock(BaseModule): + """Basic block from `ResNet `_. + + Args: + in_channels (int): Input channels. + channels (int): Output channels. + stride (int): Stride of the first block. Default: 1. + downsample (nn.Module, optional): Downsample operation on identity. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU', inplace=True). + act_cfg_out (dict, optional): Config dict for activation layer at the + last of the block. Default: None. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + expansion = 1 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + downsample: nn.Module = None, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv1 = ConvModule( + in_channels, + channels, + kernel_size=3, + stride=stride, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=None) + self.downsample = downsample + if act_cfg_out: + self.act = MODELS.build(act_cfg_out) + + def forward(self, x: Tensor) -> Tensor: + residual = x + out = self.conv1(x) + out = self.conv2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + + if hasattr(self, 'act'): + out = self.act(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block from `ResNet `_. + + Args: + in_channels (int): Input channels. + channels (int): Output channels. + stride (int): Stride of the first block. Default: 1. + downsample (nn.Module, optional): Downsample operation on identity. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU', inplace=True). + act_cfg_out (dict, optional): Config dict for activation layer at + the last of the block. Default: None. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + expansion = 2 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + act_cfg_out: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv1 = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.conv2 = ConvModule( + channels, + channels, + 3, + stride, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv3 = ConvModule( + channels, + channels * self.expansion, + 1, + norm_cfg=norm_cfg, + act_cfg=None) + if act_cfg_out: + self.act = MODELS.build(act_cfg_out) + self.downsample = downsample + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + + if hasattr(self, 'act'): + out = self.act(out) + + return out diff --git a/mmseg/models/utils/embed.py b/mmseg/models/utils/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..aef0a40b0a87bb6616db96fe2c72c19cc6f5b366 --- /dev/null +++ b/mmseg/models/utils/embed.py @@ -0,0 +1,330 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils import to_2tuple + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1. + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super().__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int, optional): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmengine.ConfigDict`, optional): The Config for + initialization. Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=768, + conv_type='Conv2d', + kernel_size=16, + stride=None, + padding='corner', + dilation=1, + bias=True, + norm_cfg=None, + input_size=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/mmseg/models/utils/encoding.py b/mmseg/models/utils/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4f0574fbc1957cf8da591a0e4befd6d8a125d3 --- /dev/null +++ b/mmseg/models/utils/encoding.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super().__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/mmseg/models/utils/inverted_residual.py b/mmseg/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..56190b3bfe7cc8fe98bf34c3812db18dd34a8f02 --- /dev/null +++ b/mmseg/models/utils/inverted_residual.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch import nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + **kwargs): + super().__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super().__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmseg/models/utils/make_divisible.py b/mmseg/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..ed42c2eeea2a6aed03a0be5516b8d1ef1139e486 --- /dev/null +++ b/mmseg/models/utils/make_divisible.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmseg/models/utils/point_sample.py b/mmseg/models/utils/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc957f3da7d1dc030c21d40311c768c6952ea4 --- /dev/null +++ b/mmseg/models/utils/point_sample.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import point_sample +from torch import Tensor + + +def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor: + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_preds' for the foreground class in `classes`. + + Args: + mask_preds (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (Tensor): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_preds.shape[1] == 1: + gt_class_logits = mask_preds.clone() + else: + inds = torch.arange(mask_preds.shape[0], device=mask_preds.device) + gt_class_logits = mask_preds[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_preds: Tensor, labels: Tensor, num_points: int, + oversample_ratio: float, importance_sample_ratio: float) -> Tensor: + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_preds (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (Tensor): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (float): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_preds.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=mask_preds.device) + point_logits = point_sample(mask_preds, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=mask_preds.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand( + batch_size, num_random_points, 2, device=mask_preds.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/mmseg/models/utils/ppm.py b/mmseg/models/utils/ppm.py new file mode 100644 index 0000000000000000000000000000000000000000..5fe6ff26fae6869b989cecde96af3ceff1a37b38 --- /dev/null +++ b/mmseg/models/utils/ppm.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList, Sequential +from torch import Tensor + + +class DAPPM(BaseModule): + """DAPPM module in `DDRNet `_. + + Args: + in_channels (int): Input channels. + branch_channels (int): Branch channels. + out_channels (int): Output channels. + num_scales (int): Number of scales. + kernel_sizes (list[int]): Kernel sizes of each scale. + strides (list[int]): Strides of each scale. + paddings (list[int]): Paddings of each scale. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer in ConvModule. + Default: dict(order=('norm', 'act', 'conv'), bias=False). + upsample_mode (str): Upsample mode. Default: 'bilinear'. + """ + + def __init__(self, + in_channels: int, + branch_channels: int, + out_channels: int, + num_scales: int, + kernel_sizes: List[int] = [5, 9, 17], + strides: List[int] = [2, 4, 8], + paddings: List[int] = [2, 4, 8], + norm_cfg: Dict = dict(type='BN', momentum=0.1), + act_cfg: Dict = dict(type='ReLU', inplace=True), + conv_cfg: Dict = dict( + order=('norm', 'act', 'conv'), bias=False), + upsample_mode: str = 'bilinear'): + super().__init__() + + self.num_scales = num_scales + self.unsample_mode = upsample_mode + self.in_channels = in_channels + self.branch_channels = branch_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.conv_cfg = conv_cfg + + self.scales = ModuleList([ + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ]) + for i in range(1, num_scales - 1): + self.scales.append( + Sequential(*[ + nn.AvgPool2d( + kernel_size=kernel_sizes[i - 1], + stride=strides[i - 1], + padding=paddings[i - 1]), + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ])) + self.scales.append( + Sequential(*[ + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ])) + self.processes = ModuleList() + for i in range(num_scales - 1): + self.processes.append( + ConvModule( + branch_channels, + branch_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg)) + + self.compression = ConvModule( + branch_channels * num_scales, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + self.shortcut = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + def forward(self, inputs: Tensor): + feats = [] + feats.append(self.scales[0](inputs)) + + for i in range(1, self.num_scales): + feat_up = F.interpolate( + self.scales[i](inputs), + size=inputs.shape[2:], + mode=self.unsample_mode) + feats.append(self.processes[i - 1](feat_up + feats[i - 1])) + + return self.compression(torch.cat(feats, + dim=1)) + self.shortcut(inputs) + + +class PAPPM(DAPPM): + """PAPPM module in `PIDNet `_. + + Args: + in_channels (int): Input channels. + branch_channels (int): Branch channels. + out_channels (int): Output channels. + num_scales (int): Number of scales. + kernel_sizes (list[int]): Kernel sizes of each scale. + strides (list[int]): Strides of each scale. + paddings (list[int]): Paddings of each scale. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', momentum=0.1). + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer in ConvModule. + Default: dict(order=('norm', 'act', 'conv'), bias=False). + upsample_mode (str): Upsample mode. Default: 'bilinear'. + """ + + def __init__(self, + in_channels: int, + branch_channels: int, + out_channels: int, + num_scales: int, + kernel_sizes: List[int] = [5, 9, 17], + strides: List[int] = [2, 4, 8], + paddings: List[int] = [2, 4, 8], + norm_cfg: Dict = dict(type='BN', momentum=0.1), + act_cfg: Dict = dict(type='ReLU', inplace=True), + conv_cfg: Dict = dict( + order=('norm', 'act', 'conv'), bias=False), + upsample_mode: str = 'bilinear'): + super().__init__(in_channels, branch_channels, out_channels, + num_scales, kernel_sizes, strides, paddings, norm_cfg, + act_cfg, conv_cfg, upsample_mode) + + self.processes = ConvModule( + self.branch_channels * (self.num_scales - 1), + self.branch_channels * (self.num_scales - 1), + kernel_size=3, + padding=1, + groups=self.num_scales - 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **self.conv_cfg) + + def forward(self, inputs: Tensor): + x_ = self.scales[0](inputs) + feats = [] + for i in range(1, self.num_scales): + feat_up = F.interpolate( + self.scales[i](inputs), + size=inputs.shape[2:], + mode=self.unsample_mode, + align_corners=False) + feats.append(feat_up + x_) + scale_out = self.processes(torch.cat(feats, dim=1)) + return self.compression(torch.cat([x_, scale_out], + dim=1)) + self.shortcut(inputs) diff --git a/mmseg/models/utils/res_layer.py b/mmseg/models/utils/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd7a6f75a168f2f7e3c61f82d309b1cf0d502bc --- /dev/null +++ b/mmseg/models/utils/res_layer.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import Sequential +from torch import nn as nn + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + multi_grid (int | None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + multi_grid=None, + contract_dilation=False, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if multi_grid is None: + if dilation > 1 and contract_dilation: + first_dilation = dilation // 2 + else: + first_dilation = dilation + else: + first_dilation = multi_grid[0] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=first_dilation, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation if multi_grid is None else multi_grid[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super().__init__(*layers) diff --git a/mmseg/models/utils/san_layers.py b/mmseg/models/utils/san_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..2267686daf62658c5dc81408e0a399c43aee83aa --- /dev/null +++ b/mmseg/models/utils/san_layers.py @@ -0,0 +1,418 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501 +# Copyright (c) 2023 MendelXu. +# Licensed under the MIT License + +import warnings +from typing import Optional + +import torch +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from torch import Tensor, nn +from torch.nn import functional as F + + +def cross_attn_with_self_bias( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +): + """Forward function of multi-head attention. Modified from + multi_head_attention_forward in + https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + """ # noqa: E501 + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, \ + 'embed_dim must be divisible by num_heads' + scaling = float(head_dim)**-0.5 + + if not use_separate_proj_weight: + if (query is key or torch.equal( + query, key)) and (key is value or torch.equal(key, value)): + # self-attention + raise NotImplementedError('self-attention is not implemented') + + elif key is value or torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function + # with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + q_k = None + q_v = None + else: + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1) + else: + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = F.linear(key, _w, _b) + q_k = F.linear(query, _w, _b) + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = F.linear(value, _w, _b) + q_v = F.linear(query, _w, _b) + else: + q_proj_weight_non_opt = \ + torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = \ + torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = \ + torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = F.linear(query, q_proj_weight_non_opt, + in_proj_bias[0:embed_dim]) + k = F.linear(key, k_proj_weight_non_opt, + in_proj_bias[embed_dim:(embed_dim * 2)]) + v = F.linear(value, v_proj_weight_non_opt, + in_proj_bias[(embed_dim * 2):]) + else: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool + ), 'Only float, byte, and bool types are supported for ' \ + 'attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention ' + 'is deprecated. Use bool tensor instead.') + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + 'The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), key.size(0) + ]: + raise RuntimeError( + 'The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + 'Byte tensor for key_padding_mask in nn.MultiheadAttention ' + 'is deprecated. Use bool tensor instead.') + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, 'bias cannot be added to static key.' + assert static_v is None, 'bias cannot be added to static value.' + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q_k = q_k.contiguous().view(tgt_len, bsz * num_heads, + head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q_v = q_v.contiguous().view(tgt_len, bsz * num_heads, + head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat( + [ + k, + torch.zeros( + (k.size(0), 1) + k.size()[2:], + dtype=k.dtype, + device=k.device), + ], + dim=1, + ) + v = torch.cat( + [ + v, + torch.zeros( + (v.size(0), 1) + v.size()[2:], + dtype=v.dtype, + device=v.device), + ], + dim=1, + ) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list( + attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, + src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, + tgt_len, src_len) + # attn_out_weights: [bsz * num_heads, tgt_len, src_len] + # ->[bsz * num_heads, tgt_len, src_len+1] + self_weight = (q * q_k).sum( + dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1] + total_attn_output_weights = torch.cat([attn_output_weights, self_weight], + dim=-1) + total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1) + total_attn_output_weights = F.dropout( + total_attn_output_weights, p=dropout_p, training=training) + attn_output_weights = \ + total_attn_output_weights[:, :, : -1] + # [bsz * num_heads, tgt_len, src_len] + self_weight = \ + total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1] + + attn_output = torch.bmm(attn_output_weights, + v) # [bsz * num_heads, tgt_len, head_dim] + attn_output = (attn_output + self_weight * q_v + ) # [bsz * num_heads, tgt_len, head_dim] + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view( + tgt_len, bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, + src_len) + return attn_output, attn_output_weights # .sum(dim=1) / num_heads + else: + return attn_output, None + + +def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias): + """Implementation of transformer layer with cross attention. The cross + attention shares the embedding weights with self-attention of tf_layer. + Args: + tf_layer: (TransformerEncoderLayer): The Module of transformer layer. + x (Tensor): query [K,N,C] + mem (Tensor): key and value [L,N,C] + attn_bias (Tensor): attention bias [N*num_head,K,L] + + Return: + x (Tensor): cross attention output [K,N,C] + """ + self_attn_layer = tf_layer.attentions[0].attn + attn_layer_paras = { + 'embed_dim_to_check': self_attn_layer.embed_dim, + 'num_heads': self_attn_layer.num_heads, + 'in_proj_weight': self_attn_layer.in_proj_weight, + 'in_proj_bias': self_attn_layer.in_proj_bias, + 'bias_k': self_attn_layer.bias_k, + 'bias_v': self_attn_layer.bias_v, + 'add_zero_attn': self_attn_layer.add_zero_attn, + 'dropout_p': self_attn_layer.dropout, + 'out_proj_weight': self_attn_layer.out_proj.weight, + 'out_proj_bias': self_attn_layer.out_proj.bias, + 'training': self_attn_layer.training + } + + q_x = tf_layer.norms[0](x) + k_x = v_x = tf_layer.norms[0](mem) + x = x + cross_attn_with_self_bias( + q_x, + k_x, + v_x, + attn_mask=attn_bias, + need_weights=False, + **attn_layer_paras)[0] + x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x) + return x + + +class LayerNorm2d(nn.Module): + """A LayerNorm variant, popularized by Transformers, that performs point- + wise mean and variance normalization over the channel dimension for inputs + that have shape (batch_size, channels, height, width). + + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape, ) + + def forward(self, x: torch.Tensor): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, + input_dim, + hidden_dim, + output_dim, + num_layers, + affine_func=nn.Linear): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + affine_func(n, k) + for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: torch.Tensor): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/mmseg/models/utils/se_layer.py b/mmseg/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff632cfea728a7ffd99f1578c828c588d78f3db --- /dev/null +++ b/mmseg/models/utils/se_layer.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils import is_tuple_of + +from .make_divisible import make_divisible + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configured + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configured by the first dict and the + second activation layer will be configured by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)). + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=make_divisible(channels // ratio, 8), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=make_divisible(channels // ratio, 8), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/mmseg/models/utils/self_attention_block.py b/mmseg/models/utils/self_attention_block.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb6e8284e599637c12553e27199338a820709e3 --- /dev/null +++ b/mmseg/models/utils/self_attention_block.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule +from mmengine.model.weight_init import constant_init +from torch import nn as nn +from torch.nn import functional as F + + +class SelfAttentionBlock(nn.Module): + """General self-attention block/non-local block. + + Please refer to https://arxiv.org/abs/1706.03762 for details about key, + query and value. + + Args: + key_in_channels (int): Input channels of key feature. + query_in_channels (int): Input channels of query feature. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_downsample (nn.Module): Query downsample module. + key_downsample (nn.Module): Key downsample module. + key_query_num_convs (int): Number of convs for key/query projection. + value_num_convs (int): Number of convs for value projection. + matmul_norm (bool): Whether normalize attention map with sqrt of + channels + with_out (bool): Whether use out projection. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, key_in_channels, query_in_channels, channels, + out_channels, share_key_query, query_downsample, + key_downsample, key_query_num_convs, value_out_num_convs, + key_query_norm, value_out_norm, matmul_norm, with_out, + conv_cfg, norm_cfg, act_cfg): + super().__init__() + if share_key_query: + assert key_in_channels == query_in_channels + self.key_in_channels = key_in_channels + self.query_in_channels = query_in_channels + self.out_channels = out_channels + self.channels = channels + self.share_key_query = share_key_query + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.key_project = self.build_project( + key_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if share_key_query: + self.query_project = self.key_project + else: + self.query_project = self.build_project( + query_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.value_project = self.build_project( + key_in_channels, + channels if with_out else out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if with_out: + self.out_project = self.build_project( + channels, + out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.out_project = None + + self.query_downsample = query_downsample + self.key_downsample = key_downsample + self.matmul_norm = matmul_norm + + self.init_weights() + + def init_weights(self): + """Initialize weight of later layer.""" + if self.out_project is not None: + if not isinstance(self.out_project, ConvModule): + constant_init(self.out_project, 0) + + def build_project(self, in_channels, channels, num_convs, use_conv_module, + conv_cfg, norm_cfg, act_cfg): + """Build projection layer for key/query/value/out.""" + if use_conv_module: + convs = [ + ConvModule( + in_channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ] + for _ in range(num_convs - 1): + convs.append( + ConvModule( + channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + convs = [nn.Conv2d(in_channels, channels, 1)] + for _ in range(num_convs - 1): + convs.append(nn.Conv2d(channels, channels, 1)) + if len(convs) > 1: + convs = nn.Sequential(*convs) + else: + convs = convs[0] + return convs + + def forward(self, query_feats, key_feats): + """Forward function.""" + batch_size = query_feats.size(0) + query = self.query_project(query_feats) + if self.query_downsample is not None: + query = self.query_downsample(query) + query = query.reshape(*query.shape[:2], -1) + query = query.permute(0, 2, 1).contiguous() + + key = self.key_project(key_feats) + value = self.value_project(key_feats) + if self.key_downsample is not None: + key = self.key_downsample(key) + value = self.key_downsample(value) + key = key.reshape(*key.shape[:2], -1) + value = value.reshape(*value.shape[:2], -1) + value = value.permute(0, 2, 1).contiguous() + + sim_map = torch.matmul(query, key) + if self.matmul_norm: + sim_map = (self.channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.reshape(batch_size, -1, *query_feats.shape[2:]) + if self.out_project is not None: + context = self.out_project(context) + return context diff --git a/mmseg/models/utils/shape_convert.py b/mmseg/models/utils/shape_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..cce1e220b645d4b02df1ec2d9ed3137c8acba707 --- /dev/null +++ b/mmseg/models/utils/shape_convert.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): + """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the + reshaped tensor as the input of `module`, and the convert the output of + `module`, whose shape is. + + [N, L, C], to [N, C, H, W]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, L, C] as input. + x (Tensor): The input tensor of shape [N, C, H, W]. + contiguous: + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, C, H, W]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> norm = nn.LayerNorm(4) + >>> feature_map = torch.rand(4, 4, 5, 5) + >>> output = nchw2nlc2nchw(norm, feature_map) + """ + B, C, H, W = x.shape + if not contiguous: + x = x.flatten(2).transpose(1, 2) + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W) + else: + x = x.flatten(2).transpose(1, 2).contiguous() + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + return x + + +def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): + """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the + reshaped tensor as the input of `module`, and convert the output of + `module`, whose shape is. + + [N, C, H, W], to [N, L, C]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, C, H, W] as input. + x (Tensor): The input tensor of shape [N, L, C]. + hw_shape: (Sequence[int]): The height and width of the + feature map with shape [N, C, H, W]. + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, L, C]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> conv = nn.Conv2d(16, 16, 3, 1, 1) + >>> feature_map = torch.rand(4, 25, 16) + >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + if not contiguous: + x = x.transpose(1, 2).reshape(B, C, H, W) + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2) + else: + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2).contiguous() + return x diff --git a/mmseg/models/utils/up_conv_block.py b/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa3b598de96d53c169232d9c89ac458f6921e8d --- /dev/null +++ b/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/mmseg/models/utils/wrappers.py b/mmseg/models/utils/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..abbd0c029623b4f480a067e4b78adfec234ef8d0 --- /dev/null +++ b/mmseg/models/utils/wrappers.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super().__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/mmseg/registry/__init__.py b/mmseg/registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee514d1a2a2bdd54a0a9b017ec227160ee502be5 --- /dev/null +++ b/mmseg/registry/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS, + LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, + OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, + PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, + TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, + WEIGHT_INITIALIZERS) + +__all__ = [ + 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', + 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', + 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', + 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', + 'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS' +] diff --git a/mmseg/registry/__pycache__/__init__.cpython-311.pyc b/mmseg/registry/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..605b742ec47c60898ab5f46cf2da0992d3095ffe Binary files /dev/null and b/mmseg/registry/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/registry/__pycache__/registry.cpython-311.pyc b/mmseg/registry/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84041cc367fafe48d7cec0562981311b09a854de Binary files /dev/null and b/mmseg/registry/__pycache__/registry.cpython-311.pyc differ diff --git a/mmseg/registry/registry.py b/mmseg/registry/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..37b6a776095856c2fab0101b5b0ec8ed7e8fa8f2 --- /dev/null +++ b/mmseg/registry/registry.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMSegmentation provides 21 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['mmseg.engine.hooks']) + +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets']) +DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmseg.datasets.transforms']) + +# mangage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models']) +# mangage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmseg.models']) +# mangage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmseg.models']) + +# mangage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmseg.engine.optimizers']) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry( + 'optim_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmseg.engine.optimizers']) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmseg.engine.optimizers']) +# mangage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmseg.engine.schedulers']) + +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['mmseg.evaluation']) +# manage evaluator +EVALUATOR = Registry( + 'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmseg.evaluation']) + +# manage task-specific modules like ohem pixel sampler +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmseg.models']) + +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmseg.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmseg.visualization']) + +# manage logprocessor +LOG_PROCESSORS = Registry( + 'log_processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmseg.visualization']) + +# manage inferencer +INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS) diff --git a/mmseg/structures/.DS_Store b/mmseg/structures/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..853229b5f69e701b9a238f12aa9c3fe02cca6798 Binary files /dev/null and b/mmseg/structures/.DS_Store differ diff --git a/mmseg/structures/__init__.py b/mmseg/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63d118dca3ebcff30ca241f9378475bcce072627 --- /dev/null +++ b/mmseg/structures/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler +from .seg_data_sample import SegDataSample + +__all__ = [ + 'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler', + 'build_pixel_sampler' +] diff --git a/mmseg/structures/__pycache__/__init__.cpython-311.pyc b/mmseg/structures/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9053d2249c4b8a4f4f8f41267ef6cb019036ee45 Binary files /dev/null and b/mmseg/structures/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/structures/__pycache__/seg_data_sample.cpython-311.pyc b/mmseg/structures/__pycache__/seg_data_sample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..542e2ada6876ae2682594563d84c9c9aebb85464 Binary files /dev/null and b/mmseg/structures/__pycache__/seg_data_sample.cpython-311.pyc differ diff --git a/mmseg/structures/sampler/__init__.py b/mmseg/structures/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91d762d1b4552b391ece046fa3d094409011bcec --- /dev/null +++ b/mmseg/structures/sampler/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_pixel_sampler import BasePixelSampler +from .builder import build_pixel_sampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/mmseg/structures/sampler/__pycache__/__init__.cpython-311.pyc b/mmseg/structures/sampler/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..725b493a9eeebcd208ee3c0370eeaf99cbafd81a Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/structures/sampler/__pycache__/base_pixel_sampler.cpython-311.pyc b/mmseg/structures/sampler/__pycache__/base_pixel_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4192220333f9c3f407d027943177be11feefea32 Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/base_pixel_sampler.cpython-311.pyc differ diff --git a/mmseg/structures/sampler/__pycache__/builder.cpython-311.pyc b/mmseg/structures/sampler/__pycache__/builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d144623cb9c83a85b5ee2beba8475853ed2181c3 Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/builder.cpython-311.pyc differ diff --git a/mmseg/structures/sampler/__pycache__/ohem_pixel_sampler.cpython-311.pyc b/mmseg/structures/sampler/__pycache__/ohem_pixel_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df01eb34fa41628727addbeaa992b9d8ccc88fe6 Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/ohem_pixel_sampler.cpython-311.pyc differ diff --git a/mmseg/structures/sampler/base_pixel_sampler.py b/mmseg/structures/sampler/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..03672cd478a2e464cc734ae92686c86f219da0a9 --- /dev/null +++ b/mmseg/structures/sampler/base_pixel_sampler.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/mmseg/structures/sampler/builder.py b/mmseg/structures/sampler/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..48e14790264a3d4c4ff54d84e5bab67b1623a1df --- /dev/null +++ b/mmseg/structures/sampler/builder.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import TASK_UTILS + +PIXEL_SAMPLERS = TASK_UTILS + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) diff --git a/mmseg/structures/sampler/ohem_pixel_sampler.py b/mmseg/structures/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a974273cab504be269e7f391e23a521b97bd8588 --- /dev/null +++ b/mmseg/structures/sampler/ohem_pixel_sampler.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_pixel_sampler import BasePixelSampler +from .builder import PIXEL_SAMPLERS + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super().__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + if not isinstance(self.context.loss_decode, nn.ModuleList): + losses_decode = [self.context.loss_decode] + else: + losses_decode = self.context.loss_decode + losses = 0.0 + for loss_module in losses_decode: + losses += loss_module( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/mmseg/structures/seg_data_sample.py b/mmseg/structures/seg_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..ce68b5474330e2149d7d1c4de2d2406ae5b0345e --- /dev/null +++ b/mmseg/structures/seg_data_sample.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.structures import BaseDataElement, PixelData + + +class SegDataSample(BaseDataElement): + """A data structure interface of MMSegmentation. They are used as + interfaces between different components. + + The attributes in ``SegDataSample`` are divided into several parts: + + - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation. + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic segmentation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.structures import PixelData + >>> from mmseg.structures import SegDataSample + + >>> data_sample = SegDataSample() + >>> img_meta = dict(img_shape=(4, 4, 3), + ... pad_shape=(4, 4, 3)) + >>> gt_segmentations = PixelData(metainfo=img_meta) + >>> gt_segmentations.data = torch.randint(0, 2, (1, 4, 4)) + >>> data_sample.gt_sem_seg = gt_segmentations + >>> assert 'img_shape' in data_sample.gt_sem_seg.metainfo_keys() + >>> data_sample.gt_sem_seg.shape + (4, 4) + >>> print(data_sample) + + ) at 0x1c2aae44d60> + + >>> data_sample = SegDataSample() + >>> gt_sem_seg_data = dict(sem_seg=torch.rand(1, 4, 4)) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> data_sample.gt_sem_seg = gt_sem_seg + >>> assert 'gt_sem_seg' in data_sample + >>> assert 'sem_seg' in data_sample.gt_sem_seg + """ + + @property + def gt_sem_seg(self) -> PixelData: + return self._gt_sem_seg + + @gt_sem_seg.setter + def gt_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_gt_sem_seg', dtype=PixelData) + + @gt_sem_seg.deleter + def gt_sem_seg(self) -> None: + del self._gt_sem_seg + + @property + def pred_sem_seg(self) -> PixelData: + return self._pred_sem_seg + + @pred_sem_seg.setter + def pred_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_pred_sem_seg', dtype=PixelData) + + @pred_sem_seg.deleter + def pred_sem_seg(self) -> None: + del self._pred_sem_seg + + @property + def seg_logits(self) -> PixelData: + return self._seg_logits + + @seg_logits.setter + def seg_logits(self, value: PixelData) -> None: + self.set_field(value, '_seg_logits', dtype=PixelData) + + @seg_logits.deleter + def seg_logits(self) -> None: + del self._seg_logits diff --git a/mmseg/ttp/__init__.py b/mmseg/ttp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..acb5f0f785d06fc865d94066eef7136aa7ae6fea --- /dev/null +++ b/mmseg/ttp/__init__.py @@ -0,0 +1,2 @@ +from .models import * +from .metrics import * \ No newline at end of file diff --git a/mmseg/ttp/__pycache__/__init__.cpython-311.pyc b/mmseg/ttp/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87bc299228fa885342bdaae9132b7578d313a46b Binary files /dev/null and b/mmseg/ttp/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/ttp/__pycache__/metrics.cpython-311.pyc b/mmseg/ttp/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05ed327a45b78df312a35744967980fe26f9a112 Binary files /dev/null and b/mmseg/ttp/__pycache__/metrics.cpython-311.pyc differ diff --git a/mmseg/ttp/__pycache__/models.cpython-311.pyc b/mmseg/ttp/__pycache__/models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a77b7c2e0ed9dda217daba52a57213800686d8e7 Binary files /dev/null and b/mmseg/ttp/__pycache__/models.cpython-311.pyc differ diff --git a/mmseg/ttp/loading.py b/mmseg/ttp/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..4849f0544ad3b999a900badf1aab4d611ca86b11 --- /dev/null +++ b/mmseg/ttp/loading.py @@ -0,0 +1,75 @@ +from opencd.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class MultiImgLoadImageFromFile(MMCV_LoadImageFromFile): + """Load an image pair from files. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def transform(self, results: dict) -> Optional[dict]: + """Functions to load image. + + Args: + results (dict): Result dict from + :class:`mmengine.dataset.BaseDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filenames = results['img_path'] + imgs = [] + try: + for filename in filenames: + if self.file_client_args is not None: + file_client = fileio.FileClient.infer_client( + self.file_client_args, filename) + img_bytes = file_client.get(filename) + else: + img_bytes = fileio.get( + filename, backend_args=self.backend_args) + img = mmcv.imfrombytes( + img_bytes, flag=self.color_type, backend=self.imdecode_backend) + if self.to_float32: + img = img.astype(np.float32) + imgs.append(img) + except Exception as e: + if self.ignore_empty: + return None + else: + raise e + + results['img'] = imgs + results['img_shape'] = imgs[0].shape[:2] + results['ori_shape'] = imgs[0].shape[:2] + return results + +@TRANSFORMS.register_module() +class LoadMultiImageFromNDArray(MultiImgLoadImageFromFile): + + def transform(self, results: dict) -> dict: + + img = results['img'] + if self.to_float32: + img = img.astype(np.float32) + + results['img_path'] = None + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + diff --git a/mmseg/ttp/metrics.py b/mmseg/ttp/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..47f95e09b2c262faf375cac3596fb9ed5df3fe67 --- /dev/null +++ b/mmseg/ttp/metrics.py @@ -0,0 +1,80 @@ +from collections import OrderedDict +from typing import Optional, Sequence, Dict +import numpy as np +import torch +from mmengine import MMLogger, print_log +from mmengine.evaluator import BaseMetric +from prettytable import PrettyTable +from torchmetrics.functional.classification import multiclass_precision, multiclass_recall, multiclass_f1_score, \ + multiclass_jaccard_index, multiclass_accuracy, binary_accuracy +from opencd.registry import METRICS + + +@METRICS.register_module() +class CDMetric(BaseMetric): + default_prefix: Optional[str] = 'cd' + + def __init__(self, + ignore_index: int = 255, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.ignore_index = ignore_index + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'].squeeze() + # format_only always for test dataset without ground truth + gt_label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label) + self.results.append((pred_label, gt_label)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + num_classes = len(self.dataset_meta['classes']) + class_names = self.dataset_meta['classes'] + + assert num_classes == 2, 'Only support binary classification in CDMetric.' + + logger: MMLogger = MMLogger.get_current_instance() + pred_label, label = zip(*results) + preds = torch.stack(pred_label, dim=0) + target = torch.stack(label, dim=0) + + multiclass_precision_ = multiclass_precision(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) + multiclass_recall_ = multiclass_recall(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) + multiclass_f1_score_ = multiclass_f1_score(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) + multiclass_jaccard_index_ = multiclass_jaccard_index(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) + accuracy_ = multiclass_accuracy(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) + binary_accuracy_ = binary_accuracy(preds, target, ignore_index=self.ignore_index) + ret_metrics = OrderedDict({ + 'acc': accuracy_.cpu().numpy(), + 'p': multiclass_precision_.cpu().numpy(), + 'r': multiclass_recall_.cpu().numpy(), + 'f1': multiclass_f1_score_.cpu().numpy(), + 'iou': multiclass_jaccard_index_.cpu().numpy(), + 'macc': binary_accuracy_.cpu().numpy(), + }) + + metrics = dict() + for k, v in ret_metrics.items(): + if k == 'macc': + metrics[k] = v.item() + else: + for i in range(num_classes): + metrics[k + '_' + class_names[i]] = v[i].item() + + # each class table + ret_metrics.pop('macc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + return metrics diff --git a/mmseg/ttp/models.py b/mmseg/ttp/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c2125e184f75c918b7841a9bd9b89550dba77267 --- /dev/null +++ b/mmseg/ttp/models.py @@ -0,0 +1,314 @@ +import copy +from typing import List, Tuple, Optional +import torch.nn.functional as F +import einops +import torch +from mmcv.cnn import ConvModule, build_norm_layer +from mmcv.cnn.bricks.transformer import PatchEmbed, FFN, build_transformer_layer +from mmengine.dist import is_main_process +from mmengine.model import BaseModule +from peft import get_peft_config, get_peft_model +from torch import Tensor, nn +# from mmdet.utils import OptConfigType, MultiConfig +from mmpretrain.models import resize_pos_embed +from mmpretrain.models.backbones.vit_sam import Attention, window_partition, window_unpartition +from mmseg.models import BaseSegmentor, EncoderDecoder +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.utils import resize +from mmseg.utils import OptConfigType, MultiConfig +from opencd.registry import MODELS + +from mmpretrain.models import build_norm_layer as build_norm_layer_mmpretrain + + +@MODELS.register_module() +class MMPretrainSamVisionEncoder(BaseModule): + def __init__( + self, + encoder_cfg, + peft_cfg=None, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + vision_encoder = MODELS.build(encoder_cfg) + vision_encoder.init_weights() + if peft_cfg is not None and isinstance(peft_cfg, dict): + config = { + "peft_type": "LORA", + "r": 16, + 'target_modules': ["qkv"], + "lora_alpha": 32, + "lora_dropout": 0.05, + "bias": "none", + "inference_mode": False, + } + config.update(peft_cfg) + peft_config = get_peft_config(config) + self.vision_encoder = get_peft_model(vision_encoder, peft_config) + if is_main_process(): + self.vision_encoder.print_trainable_parameters() + else: + self.vision_encoder = vision_encoder + # freeze the vision encoder + for param in self.vision_encoder.parameters(): + param.requires_grad = False + + def forward(self, x): + return self.vision_encoder(x) + + +@MODELS.register_module() +class MLPSegHead(BaseDecodeHead): + def __init__( + self, + out_size, + interpolate_mode='bilinear', + **kwargs + ): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + self.out_size = out_size + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=self.out_size, + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + out = self.cls_seg(out) + return out + + +@MODELS.register_module() +class LN2d(nn.Module): + """A LayerNorm variant, popularized by Transformers, that performs + pointwise mean and variance normalization over the channel dimension for + inputs that have shape (batch_size, channels, height, width).""" + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +@MODELS.register_module() +class SequentialNeck(BaseModule): + def __init__(self, necks): + super().__init__() + self.necks = nn.ModuleList() + for neck in necks: + self.necks.append(MODELS.build(neck)) + + def forward(self, *args, **kwargs): + for neck in self.necks: + args = neck(*args, **kwargs) + return args + + +@MODELS.register_module() +class SimpleFPN(BaseModule): + def __init__(self, + backbone_channel: int, + in_channels: List[int], + out_channels: int, + num_outs: int, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = None, + init_cfg: MultiConfig = None) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.backbone_channel = backbone_channel + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(self.backbone_channel, + self.backbone_channel // 2, 2, 2), + build_norm_layer(norm_cfg, self.backbone_channel // 2)[1], + nn.GELU(), + nn.ConvTranspose2d(self.backbone_channel // 2, + self.backbone_channel // 4, 2, 2)) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(self.backbone_channel, + self.backbone_channel // 2, 2, 2)) + self.fpn3 = nn.Sequential(nn.Identity()) + self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2)) + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.num_ins): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + def forward(self, input: Tensor) -> tuple: + # build FPN + inputs = [] + inputs.append(self.fpn1(input)) + inputs.append(self.fpn2(input)) + inputs.append(self.fpn3(input)) + inputs.append(self.fpn4(input)) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build outputs + # part 1: from original levels + outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)] + + # part 2: add extra levels + if self.num_outs > len(outs): + for i in range(self.num_outs - self.num_ins): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + return tuple(outs) + + +@MODELS.register_module() +class TimeFusionTransformerEncoderLayer(BaseModule): + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + use_rel_pos: bool = False, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.window_size = window_size + + self.ln1 = build_norm_layer_mmpretrain(norm_cfg, self.embed_dims) + + self.attn = Attention( + embed_dims=embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + input_size=input_size if window_size == 0 else + (window_size, window_size), + ) + + self.ln2 = build_norm_layer_mmpretrain(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + if self.window_size > 0: + in_channels = embed_dims * 2 + self.down_channel = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, bias=False) + self.down_channel.weight.data.fill_(1.0/in_channels) + + self.soft_ffn = nn.Sequential( + nn.Conv2d(embed_dims, embed_dims, kernel_size=1, stride=1), + nn.GELU(), + nn.Conv2d(embed_dims, embed_dims, kernel_size=1, stride=1), + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x): + shortcut = x + x = self.ln1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = shortcut + x + + x = self.ffn(self.ln2(x), identity=x) + # # time phase fusion + if self.window_size > 0: + x = einops.rearrange(x, 'b h w d -> b d h w') # 2B, C, H, W + x0 = x[:x.size(0)//2] + x1 = x[x.size(0)//2:] # B, C, H, W + x0_1 = torch.cat([x0, x1], dim=1) + activate_map = self.down_channel(x0_1) + activate_map = torch.sigmoid(activate_map) + x0 = x0 + self.soft_ffn(x1 * activate_map) + x1 = x1 + self.soft_ffn(x0 * activate_map) + x = torch.cat([x0, x1], dim=0) + x = einops.rearrange(x, 'b d h w -> b h w d') + return x \ No newline at end of file diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2af58c6e0316d6f961df81160f3fc61a8a29e3 --- /dev/null +++ b/mmseg/utils/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable +from .class_names import (ade_classes, ade_palette, bdd100k_classes, + bdd100k_palette, cityscapes_classes, + cityscapes_palette, cocostuff_classes, + cocostuff_palette, dataset_aliases, get_classes, + get_palette, isaid_classes, isaid_palette, + loveda_classes, loveda_palette, potsdam_classes, + potsdam_palette, stare_classes, stare_palette, + synapse_classes, synapse_palette, vaihingen_classes, + vaihingen_palette, voc_classes, voc_palette) +# yapf: enable +from .collect_env import collect_env +from .get_templates import get_predefined_templates +from .io import datafrombytes +from .misc import add_prefix, stack_batch +from .set_env import register_all_modules +from .tokenizer import tokenize +from .typing_utils import (ConfigType, ForwardResults, MultiConfig, + OptConfigType, OptMultiConfig, OptSampleList, + SampleList, TensorDict, TensorList) + +# isort: off +from .mask_classification import MatchMasks, seg_data_to_instance_data + +__all__ = [ + 'collect_env', + 'register_all_modules', + 'stack_batch', + 'add_prefix', + 'ConfigType', + 'OptConfigType', + 'MultiConfig', + 'OptMultiConfig', + 'SampleList', + 'OptSampleList', + 'TensorDict', + 'TensorList', + 'ForwardResults', + 'cityscapes_classes', + 'ade_classes', + 'voc_classes', + 'cocostuff_classes', + 'loveda_classes', + 'potsdam_classes', + 'vaihingen_classes', + 'isaid_classes', + 'stare_classes', + 'cityscapes_palette', + 'ade_palette', + 'voc_palette', + 'cocostuff_palette', + 'loveda_palette', + 'potsdam_palette', + 'vaihingen_palette', + 'isaid_palette', + 'stare_palette', + 'dataset_aliases', + 'get_classes', + 'get_palette', + 'datafrombytes', + 'synapse_palette', + 'synapse_classes', + 'get_predefined_templates', + 'tokenize', + 'seg_data_to_instance_data', + 'MatchMasks', + 'bdd100k_classes', + 'bdd100k_palette', +] diff --git a/mmseg/utils/__pycache__/__init__.cpython-311.pyc b/mmseg/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..296dbe121802bd67a009c5358ca77144cdbb2100 Binary files /dev/null and b/mmseg/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/class_names.cpython-311.pyc b/mmseg/utils/__pycache__/class_names.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce3ed1338f31ee1f590fb50a4b20fb68d8eb60dd Binary files /dev/null and b/mmseg/utils/__pycache__/class_names.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/collect_env.cpython-311.pyc b/mmseg/utils/__pycache__/collect_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce5788dd02669f67ba3a5332a499bec4f7a08836 Binary files /dev/null and b/mmseg/utils/__pycache__/collect_env.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/get_templates.cpython-311.pyc b/mmseg/utils/__pycache__/get_templates.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac4d08c7eaa2b745d13c72786bcbf94f83f26cb5 Binary files /dev/null and b/mmseg/utils/__pycache__/get_templates.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/io.cpython-311.pyc b/mmseg/utils/__pycache__/io.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f73d5515229a55d8e4c68be96fb52b596184f258 Binary files /dev/null and b/mmseg/utils/__pycache__/io.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/mask_classification.cpython-311.pyc b/mmseg/utils/__pycache__/mask_classification.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ea19c2416d672fdf76ccb60fed58e84901e904b Binary files /dev/null and b/mmseg/utils/__pycache__/mask_classification.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/misc.cpython-311.pyc b/mmseg/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd66d57fa1e83654fae7ce39bbbf6f1ec169f6ce Binary files /dev/null and b/mmseg/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/set_env.cpython-311.pyc b/mmseg/utils/__pycache__/set_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9bb97db52b1654d412cdbcdd372e9cf93d43bed Binary files /dev/null and b/mmseg/utils/__pycache__/set_env.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/tokenizer.cpython-311.pyc b/mmseg/utils/__pycache__/tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76673ee9e4fdbd661a5a0a711daed25e7aa98f79 Binary files /dev/null and b/mmseg/utils/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/mmseg/utils/__pycache__/typing_utils.cpython-311.pyc b/mmseg/utils/__pycache__/typing_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebb9d19743743250856832061d9edf6ef1b5cc14 Binary files /dev/null and b/mmseg/utils/__pycache__/typing_utils.cpython-311.pyc differ diff --git a/mmseg/utils/bpe_simple_vocab_16e6.txt.gz b/mmseg/utils/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/mmseg/utils/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab35f99dcabd886b40e88188d9395fff557ffc2 --- /dev/null +++ b/mmseg/utils/class_names.py @@ -0,0 +1,529 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import is_str + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def pcontext_classes(): + """Pascal Context class names for external use.""" + return [ + 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', + 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', + 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', + 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', + 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', + 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', + 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', + 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', + 'wood' + ] + + +def cocostuff_classes(): + """CocoStuff class names for external use.""" + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper', + 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', + 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', + 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', + 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', + 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', + 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', + 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood' + ] + + +def loveda_classes(): + """LoveDA class names for external use.""" + return [ + 'background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural' + ] + + +def potsdam_classes(): + """Potsdam class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def vaihingen_classes(): + """Vaihingen class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def isaid_classes(): + """iSAID class names for external use.""" + return [ + 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court', + 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle', + 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout', + 'Soccer_ball_field', 'plane', 'Harbor' + ] + + +def stare_classes(): + """stare class names for external use.""" + return ['background', 'vessel'] + + +def mapillary_v1_classes(): + """mapillary_v1 class names for external use.""" + return [ + 'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', + 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', + 'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', + 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General', + 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', + 'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', + 'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', + 'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame', + 'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)', + 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', + 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', + 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled' + ] + + +def mapillary_v1_palette(): + """mapillary_v1_ palette for external use.""" + return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], + [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], + [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], + [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], + [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], + [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], + [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], + [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], + [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], + [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], + [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], + [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]] + + +def mapillary_v2_classes(): + """mapillary_v2 class names for external use.""" + return [ + 'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb', + 'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side', + 'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane', + 'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking', + 'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road', + 'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island', + 'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group', + 'Bicyclist', 'Motorcyclist', 'Other Rider', + 'Lane Marking - Dashed Line', 'Lane Marking - Straight Line', + 'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous', + 'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)', + 'Lane Marking - Arrow (Right)', + 'Lane Marking - Arrow (Split Left or Straight)', + 'Lane Marking - Arrow (Split Right or Straight)', + 'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk', + 'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)', + 'Lane Marking - Hatched (Chevron)', + 'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other', + 'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)', + 'Lane Marking - Symbol (Other)', 'Lane Marking - Text', + 'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk', + 'Lane Marking (only) - Other', 'Lane Marking (only) - Test', + 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', + 'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter', + 'Phone Booth', 'Pothole', 'Signage - Advertisement', + 'Signage - Ambiguous', 'Signage - Back', 'Signage - Information', + 'Signage - Other', 'Signage - Store', 'Street Light', 'Pole', + 'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone', + 'Traffic Light - General (Single)', 'Traffic Light - Pedestrians', + 'Traffic Light - General (Upright)', + 'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists', + 'Traffic Light - Other', 'Traffic Sign - Ambiguous', + 'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)', + 'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)', + 'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)', + 'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', + 'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve', + 'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled' + ] + + +def mapillary_v2_palette(): + """mapillary_v2_ palette for external use.""" + return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], + [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], + [250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35], + [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110], + [244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70], + [150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255], + [255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26], + [250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21], + [250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18], + [250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255], + [250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64], + [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], + [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], + [250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128], + [0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30], + [192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0], + [220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0], + [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], + [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], + [0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], + [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +def pcontext_palette(): + """Pascal Context palette for external use.""" + return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], + [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], + [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], + [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], + [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], + [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], + [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], + [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], + [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], + [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], + [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], + [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], + [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], + [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + +def cocostuff_palette(): + """CocoStuff palette for external use.""" + return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], + [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], + [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], + [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], + [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], + [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], + [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], + [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], + [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], + [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], + [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], + [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], + [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], + [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0], + [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96], + [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128], + [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64], + [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96], + [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0], + [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64], + [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96], + [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128], + [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0], + [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32], + [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64], + [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0], + [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32], + [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192], + [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64], + [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32], + [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64], + [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64], + [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32], + [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192], + [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0], + [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96], + [64, 160, 64], [64, 64, 0]] + + +def loveda_palette(): + """LoveDA palette for external use.""" + return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]] + + +def potsdam_palette(): + """Potsdam palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def vaihingen_palette(): + """Vaihingen palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def isaid_palette(): + """iSAID palette for external use.""" + return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, + 127], [0, 0, 127], + [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191], + [0, 127, 255], [0, 100, 155]] + + +def stare_palette(): + """STARE palette for external use.""" + return [[120, 120, 120], [6, 230, 230]] + + +def synapse_palette(): + """Synapse palette for external use.""" + return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255], + [255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]] + + +def synapse_classes(): + """Synapse class names for external use.""" + return [ + 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach' + ] + + +def lip_classes(): + """LIP class names for external use.""" + return [ + 'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', + 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', + 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', + 'rightShoe' + ] + + +def lip_palette(): + """LIP palette for external use.""" + return [ + 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes', + 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', + 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', + 'Right-shoe' + ] + + +def bdd100k_classes(): + """BDD100K class names for external use(the class name is compatible with + Cityscapes ).""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def bdd100k_palette(): + """bdd100k palette for external use(same with cityscapes)""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], + 'pcontext': ['pcontext', 'pascal_context', 'voc2010'], + 'loveda': ['loveda'], + 'potsdam': ['potsdam'], + 'vaihingen': ['vaihingen'], + 'cocostuff': [ + 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', + 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', + 'coco_stuff164k' + ], + 'isaid': ['isaid', 'iSAID'], + 'stare': ['stare', 'STARE'], + 'lip': ['LIP', 'lip'], + 'mapillary_v1': ['mapillary_v1'], + 'mapillary_v2': ['mapillary_v2'], + 'bdd100k': ['bdd100k'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/mmseg/utils/collect_env.py b/mmseg/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d6ea290283e3af2f29475f82d225072cf39d99 --- /dev/null +++ b/mmseg/utils/collect_env.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/mmseg/utils/get_templates.py b/mmseg/utils/get_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9032ba96cbe750134676fe46fc26fb607779f5 --- /dev/null +++ b/mmseg/utils/get_templates.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +PREDEFINED_TEMPLATES = { + 'imagenet': [ + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', + ], + 'vild': [ + 'a photo of a {}.', + 'This is a photo of a {}', + 'There is a {} in the scene', + 'There is the {} in the scene', + 'a photo of a {} in the scene', + 'a photo of a small {}.', + 'a photo of a medium {}.', + 'a photo of a large {}.', + 'This is a photo of a small {}.', + 'This is a photo of a medium {}.', + 'This is a photo of a large {}.', + 'There is a small {} in the scene.', + 'There is a medium {} in the scene.', + 'There is a large {} in the scene.', + ], +} + + +def get_predefined_templates(template_set_name: str) -> List[str]: + if template_set_name not in PREDEFINED_TEMPLATES: + raise ValueError(f'Template set {template_set_name} not found') + return PREDEFINED_TEMPLATES[template_set_name] diff --git a/mmseg/utils/io.py b/mmseg/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..7029c3cddda02c89cbb50cee9f8b7e7fa57378d9 --- /dev/null +++ b/mmseg/utils/io.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import io +import pickle + +import cv2 +import numpy as np + + +def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray: + """Data decoding from bytes. + + Args: + content (bytes): The data bytes got from files or other streams. + backend (str): The data decoding backend type. Options are 'numpy', + 'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'. + + Returns: + numpy.ndarray: Loaded data array. + """ + if backend == 'pickle': + data = pickle.loads(content) + else: + with io.BytesIO(content) as f: + if backend == 'nifti': + f = gzip.open(f) + try: + from nibabel import FileHolder, Nifti1Image + except ImportError: + print('nifti files io depends on nibabel, please run' + '`pip install nibabel` to install it') + fh = FileHolder(fileobj=f) + data = Nifti1Image.from_file_map({'header': fh, 'image': fh}) + data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata() + elif backend == 'numpy': + data = np.load(f) + elif backend == 'cv2': + data = np.frombuffer(f.read(), dtype=np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED) + else: + raise ValueError + return data diff --git a/mmseg/utils/mask_classification.py b/mmseg/utils/mask_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..205d5259754abfe07e0d84ae0739cf08043815ff --- /dev/null +++ b/mmseg/utils/mask_classification.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from mmcv.ops import point_sample +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import TASK_UTILS +from mmseg.utils import ConfigType, SampleList + + +def seg_data_to_instance_data(ignore_index: int, + batch_data_samples: SampleList): + """Convert the paradigm of ground truth from semantic segmentation to + instance segmentation. + + Args: + ignore_index (int): The label index to be ignored. + batch_data_samples (List[SegDataSample]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + - batch_gt_instances (List[InstanceData]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (List[Dict]): List of image meta information. + """ + batch_gt_instances = [] + + for data_sample in batch_data_samples: + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros( + (0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg).long() + else: + gt_masks = torch.stack(masks).squeeze(1).long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances + + +class MatchMasks: + """Match the predictions to category labels. + + Args: + num_points (int): the number of sampled points to compute cost. + num_queries (int): the number of prediction masks. + num_classes (int): the number of classes. + assigner (BaseAssigner): the assigner to compute matching. + """ + + def __init__(self, + num_points: int, + num_queries: int, + num_classes: int, + assigner: ConfigType = None): + assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \ + 'cannot be None' + assert num_points > 0, 'num_points should be a positive integer.' + self.num_points = num_points + self.num_queries = num_queries + self.num_classes = num_classes + self.assigner = TASK_UTILS.build(assigner) + + def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor], + batch_gt_instances: List[InstanceData]) -> Tuple: + """Compute best mask matches for all images for a decoder layer. + + Args: + cls_scores (List[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds (List[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (List[InstanceData]): each contains + ``labels`` and ``masks``. + + Returns: + tuple: a tuple containing the following targets. + + - labels (List[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - mask_targets (List[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights (List[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to + average the loss. `avg_factor` is usually equal + to the number of positive priors. + """ + batch_size = cls_scores.shape[0] + results = dict({ + 'labels': [], + 'mask_targets': [], + 'mask_weights': [], + }) + for i in range(batch_size): + labels, mask_targets, mask_weights\ + = self._get_targets_single(cls_scores[i], + mask_preds[i], + batch_gt_instances[i]) + results['labels'].append(labels) + results['mask_targets'].append(mask_targets) + results['mask_weights'].append(mask_weights) + + # shape (batch_size, num_queries) + labels = torch.stack(results['labels'], dim=0) + # shape (batch_size, num_gts, h, w) + mask_targets = torch.cat(results['mask_targets'], dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(results['mask_weights'], dim=0) + + avg_factor = sum( + [len(gt_instances.labels) for gt_instances in batch_gt_instances]) + + res = (labels, mask_targets, mask_weights, avg_factor) + + return res + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData) \ + -> Tuple[Tensor, Tensor, Tensor]: + """Compute a set of best mask matches for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + """ + gt_labels = gt_instances.labels + gt_masks = gt_instances.masks + # when "gt_labels" is empty, classify all queries to background + if len(gt_labels) == 0: + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + mask_targets = gt_labels + mask_weights = gt_labels.new_zeros((self.num_queries, )) + return labels, mask_targets, mask_weights + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + matched_quiery_inds, matched_label_inds = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances) + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[matched_quiery_inds] = gt_labels[matched_label_inds] + + mask_weights = gt_labels.new_zeros((self.num_queries, )) + mask_weights[matched_quiery_inds] = 1 + mask_targets = gt_masks[matched_label_inds] + + return labels, mask_targets, mask_weights diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc469e8320d375135846cfb0474a0fc8d9b15d0 --- /dev/null +++ b/mmseg/utils/misc.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from .typing_utils import SampleList + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs + + +def stack_batch(inputs: List[torch.Tensor], + data_samples: Optional[SampleList] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Union[int, float] = 0, + seg_pad_val: Union[int, float] = 255) -> torch.Tensor: + """Stack multiple inputs to form a batch and pad the images and gt_sem_segs + to the max shape use the right bottom padding mode. + + Args: + inputs (List[Tensor]): The input multiple tensors. each is a + CHW 3D-tensor. + data_samples (list[:obj:`SegDataSample`]): The list of data samples. + It usually includes information such as `gt_sem_seg`. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (int, float): The padding value. Defaults to 0 + seg_pad_val (int, float): The padding value. Defaults to 255 + + Returns: + Tensor: The 4D-tensor. + List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. + """ + assert isinstance(inputs, list), \ + f'Expected input type to be list, but got {type(inputs)}' + assert len({tensor.ndim for tensor in inputs}) == 1, \ + f'Expected the dimensions of all inputs must be the same, ' \ + f'but got {[tensor.ndim for tensor in inputs]}' + assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ + f'but got {inputs[0].ndim}' + assert len({tensor.shape[0] for tensor in inputs}) == 1, \ + f'Expected the channels of all inputs must be the same, ' \ + f'but got {[tensor.shape[0] for tensor in inputs]}' + + # only one of size and size_divisor should be valid + assert (size is not None) ^ (size_divisor is not None), \ + 'only one of size and size_divisor should be valid' + + padded_inputs = [] + padded_samples = [] + inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] + max_size = np.stack(inputs_sizes).max(0) + if size_divisor is not None and size_divisor > 1: + # the last two dims are H,W, both subject to divisibility requirement + max_size = (max_size + + (size_divisor - 1)) // size_divisor * size_divisor + + for i in range(len(inputs)): + tensor = inputs[i] + if size is not None: + width = max(size[-1] - tensor.shape[-1], 0) + height = max(size[-2] - tensor.shape[-2], 0) + # (padding_left, padding_right, padding_top, padding_bottom) + padding_size = (0, width, 0, height) + elif size_divisor is not None: + width = max(max_size[-1] - tensor.shape[-1], 0) + height = max(max_size[-2] - tensor.shape[-2], 0) + padding_size = (0, width, 0, height) + else: + padding_size = [0, 0, 0, 0] + + # pad img + pad_img = F.pad(tensor, padding_size, value=pad_val) + padded_inputs.append(pad_img) + # pad gt_sem_seg + if data_samples is not None: + data_sample = data_samples[i] + pad_shape = None + if 'gt_sem_seg' in data_sample: + gt_sem_seg = data_sample.gt_sem_seg.data + del data_sample.gt_sem_seg.data + data_sample.gt_sem_seg.data = F.pad( + gt_sem_seg, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_sem_seg.shape + if 'gt_edge_map' in data_sample: + gt_edge_map = data_sample.gt_edge_map.data + del data_sample.gt_edge_map.data + data_sample.gt_edge_map.data = F.pad( + gt_edge_map, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_edge_map.shape + if 'gt_depth_map' in data_sample: + gt_depth_map = data_sample.gt_depth_map.data + del data_sample.gt_depth_map.data + data_sample.gt_depth_map.data = F.pad( + gt_depth_map, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_depth_map.shape + data_sample.set_metainfo({ + 'img_shape': tensor.shape[-2:], + 'pad_shape': pad_shape, + 'padding_size': padding_size + }) + padded_samples.append(data_sample) + else: + padded_samples.append( + dict( + img_padding_size=padding_size, + pad_shape=pad_img.shape[-2:])) + + return torch.stack(padded_inputs, dim=0), padded_samples diff --git a/mmseg/utils/set_env.py b/mmseg/utils/set_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c948950d62a7463295c1055a27a9a0ce881d9fad --- /dev/null +++ b/mmseg/utils/set_env.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings + +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmseg into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmseg default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmseg`, and all registries will build modules from mmseg's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmseg.datasets # noqa: F401,F403 + import mmseg.engine # noqa: F401,F403 + import mmseg.evaluation # noqa: F401,F403 + import mmseg.models # noqa: F401,F403 + import mmseg.structures # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmseg') + if never_created: + DefaultScope.get_instance('mmseg', scope_name='mmseg') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmseg': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmseg", ' + '`register_all_modules` will force the current' + 'default scope to be "mmseg". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmseg-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmseg') diff --git a/mmseg/utils/tokenizer.py b/mmseg/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d56f5fae602506a27b9ae8835415e8dea7b611b7 --- /dev/null +++ b/mmseg/utils/tokenizer.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""CLIP tokenizer. + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright +(c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import List, Union + +import ftfy +import regex as re +import torch + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """Returns list of utf-8 byte and a corresponding list of unicode strings. + + The reversible bpe codes work on unicode strings. This means you need a + large # of unicode characters in your vocab if you want to avoid UNKs. When + you're at something like a 10B token dataset you end up needing around 5K + for decent coverage. This is a significant percentage of your normal, say, + 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and + unicode strings. And avoids mapping to whitespace/control characters the + bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer: + + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', '' + ] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = '|'.join(special_tokens) + self.pat = re.compile( + special + + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: # noqa: E722, E261 + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + + +def tokenize(texts: Union[str, List[str]], + context_length: int = 77) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, + shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[''] + eot_token = _tokenizer.encoder[''] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper.""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, + texts: Union[str, List[str]], + context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it + # more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/mmseg/utils/typing_utils.py b/mmseg/utils/typing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fba7d3b92bba8301171d2a0fffadfabfcd112976 --- /dev/null +++ b/mmseg/utils/typing_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Collecting some commonly used type hint in mmflow.""" +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +from mmengine.config import ConfigDict + +from mmseg.structures import SegDataSample + +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] +# Type hint of one or more config data +MultiConfig = Union[ConfigType, Sequence[ConfigType]] +OptMultiConfig = Optional[MultiConfig] + +SampleList = Sequence[SegDataSample] +OptSampleList = Optional[SampleList] + +# Type hint of Tensor +TensorDict = Dict[str, torch.Tensor] +TensorList = Sequence[torch.Tensor] + +ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample], + Tuple[torch.Tensor], torch.Tensor] diff --git a/mmseg/version.py b/mmseg/version.py new file mode 100644 index 0000000000000000000000000000000000000000..b76bb4580ddfa0ba0ba13fa4896c49bac9cef65a --- /dev/null +++ b/mmseg/version.py @@ -0,0 +1,18 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '1.2.2' + + +def parse_version_info(version_str): + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/mmseg/visualization/__init__.py b/mmseg/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbb211e5243aafb4ab3d91f6a6f7ce0735b13a9 --- /dev/null +++ b/mmseg/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import SegLocalVisualizer + +__all__ = ['SegLocalVisualizer'] diff --git a/mmseg/visualization/__pycache__/__init__.cpython-311.pyc b/mmseg/visualization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1461966940c43fa78d3450f5884f4237650b74b9 Binary files /dev/null and b/mmseg/visualization/__pycache__/__init__.cpython-311.pyc differ diff --git a/mmseg/visualization/__pycache__/local_visualizer.cpython-311.pyc b/mmseg/visualization/__pycache__/local_visualizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0674e245d34d23b4971c9d4ec1c06b6ff1038c Binary files /dev/null and b/mmseg/visualization/__pycache__/local_visualizer.cpython-311.pyc differ diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3d652c7bbe9d93ca481fb7a7ed4bb976eec80d --- /dev/null +++ b/mmseg/visualization/local_visualizer.py @@ -0,0 +1,349 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.dist import master_only +from mmengine.structures import PixelData +from mmengine.visualization import Visualizer + +from mmseg.registry import VISUALIZERS +from mmseg.structures import SegDataSample +from mmseg.utils import get_classes, get_palette + + +@VISUALIZERS.register_module() +class SegLocalVisualizer(Visualizer): + """Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + alpha (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Examples: + >>> import numpy as np + >>> import torch + >>> from mmengine.structures import PixelData + >>> from mmseg.structures import SegDataSample + >>> from mmseg.visualization import SegLocalVisualizer + + >>> seg_local_visualizer = SegLocalVisualizer() + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> gt_seg_data_sample = SegDataSample() + >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg + >>> seg_local_visualizer.dataset_meta = dict( + >>> classes=('background', 'foreground'), + >>> palette=[[120, 120, 120], [6, 230, 230]]) + >>> seg_local_visualizer.add_datasample('visualizer_example', + ... image, gt_seg_data_sample) + >>> seg_local_visualizer.add_datasample( + ... 'visualizer_example', image, + ... gt_seg_data_sample, show=True) + """ # noqa + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + classes: Optional[List] = None, + palette: Optional[List] = None, + dataset_name: Optional[str] = None, + alpha: float = 0.8, + **kwargs): + super().__init__(name, image, vis_backends, save_dir, **kwargs) + self.alpha: float = alpha + self.set_dataset_meta(palette, classes, dataset_name) + + def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: + """Get semantic seg center coordinate. + + Args: + mask: np.ndarray: get from sem_seg + """ + loc = np.argwhere(mask == 1) + + loc_sort = np.array( + sorted(loc.tolist(), key=lambda row: (row[0], row[1]))) + y_list = loc_sort[:, 0] + unique, indices, counts = np.unique( + y_list, return_index=True, return_counts=True) + y_loc = unique[counts.argmax()] + y_most_freq_loc = loc[loc_sort[:, 0] == y_loc] + center_num = len(y_most_freq_loc) // 2 + x = y_most_freq_loc[center_num][1] + y = y_most_freq_loc[center_num][0] + return np.array([x, y]) + + def _draw_sem_seg(self, + image: np.ndarray, + sem_seg: PixelData, + classes: Optional[List], + palette: Optional[List], + with_labels: Optional[bool] = True) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for pixel-level + annotations or predictions. + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + with_labels(bool, optional): Add semantic labels in visualization + result, Default to True. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + num_classes = len(classes) + + sem_seg = sem_seg.cpu().data + ids = np.unique(sem_seg)[::-1] + legal_indices = ids < num_classes + ids = ids[legal_indices] + labels = np.array(ids, dtype=np.int64) + + colors = [palette[label] for label in labels] + + mask = np.zeros_like(image, dtype=np.uint8) + for label, color in zip(labels, colors): + mask[sem_seg[0] == label, :] = color + + if with_labels: + font = cv2.FONT_HERSHEY_SIMPLEX + # (0,1] to change the size of the text relative to the image + scale = 0.05 + fontScale = min(image.shape[0], image.shape[1]) / (25 / scale) + fontColor = (255, 255, 255) + if image.shape[0] < 300 or image.shape[1] < 300: + thickness = 1 + rectangleThickness = 1 + else: + thickness = 2 + rectangleThickness = 2 + lineType = 2 + + if isinstance(sem_seg[0], torch.Tensor): + masks = sem_seg[0].numpy() == labels[:, None, None] + else: + masks = sem_seg[0] == labels[:, None, None] + masks = masks.astype(np.uint8) + for mask_num in range(len(labels)): + classes_id = labels[mask_num] + classes_color = colors[mask_num] + loc = self._get_center_loc(masks[mask_num]) + text = classes[classes_id] + (label_width, label_height), baseline = cv2.getTextSize( + text, font, fontScale, thickness) + mask = cv2.rectangle(mask, loc, + (loc[0] + label_width + baseline, + loc[1] + label_height + baseline), + classes_color, -1) + mask = cv2.rectangle(mask, loc, + (loc[0] + label_width + baseline, + loc[1] + label_height + baseline), + (0, 0, 0), rectangleThickness) + mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height), + font, fontScale, fontColor, thickness, + lineType) + color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( + np.uint8) + self.set_image(color_seg) + return color_seg + + def _draw_depth_map(self, image: np.ndarray, + depth_map: PixelData) -> np.ndarray: + """Draws a depth map on a given image. + + This function takes an image and a depth map as input, + renders the depth map, and concatenates it with the original image. + Finally, it updates the internal image state of the visualizer with + the concatenated result. + + Args: + image (np.ndarray): The original image where the depth map will + be drawn. The array should be in the format HxWx3 where H is + the height, W is the width. + + depth_map (PixelData): Depth map to be drawn. The depth map + should be in the form of a PixelData object. It will be + converted to a torch tensor if it is a numpy array. + + Returns: + np.ndarray: The concatenated image with the depth map drawn. + + Example: + >>> depth_map_data = PixelData(data=torch.rand(1, 10, 10)) + >>> image = np.random.randint(0, 256, + >>> size=(10, 10, 3)).astype('uint8') + >>> visualizer = SegLocalVisualizer() + >>> visualizer._draw_depth_map(image, depth_map_data) + """ + depth_map = depth_map.cpu().data + if isinstance(depth_map, np.ndarray): + depth_map = torch.from_numpy(depth_map) + if depth_map.ndim == 2: + depth_map = depth_map[None] + + depth_map = self.draw_featmap(depth_map, resize_shape=image.shape[:2]) + out_image = np.concatenate((image, depth_map), axis=0) + self.set_image(out_image) + return out_image + + def set_dataset_meta(self, + classes: Optional[List] = None, + palette: Optional[List] = None, + dataset_name: Optional[str] = None) -> None: + """Set meta information to visualizer. + + Args: + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. + classes and palette, but the `classes` and `palette` have + higher priority. Defaults to None. + """ # noqa + # Set default value. When calling + # `SegLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + if dataset_name is None: + dataset_name = 'cityscapes' + classes = classes if classes else get_classes(dataset_name) + palette = palette if palette else get_palette(dataset_name) + assert len(classes) == len( + palette), 'The length of classes should be equal to palette' + self.dataset_meta: dict = {'classes': classes, 'palette': palette} + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: Optional[SegDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + step: int = 0, + with_labels: Optional[bool] = True) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. it is usually used when the display + is not available. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. + Defaults to None. + pred_sample (:obj:`SegDataSample`, optional): Prediction + SegDataSample. Defaults to None. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + step (int): Global step value to record. Defaults to 0. + with_labels(bool, optional): Add semantic labels in visualization + result, Defaults to True. + """ + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + + gt_img_data = None + pred_img_data = None + + if draw_gt and data_sample is not None: + if 'gt_sem_seg' in data_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg, + classes, palette, with_labels) + + if 'gt_depth_map' in data_sample: + gt_img_data = gt_img_data if gt_img_data is not None else image + gt_img_data = self._draw_depth_map(gt_img_data, + data_sample.gt_depth_map) + + if draw_pred and data_sample is not None: + + if 'pred_sem_seg' in data_sample: + + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + pred_img_data = self._draw_sem_seg(image, + data_sample.pred_sem_seg, + classes, palette, + with_labels) + + if 'pred_depth_map' in data_sample: + pred_img_data = pred_img_data if pred_img_data is not None \ + else image + pred_img_data = self._draw_depth_map( + pred_img_data, data_sample.pred_depth_map) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file) + else: + self.add_image(name, drawn_img, step) diff --git a/opencd/.DS_Store b/opencd/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..626cdcea2d88f4777aa5a92bd043faf9025df5fa Binary files /dev/null and b/opencd/.DS_Store differ diff --git a/opencd/__init__.py b/opencd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17cd20ada50ad7743bbc5c3577b7c4e6de220885 --- /dev/null +++ b/opencd/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) Open-CD. All rights reserved. +import mmcv +import mmdet +import mmengine +from mmengine.utils import digit_version + +import mmseg +from .version import __version__, version_info + +mmcv_minimum_version = '2.0.0rc4' +mmcv_maximum_version = '2.2.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.6.0' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +mmseg_minimum_version = '1.0.0rc6' +mmseg_maximum_version = '1.3.0' +mmseg_version = digit_version(mmseg.__version__) + +mmdet_minimum_version = '3.0.0rc6' +mmdet_maximum_version = '4.0.0' +mmdet_version = digit_version(mmdet.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<{mmengine_maximum_version}.' + +assert (mmseg_version >= digit_version(mmseg_minimum_version) + and mmseg_version < digit_version(mmseg_maximum_version)), \ + f'MMSegmentation=={mmseg.__version__} is used but incompatible. ' \ + f'Please install mmseg>={mmseg_minimum_version}, ' \ + f'<{mmseg_maximum_version}.' + +assert (mmdet_version >= digit_version(mmdet_minimum_version) + and mmdet_version < digit_version(mmdet_maximum_version)), \ + f'MMDetection=={mmdet.__version__} is used but incompatible. ' \ + f'Please install mmdet>={mmdet_minimum_version}, ' \ + f'<{mmdet_maximum_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/opencd/__pycache__/__init__.cpython-311.pyc b/opencd/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aea718eb777ea847e715b1807185db7ccfc57f06 Binary files /dev/null and b/opencd/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/__pycache__/registry.cpython-311.pyc b/opencd/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68891d3ba989d2863b726240cd4afff709de06c4 Binary files /dev/null and b/opencd/__pycache__/registry.cpython-311.pyc differ diff --git a/opencd/__pycache__/version.cpython-311.pyc b/opencd/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e87450d9f5e7c8b4bcad98f4f6c72356c5cc4f4 Binary files /dev/null and b/opencd/__pycache__/version.cpython-311.pyc differ diff --git a/opencd/apis/__init__.py b/opencd/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bebfc72c8512990733d0341ab20330a3dac71cca --- /dev/null +++ b/opencd/apis/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-CD. All rights reserved. +from .opencd_inferencer import OpenCDInferencer + +__all__ = ['OpenCDInferencer'] \ No newline at end of file diff --git a/opencd/apis/__pycache__/__init__.cpython-311.pyc b/opencd/apis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8656c3ff9dca1b75769452119af06d8dab307b33 Binary files /dev/null and b/opencd/apis/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/apis/__pycache__/opencd_inferencer.cpython-311.pyc b/opencd/apis/__pycache__/opencd_inferencer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90460b8993680c5333108b17f46b01457c8e07ca Binary files /dev/null and b/opencd/apis/__pycache__/opencd_inferencer.cpython-311.pyc differ diff --git a/opencd/apis/opencd_inferencer.py b/opencd/apis/opencd_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..63de7e3db06841f8662b9f442eef06c21f6f445c --- /dev/null +++ b/opencd/apis/opencd_inferencer.py @@ -0,0 +1,168 @@ +# Copyright (c) Open-CD. All rights reserved. +import os.path as osp +from typing import List, Optional, Union + +import mmcv +import mmengine +import numpy as np + +from mmcv.transforms import Compose + +from mmseg.utils import ConfigType +from mmseg.apis import MMSegInferencer + +class OpenCDInferencer(MMSegInferencer): + """Change Detection inferencer, provides inference and visualization + interfaces. Note: MMEngine >= 0.5.0 is required. + + Args: + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. If palette is + not defined, visualizer will take `cityscapes` palette by default. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias. + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + scope (str, optional): The scope of the model. Defaults to 'opencd'. + """ # noqa + + def __init__(self, + classes: Optional[Union[str, List]] = None, + palette: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, + scope: Optional[str] = 'opencd', + **kwargs) -> None: + super().__init__(scope=scope, **kwargs) + + classes = classes if classes else self.model.dataset_meta.classes + palette = palette if palette else self.model.dataset_meta.palette + self.visualizer.set_dataset_meta(classes, palette, dataset_name) + + def _inputs_to_list(self, inputs: Union[str, np.ndarray]) -> list: + """Preprocess the inputs to a list. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - str: + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (InputsType): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + return list(inputs) + + def visualize(self, + inputs: list, + preds: List[dict], + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + img_out_dir: str = '', + opacity: float = 1.0) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + img_out_dir (str): Output directory of rendering prediction i.e. + color segmentation mask. Defaults: '' + opacity (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Returns: + List[np.ndarray]: Visualization results. + """ + if not show and img_out_dir == '' and not return_vis: + return None + if self.visualizer is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + self.visualizer.alpha = opacity + + results = [] + + for single_inputs, pred in zip(inputs, preds): + img_from_to = [] + for single_input in single_inputs: + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + '_vis' + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type:' + f'{type(single_input)}') + img_shape = img.shape + img_from_to.append(img) + + out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\ + else None + + img_zero_board = np.zeros(img_shape) + self.visualizer.add_datasample( + img_name, + img_zero_board, + img_from_to, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=True, + out_file=out_file) + if return_vis: + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results if return_vis else None + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + # Loading annotations is also not applicable + for transform in ('MultiImgLoadAnnotations', 'MultiImgLoadDepthAnnotation'): + idx = self._get_transform_idx(pipeline_cfg, transform) + if idx != -1: + del pipeline_cfg[idx] + + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'MultiImgLoadImageFromFile') + if load_img_idx == -1: + raise ValueError( + 'MultiImgLoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'MultiImgLoadInferencerLoader' + return Compose(pipeline_cfg) diff --git a/opencd/datasets/.DS_Store b/opencd/datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..949480cd5bf24688e41226857787ebac26df9c19 Binary files /dev/null and b/opencd/datasets/.DS_Store differ diff --git a/opencd/datasets/__init__.py b/opencd/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95f3bd7322072b7e68588433b945a8731c48bae9 --- /dev/null +++ b/opencd/datasets/__init__.py @@ -0,0 +1,16 @@ +from .bandon import BANDON_Dataset +from .basecddataset import _BaseCDDataset +from .basescddataset import BaseSCDDataset +from .clcd import CLCD_Dataset +from .dsifn import DSIFN_Dataset +from .landsat import Landsat_Dataset +from .levir_cd import LEVIR_CD_Dataset +from .rsipac_cd import RSIPAC_CD_Dataset +from .s2looking import S2Looking_Dataset +from .second import SECOND_Dataset +from .svcd import SVCD_Dataset +from .whu_cd import WHU_CD_Dataset + +__all__ = ['_BaseCDDataset', 'BaseSCDDataset', 'LEVIR_CD_Dataset', 'S2Looking_Dataset', + 'SVCD_Dataset', 'RSIPAC_CD_Dataset', 'CLCD_Dataset', 'DSIFN_Dataset', + 'SECOND_Dataset', 'Landsat_Dataset', 'BANDON_Dataset', 'WHU_CD_Dataset'] diff --git a/opencd/datasets/__pycache__/__init__.cpython-311.pyc b/opencd/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cedd931981102e2931af72fe78593b40d8d4277c Binary files /dev/null and b/opencd/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/bandon.cpython-311.pyc b/opencd/datasets/__pycache__/bandon.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72b61d79e9c868b4d6433cde0635d09dffcc6fc2 Binary files /dev/null and b/opencd/datasets/__pycache__/bandon.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/basecddataset.cpython-311.pyc b/opencd/datasets/__pycache__/basecddataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde3338dc4894c07cccae36c2506bbd28e3c35ae Binary files /dev/null and b/opencd/datasets/__pycache__/basecddataset.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/basescddataset.cpython-311.pyc b/opencd/datasets/__pycache__/basescddataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9946670b955590bd1907e7ceed1af8a2a81979fa Binary files /dev/null and b/opencd/datasets/__pycache__/basescddataset.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/clcd.cpython-311.pyc b/opencd/datasets/__pycache__/clcd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfd083c240e0fe9713ca2b3ff5b8f7f164d088ca Binary files /dev/null and b/opencd/datasets/__pycache__/clcd.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/dsifn.cpython-311.pyc b/opencd/datasets/__pycache__/dsifn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c0a4a4eb354199199c1364d60e06391bfb658f2 Binary files /dev/null and b/opencd/datasets/__pycache__/dsifn.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/landsat.cpython-311.pyc b/opencd/datasets/__pycache__/landsat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1b34153aff52849e00dab4473698acd8e62def7 Binary files /dev/null and b/opencd/datasets/__pycache__/landsat.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/levir_cd.cpython-311.pyc b/opencd/datasets/__pycache__/levir_cd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88d6b2e2446d4a9edffc0e1c187ab5ddf3f7a570 Binary files /dev/null and b/opencd/datasets/__pycache__/levir_cd.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/rsipac_cd.cpython-311.pyc b/opencd/datasets/__pycache__/rsipac_cd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8dc154b5af7f9114f33450bc3e7c797ebf1bcd Binary files /dev/null and b/opencd/datasets/__pycache__/rsipac_cd.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/s2looking.cpython-311.pyc b/opencd/datasets/__pycache__/s2looking.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..911953340e950ac651f0f046b39e8a4d730a4458 Binary files /dev/null and b/opencd/datasets/__pycache__/s2looking.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/second.cpython-311.pyc b/opencd/datasets/__pycache__/second.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f71bf9b8c0afdbda6f8c5319ec166b8e0aae9d88 Binary files /dev/null and b/opencd/datasets/__pycache__/second.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/svcd.cpython-311.pyc b/opencd/datasets/__pycache__/svcd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d73187e874d9c5d0808a812705ded6fefd87b90e Binary files /dev/null and b/opencd/datasets/__pycache__/svcd.cpython-311.pyc differ diff --git a/opencd/datasets/__pycache__/whu_cd.cpython-311.pyc b/opencd/datasets/__pycache__/whu_cd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f966e2b8635deb33125df52d7b3e1c3851e1d37 Binary files /dev/null and b/opencd/datasets/__pycache__/whu_cd.cpython-311.pyc differ diff --git a/opencd/datasets/bandon.py b/opencd/datasets/bandon.py new file mode 100644 index 0000000000000000000000000000000000000000..637b619f798bf482a5a1da227d71f3a6009a8919 --- /dev/null +++ b/opencd/datasets/bandon.py @@ -0,0 +1,30 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basescddataset import BaseSCDDataset + + +@DATASETS.register_module() +class BANDON_Dataset(BaseSCDDataset): + """BANDON dataset + + Note: Use `tools/generate_txt/generate_bandon_txt.py` + to generate .txt files for BANDON dataset + + """ + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]], + semantic_classes=('background', 'roofs', 'facades'), + semantic_palette=[[0, 0, 0], [244, 177, 131], [143, 170, 220]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_semantic_zero_label=False, + **kwargs) -> None: + + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_semantic_zero_label=reduce_semantic_zero_label, + **kwargs) \ No newline at end of file diff --git a/opencd/datasets/basecddataset.py b/opencd/datasets/basecddataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2d93f2bb98b663e70e4d3ae98f5b25a83a32514d --- /dev/null +++ b/opencd/datasets/basecddataset.py @@ -0,0 +1,291 @@ +# Copyright (c) Open-CD. All rights reserved. +import copy +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.dataset import BaseDataset, Compose + +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class _BaseCDDataset(BaseDataset): + """Custom datasets for change detection. An example of file structure + is as followed. + .. code-block:: none + ├── data + │ ├── my_dataset + │ │ ├── train + │ │ │ ├── img_path_from/xxx{img_suffix} + │ │ │ ├── img_path_to/xxx{img_suffix} + │ │ │ ├── seg_map_path/xxx{img_suffix} + │ │ ├── val + │ │ │ ├── img_path_from/xxx{seg_map_suffix} + │ │ │ ├── img_path_to/xxx{seg_map_suffix} + │ │ │ ├── seg_map_path/xxx{seg_map_suffix} + + The imgs/gt_semantic_seg pair of CustomDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_path_x/``and ``seg_map_path`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + format_seg_map (str): If `format_seg_map`='to_binary', the binary + change detection label will be formatted as 0 (<128) or 1 (>=128). + Default: None + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + seg_map_suffix='.png', + format_seg_map=None, + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img_path='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + ignore_index: int = 255, + reduce_zero_label: bool = False, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.seg_map_suffix = seg_map_suffix + self.format_seg_map = format_seg_map + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir_from = self.data_prefix.get('img_path_from', None) + img_dir_to = self.data_prefix.get('img_path_to', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + + if osp.isfile(self.ann_file): + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict(img_path=\ + [osp.join(img_dir_from, img_name + self.img_suffix), \ + osp.join(img_dir_to, img_name + self.img_suffix)]) + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['format_seg_map'] = self.format_seg_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + file_list_from = fileio.list_dir_or_file( + dir_path=img_dir_from, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args) + file_list_to = fileio.list_dir_or_file( + dir_path=img_dir_to, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args) + + assert sorted(list(file_list_from)) == sorted(list(file_list_to)), \ + 'The images in `img_path_from` and `img_path_to` are not ' \ + 'one-to-one correspondence' + + for img in fileio.list_dir_or_file( + dir_path=img_dir_from, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=\ + [osp.join(img_dir_from, img), \ + osp.join(img_dir_to, img)]) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['format_seg_map'] = self.format_seg_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/opencd/datasets/basescddataset.py b/opencd/datasets/basescddataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bedc8dd8a49a727c2342f34c28f6e79c8951a5 --- /dev/null +++ b/opencd/datasets/basescddataset.py @@ -0,0 +1,206 @@ +# Copyright (c) Open-CD. All rights reserved. +import copy +import os.path as osp +from typing import Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np + +from mmseg.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class BaseSCDDataset(_BaseCDDataset): + def __init__(self, + lazy_init=False, + reduce_semantic_zero_label=False, + **kwargs): + super().__init__(lazy_init=True, **kwargs) + + self.reduce_semantic_zero_label = reduce_semantic_zero_label + + # Get label map for semantic custom classes + new_classes = self._metainfo.get('semantic_classes', None) + self.semantic_label_map = self.get_semantic_label_map(new_classes) + self._metainfo.update( + dict( + semantic_label_map=self.semantic_label_map, + reduce_semantic_zero_label=self.reduce_semantic_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_semantic_palette = self._update_semantic_palette() + self._metainfo.update(dict(semantic_palette=updated_semantic_palette)) + + if not lazy_init: + self.full_init() + + if self.test_mode: + assert self._metainfo.get('semantic_classes') is not None, \ + 'dataset metainfo `semantic_classes` should be specified when testing' + + @classmethod + def get_semantic_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require semantic label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('semantic_classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['semantic_classes']): + raise ValueError( + f'new semantic_classes {new_classes} is not a ' + f'subset of semantic_classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_semantic_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('semantic_palette', []) + classes = self._metainfo.get('semantic_classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.semantic_label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.semantic_label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir_from = self.data_prefix.get('img_path_from', None) + img_dir_to = self.data_prefix.get('img_path_to', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + ann_dir_from = self.data_prefix.get('seg_map_path_from', None) + ann_dir_to = self.data_prefix.get('seg_map_path_to', None) + + if osp.isfile(self.ann_file): + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + data_names = line.strip().split(' ') + # img_name: img1, img2, binary label, semantic_from label, \ + # semantic_to label + img_name_from, img_name_to, ann_name, ann_name_from, \ + ann_name_to = data_names + + data_info = dict(img_path=\ + [osp.join(img_dir_from, img_name_from + self.img_suffix), \ + osp.join(img_dir_to, img_name_to + self.img_suffix)]) + if ann_dir is not None: + seg_map = ann_name + self.seg_map_suffix + seg_map_from = ann_name_from + self.seg_map_suffix + seg_map_to = ann_name_to + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map_from) + data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map_to) + data_info['label_map'] = self.label_map + data_info['format_seg_map'] = self.format_seg_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['semantic_label_map'] = self.semantic_label_map + data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + file_list_from = fileio.list_dir_or_file( + dir_path=img_dir_from, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args) + file_list_to = fileio.list_dir_or_file( + dir_path=img_dir_to, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args) + + assert sorted(list(file_list_from)) == sorted(list(file_list_to)), \ + 'The images in `img_path_from` and `img_path_to` are not ' \ + 'one-to-one correspondence' + + for img in fileio.list_dir_or_file( + dir_path=img_dir_from, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=\ + [osp.join(img_dir_from, img), \ + osp.join(img_dir_to, img)]) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map) + data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map) + data_info['label_map'] = self.label_map + data_info['format_seg_map'] = self.format_seg_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['semantic_label_map'] = self.semantic_label_map + data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/opencd/datasets/clcd.py b/opencd/datasets/clcd.py new file mode 100644 index 0000000000000000000000000000000000000000..21a6e477d53a8ce27363bb72d7feace064afe0df --- /dev/null +++ b/opencd/datasets/clcd.py @@ -0,0 +1,22 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class CLCD_Dataset(_BaseCDDataset): + """CLCD dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + format_seg_map='to_binary', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + format_seg_map=format_seg_map, + **kwargs) diff --git a/opencd/datasets/dsifn.py b/opencd/datasets/dsifn.py new file mode 100644 index 0000000000000000000000000000000000000000..25fcd4094cc6cb0b71948e04dcd750c4d39db079 --- /dev/null +++ b/opencd/datasets/dsifn.py @@ -0,0 +1,16 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class DSIFN_Dataset(_BaseCDDataset): + """DSIFN dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.jpg', + **kwargs) -> None: + super().__init__(img_suffix=img_suffix, **kwargs) diff --git a/opencd/datasets/landsat.py b/opencd/datasets/landsat.py new file mode 100644 index 0000000000000000000000000000000000000000..7a33051e3a1cc0abf5c22b99324b2a59fe93d258 --- /dev/null +++ b/opencd/datasets/landsat.py @@ -0,0 +1,27 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basescddataset import BaseSCDDataset + + +@DATASETS.register_module() +class Landsat_Dataset(BaseSCDDataset): + """Landsat dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]], + semantic_classes=('unchanged', 'farmland', 'desert', + 'building', 'water'), + semantic_palette=[[255, 255, 255], [128, 128, 128], [130, 87, 87], + [255, 0, 0], [0, 0, 255]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_semantic_zero_label=True, + **kwargs) -> None: + + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_semantic_zero_label=reduce_semantic_zero_label, + **kwargs) \ No newline at end of file diff --git a/opencd/datasets/levir_cd.py b/opencd/datasets/levir_cd.py new file mode 100644 index 0000000000000000000000000000000000000000..02ea035b87f84102fa774081a2815b5dda322df7 --- /dev/null +++ b/opencd/datasets/levir_cd.py @@ -0,0 +1,22 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class LEVIR_CD_Dataset(_BaseCDDataset): + """LEVIR-CD dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + format_seg_map='to_binary', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + format_seg_map=format_seg_map, + **kwargs) diff --git a/opencd/datasets/rsipac_cd.py b/opencd/datasets/rsipac_cd.py new file mode 100644 index 0000000000000000000000000000000000000000..c25e3073f10ddaf7715c196978ae1fdca9abfe31 --- /dev/null +++ b/opencd/datasets/rsipac_cd.py @@ -0,0 +1,20 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class RSIPAC_CD_Dataset(_BaseCDDataset): + """RSIPAC_CD dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.tif', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + **kwargs) diff --git a/opencd/datasets/s2looking.py b/opencd/datasets/s2looking.py new file mode 100644 index 0000000000000000000000000000000000000000..60918bcc86472416497f12aa1e4b613f29e01ebc --- /dev/null +++ b/opencd/datasets/s2looking.py @@ -0,0 +1,22 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class S2Looking_Dataset(_BaseCDDataset): + """S2Looking dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + format_seg_map='to_binary', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + format_seg_map=format_seg_map, + **kwargs) diff --git a/opencd/datasets/second.py b/opencd/datasets/second.py new file mode 100644 index 0000000000000000000000000000000000000000..86dc278d508d08b3eaaac77b3f2264d1e3b9f552 --- /dev/null +++ b/opencd/datasets/second.py @@ -0,0 +1,29 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basescddataset import BaseSCDDataset + + +@DATASETS.register_module() +class SECOND_Dataset(BaseSCDDataset): + """SECOND dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]], + semantic_classes=('unchanged', 'water', 'ground', + 'low vegetation', 'tree', 'building', + 'sports field'), + semantic_palette=[[255, 255, 255], [0, 0, 255], [128, 128, 128], + [0, 128, 0], [0, 255, 0], [128, 0, 0], + [255, 0, 0]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_semantic_zero_label=True, + **kwargs) -> None: + + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_semantic_zero_label=reduce_semantic_zero_label, + **kwargs) \ No newline at end of file diff --git a/opencd/datasets/svcd.py b/opencd/datasets/svcd.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf5e4eafbb870e6ebcad7dacbc5816bceb217fc --- /dev/null +++ b/opencd/datasets/svcd.py @@ -0,0 +1,22 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class SVCD_Dataset(_BaseCDDataset): + """SVCD dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.jpg', + format_seg_map='to_binary', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + format_seg_map=format_seg_map, + **kwargs) diff --git a/opencd/datasets/transforms/__init__.py b/opencd/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55f91d102caa1d6764e5161789c05634d482e4a5 --- /dev/null +++ b/opencd/datasets/transforms/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Open-CD. All rights reserved. +from .formatting import MultiImgPackSegInputs +from .loading import (MultiImgLoadAnnotations, MultiImgLoadImageFromFile, + MultiImgLoadInferencerLoader, + MultiImgLoadLoadImageFromNDArray) +# yapf: disable +from .transforms import (MultiImgAdjustGamma, MultiImgAlbu, MultiImgCLAHE, + MultiImgExchangeTime, MultiImgNormalize, MultiImgPad, + MultiImgPhotoMetricDistortion, MultiImgRandomCrop, + MultiImgRandomCutOut, MultiImgRandomFlip, + MultiImgRandomResize, MultiImgRandomRotate, + MultiImgRandomRotFlip, MultiImgRerange, + MultiImgResize, MultiImgResizeShortestEdge, + MultiImgResizeToMultiple, MultiImgRGB2Gray) + +# yapf: enable +__all__ = [ + 'MultiImgPackSegInputs', 'MultiImgLoadImageFromFile', 'MultiImgLoadAnnotations', + 'MultiImgLoadLoadImageFromNDArray', 'MultiImgLoadInferencerLoader', + 'MultiImgResizeToMultiple', 'MultiImgRerange', 'MultiImgCLAHE', 'MultiImgRandomCrop', + 'MultiImgRandomRotate', 'MultiImgRGB2Gray', 'MultiImgAdjustGamma', + 'MultiImgPhotoMetricDistortion', 'MultiImgRandomCutOut', 'MultiImgRandomRotFlip', + 'MultiImgResizeShortestEdge', 'MultiImgExchangeTime', 'MultiImgResize', + 'MultiImgRandomResize', 'MultiImgNormalize', 'MultiImgRandomFlip', 'MultiImgPad', + 'MultiImgAlbu' +] diff --git a/opencd/datasets/transforms/__pycache__/__init__.cpython-311.pyc b/opencd/datasets/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1800f2721810b3b18d3e325171c5bdc37ffd725 Binary files /dev/null and b/opencd/datasets/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/datasets/transforms/__pycache__/formatting.cpython-311.pyc b/opencd/datasets/transforms/__pycache__/formatting.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d8bb65ab5b6bac050d8b7d017d6606bd1058d48 Binary files /dev/null and b/opencd/datasets/transforms/__pycache__/formatting.cpython-311.pyc differ diff --git a/opencd/datasets/transforms/__pycache__/loading.cpython-311.pyc b/opencd/datasets/transforms/__pycache__/loading.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82270b8ea0bdd5a1c2f7b7f6b6e424bef4b5ae16 Binary files /dev/null and b/opencd/datasets/transforms/__pycache__/loading.cpython-311.pyc differ diff --git a/opencd/datasets/transforms/__pycache__/transforms.cpython-311.pyc b/opencd/datasets/transforms/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78fa493f7f2be22601b9146b80aa7bdb2c57ad8b Binary files /dev/null and b/opencd/datasets/transforms/__pycache__/transforms.cpython-311.pyc differ diff --git a/opencd/datasets/transforms/formatting.py b/opencd/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..67c90d3904f89c6b869fa5c85f6c79644988ac10 --- /dev/null +++ b/opencd/datasets/transforms/formatting.py @@ -0,0 +1,116 @@ +# Copyright (c) Open-CD. All rights reserved. +import numpy as np +import torch +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmengine.structures import PixelData + +from mmseg.structures import SegDataSample +from opencd.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class MultiImgPackSegInputs(BaseTransform): + """Pack the inputs data for the semantic segmentation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: + + - ``img_path``: filename of the image + + - ``ori_shape``: original shape of the image as a tuple (h, w, c) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w, c). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``pad_shape``: shape of padded images + + - ``scale_factor``: a float indicating the preprocessing scale + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + Args: + meta_keys (Sequence[str], optional): Meta keys to be packed from + ``SegDataSample`` and collected in ``data[img_metas]``. + Default: ``('img_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction')`` + """ + + def __init__(self, + meta_keys=('img_path', 'seg_map_path', 'seg_map_path_from', + 'seg_map_path_to', 'ori_shape','img_shape', + 'pad_shape', 'scale_factor', 'flip', + 'flip_direction')): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`SegDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + def _transform_img(img): + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + return img + + imgs = [_transform_img(img) for img in results['img']] + imgs = torch.cat(imgs, axis=0) # -> (6, H, W) + packed_results['inputs'] = imgs + + data_sample = SegDataSample() + if 'gt_seg_map' in results: + gt_sem_seg_data = dict( + data=to_tensor(results['gt_seg_map'][None, + ...].astype(np.int64))) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + if 'gt_edge_map' in results: + gt_edge_data = dict( + data=to_tensor(results['gt_edge_map'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data))) + + if 'gt_seg_map_from' in results: + gt_sem_seg_data_from = dict( + data=to_tensor(results['gt_seg_map_from'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_sem_seg_from=PixelData(**gt_sem_seg_data_from))) + + if 'gt_seg_map_to' in results: + gt_sem_seg_data_to = dict( + data=to_tensor(results['gt_seg_map_to'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_sem_seg_to=PixelData(**gt_sem_seg_data_to))) + + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/opencd/datasets/transforms/loading.py b/opencd/datasets/transforms/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ec9891fb62cb9681a8e89c66295e399d6cf4d0 --- /dev/null +++ b/opencd/datasets/transforms/loading.py @@ -0,0 +1,453 @@ +# Copyright (c) Open-CD. All rights reserved. +import warnings +from typing import Dict, Optional, Union + +import mmcv +import mmengine.fileio as fileio +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile as MMCV_LoadImageFromFile + +from opencd.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class MultiImgLoadImageFromFile(MMCV_LoadImageFromFile): + """Load an image pair from files. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def transform(self, results: dict) -> Optional[dict]: + """Functions to load image. + + Args: + results (dict): Result dict from + :class:`mmengine.dataset.BaseDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filenames = results['img_path'] + imgs = [] + try: + for filename in filenames: + if self.file_client_args is not None: + file_client = fileio.FileClient.infer_client( + self.file_client_args, filename) + img_bytes = file_client.get(filename) + else: + img_bytes = fileio.get( + filename, backend_args=self.backend_args) + img = mmcv.imfrombytes( + img_bytes, flag=self.color_type, backend=self.imdecode_backend) + if self.to_float32: + img = img.astype(np.float32) + imgs.append(img) + except Exception as e: + if self.ignore_empty: + return None + else: + raise e + + results['img'] = imgs + results['img_shape'] = imgs[0].shape[:2] + results['ori_shape'] = imgs[0].shape[:2] + return results + + +@TRANSFORMS.register_module() +class MultiImgLoadAnnotations(MMCV_LoadAnnotations): + """Load annotations for change detection provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of change detection ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of change detection ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_zero_label (bool, optional): Whether reduce all label value + by 255. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + reduce_zero_label=None, + backend_args=None, + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + img_bytes = fileio.get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='grayscale', # in mmseg: unchanged + backend=self.imdecode_backend).squeeze().astype(np.uint8) + + # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify to format ann + if results.get('format_seg_map', None) is not None: + if results['format_seg_map'] == 'to_binary': + gt_semantic_seg_copy = gt_semantic_seg.copy() + gt_semantic_seg[gt_semantic_seg_copy < 128] = 0 + gt_semantic_seg[gt_semantic_seg_copy >= 128] = 1 + else: + raise ValueError('Invalid value {}'.format(results['format_seg_map'])) + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['seg_fields'].append('gt_seg_map') + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgMultiAnnLoadAnnotations(MMCV_LoadAnnotations): + """Load annotations for semantic change detection provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of change detection ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of change detection ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_semantic_zero_label (bool, optional): Whether reduce all label value + by 255. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + reduce_semantic_zero_label=None, + backend_args=None, + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.reduce_semantic_zero_label = reduce_semantic_zero_label + if self.reduce_semantic_zero_label is not None: + warnings.warn('`reduce_semantic_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_semantic_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + img_bytes = fileio.get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='grayscale', # in mmseg: unchanged + backend=self.imdecode_backend).squeeze().astype(np.uint8) + # for semantic anns + img_bytes_from = fileio.get( + results['seg_map_path_from'], backend_args=self.backend_args) + gt_semantic_seg_from = mmcv.imfrombytes( + img_bytes_from, flag='grayscale', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + img_bytes_to = fileio.get( + results['seg_map_path_to'], backend_args=self.backend_args) + gt_semantic_seg_to = mmcv.imfrombytes( + img_bytes_to, flag='grayscale', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + + # reduce zero_label + if self.reduce_semantic_zero_label is None: + self.reduce_semantic_zero_label = results['reduce_semantic_zero_label'] + assert self.reduce_semantic_zero_label == results['reduce_semantic_zero_label'], \ + 'Initialize dataset with `reduce_semantic_zero_label` as ' \ + f'{results["reduce_semantic_zero_label"]} but when load annotation ' \ + f'the `reduce_semantic_zero_label` is {self.reduce_semantic_zero_label}' + if self.reduce_semantic_zero_label: + # avoid using underflow conversion + gt_semantic_seg_from[gt_semantic_seg_from == 0] = 255 + gt_semantic_seg_from = gt_semantic_seg_from - 1 + gt_semantic_seg_from[gt_semantic_seg_from == 254] = 255 + gt_semantic_seg_to[gt_semantic_seg_to == 0] = 255 + gt_semantic_seg_to = gt_semantic_seg_to - 1 + gt_semantic_seg_to[gt_semantic_seg_to == 254] = 255 + # modify to format ann + if results.get('format_seg_map', None) is not None: + if results['format_seg_map'] == 'to_binary': + gt_semantic_seg_copy = gt_semantic_seg.copy() + gt_semantic_seg[gt_semantic_seg_copy < 128] = 0 + gt_semantic_seg[gt_semantic_seg_copy >= 128] = 1 + else: + raise ValueError('Invalid value {}'.format(results['format_seg_map'])) + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + if results.get('semantic_label_map', None) is not None: + ''' Just for semantic anns here ''' + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_from_copy = gt_semantic_seg_from.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg_from[gt_semantic_seg_from_copy == old_id] = new_id + gt_semantic_seg_to_copy = gt_semantic_seg_to.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg_to[gt_semantic_seg_to_copy == old_id] = new_id + + results['gt_seg_map'] = gt_semantic_seg + results['gt_seg_map_from'] = gt_semantic_seg_from + results['gt_seg_map_to'] = gt_semantic_seg_to + results['seg_fields'].extend(['gt_seg_map', + 'gt_seg_map_from', 'gt_seg_map_to']) + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_semantic_zero_label={self.reduce_semantic_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgLoadLoadImageFromNDArray(MultiImgLoadImageFromFile): + """Load an image pair from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + imgs = [] + if self.to_float32: + for img in results['img']: + img = img.astype(np.float32) + imgs.append(img) + + results['img_path'] = None + results['img'] = imgs + results['img_shape'] = imgs[0].shape[:2] + results['ori_shape'] = imgs[0].shape[:2] + return results + + +@TRANSFORMS.register_module() +class MultiImgLoadInferencerLoader(BaseTransform): + """Load an image pair from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='MultiImgLoadImageFromFile', **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='MultiImgLoadLoadImageFromNDArray', **kwargs)) + + def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + assert len(single_input) == 2, \ + 'In `MultiImgLoadInferencerLoader`,' \ + '`single_input` contains bi-temporal images' + if isinstance(single_input[0], str): + inputs = dict(img_path=single_input) + elif isinstance(single_input[0], Union[np.ndarray, list]): + inputs = dict(img=single_input) + elif isinstance(single_input[0], dict): + inputs = single_input + else: + raise NotImplementedError + + if 'img' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) diff --git a/opencd/datasets/transforms/transforms.py b/opencd/datasets/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbd7dfd8bbb12854223be7310cf6f3de8d6d5a1 --- /dev/null +++ b/opencd/datasets/transforms/transforms.py @@ -0,0 +1,1867 @@ +# Copyright (c) Open-CD. All rights reserved. +import copy +import warnings +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import cv2 +import mmcv +import numpy as np +from mmcv.image.geometric import _scale_size +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_list_of, is_seq_of, is_str, is_tuple_of +from numpy import random +from scipy.ndimage import gaussian_filter + +from mmseg.datasets.dataset_wrappers import MultiImageMixDataset +from opencd.registry import TRANSFORMS + +try: + import albumentations + from albumentations import Compose +except ImportError: + albumentations = None + Compose = None + + +@TRANSFORMS.register_module() +class MultiImgResizeToMultiple(BaseTransform): + """Resize images & seg to multiple of divisor. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - pad_shape + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def transform(self, results: dict) -> dict: + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + imgs = results['img'] + imgs = [ + mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') for img in imgs] + + results['img'] = imgs + results['img_shape'] = imgs[0].shape + results['pad_shape'] = imgs[0].shape + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRerange(BaseTransform): + """Rerange the image pixel value. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def transform(self, results: dict) -> dict: + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + def _rerange(img): + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + return img + + results['img'] = [_rerange(img) for img in results['img']] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgCLAHE(BaseTransform): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def transform(self, results: dict) -> dict: + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + def _clane(img): + for i in range(img.shape[2]): + img[:, :, i] = mmcv.clahe( + np.array(img[:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + return img + + results['img'] = [_clane(img) for img in results['img']] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRandomCrop(BaseTransform): + """Random crop the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - gt_seg_map + + + Args: + crop_size (Union[int, Tuple[int, int]]): Expected size after cropping + with the format of (h, w). If set to an integer, then cropping + width and height are equal to this integer. + cat_max_ratio (float): The maximum ratio that single category could + occupy. + ignore_index (int): The label index to be ignored. Default: 255 + """ + + def __init__(self, + crop_size: Union[int, Tuple[int, int]], + cat_max_ratio: float = 1., + ignore_index: int = 255): + super().__init__() + assert isinstance(crop_size, int) or ( + isinstance(crop_size, tuple) and len(crop_size) == 2 + ), 'The expected crop_size is an integer, or a tuple containing two ' + 'intergers' + + if isinstance(crop_size, int): + crop_size = (crop_size, crop_size) + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + @cache_randomness + def crop_bbox(self, results: dict) -> tuple: + """get a crop bounding box. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: Coordinates of the cropped image. + """ + + def generate_crop_bbox(img: np.ndarray) -> tuple: + """Randomly get a crop bounding box. + + Args: + img (np.ndarray): Original input image. + + Returns: + tuple: Coordinates of the cropped image. + """ + + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + img = results['img'][0] + crop_bbox = generate_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_seg_map'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = generate_crop_bbox(img) + + return crop_bbox + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + crop_bbox = self.crop_bbox(results) + + # crop the image + imgs = [self.crop(img, crop_bbox) for img in results['img']] + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + results['img'] = imgs + results['img_shape'] = imgs[0].shape + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@TRANSFORMS.register_module() +class MultiImgRandomRotate(BaseTransform): + """Rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + @cache_randomness + def generate_degree(self): + return np.random.rand() < self.prob, np.random.uniform( + min(*self.degree), max(*self.degree)) + + def transform(self, results: dict) -> dict: + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate, degree = self.generate_degree() + if rotate: + # rotate image + results['img'] = [ + mmcv.imrotate( + img, + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) for img in results['img']] + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRGB2Gray(BaseTransform): + """Convert RGB image to grayscale image. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def transform(self, results: dict) -> dict: + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + + def _rgb2gray(img): + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + return img + + imgs = [_rgb2gray(img) for img in results['img']] + + results['img'] = imgs + results['img_shape'] = imgs[0].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgAdjustGamma(BaseTransform): + """Using gamma correction to process the image. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def transform(self, results: dict) -> dict: + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = [ + mmcv.lut_transform( + np.array(img, dtype=np.uint8), self.table) for img in results['img'] + ] + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@TRANSFORMS.register_module() +class MultiImgPhotoMetricDistortion(BaseTransform): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + consistent_contrast_mode (bool): Whether to + keep the contrast mode consistent. + """ + + def __init__(self, + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_delta: int = 18, + consistent_contrast_mode: bool = False): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + self.consistent_contrast_mode = consistent_contrast_mode + + def convert(self, + img: np.ndarray, + alpha: int = 1, + beta: int = 0) -> np.ndarray: + """Multiple with alpha and add beat with clip. + + Args: + img (np.ndarray): The input image. + alpha (int): Image weights, change the contrast/saturation + of the image. Default: 1 + beta (int): Image bias, change the brightness of the + image. Default: 0 + + Returns: + np.ndarray: The transformed image. + """ + + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img: np.ndarray) -> np.ndarray: + """Brightness distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after brightness change. + """ + + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img: np.ndarray) -> np.ndarray: + """Contrast distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after contrast change. + """ + + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img: np.ndarray) -> np.ndarray: + """Saturation distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after saturation change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img: np.ndarray) -> np.ndarray: + """Hue distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after hue change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def transform(self, results: dict) -> dict: + """Transform function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + def _photo_metric_distortion(img, contrast_mode=None): + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = contrast_mode or random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + return img + + + contrast_mode = random.randint(2) \ + if self.consistent_contrast_mode else None + results['img'] = [_photo_metric_distortion(img, contrast_mode) \ + for img in results['img']] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta}), ' + f'consistent_contrast_mode={self.consistent_contrast_mode}') + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRandomCutOut(BaseTransform): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): cutout probability. + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. + """ + + def __init__(self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None): + + assert 0 <= prob and prob <= 1 + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) + self.prob = prob + self.n_holes = n_holes + self.fill_in = fill_in + self.seg_fill_in = seg_fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + @cache_randomness + def do_cutout(self): + return np.random.rand() < self.prob + + @cache_randomness + def generate_patches(self, results): + cutout = self.do_cutout() + + h, w, _ = results['img'][0].shape + if cutout: + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + else: + n_holes = 0 + x1_lst = [] + y1_lst = [] + index_lst = [] + for _ in range(n_holes): + x1_lst.append(np.random.randint(0, w)) + y1_lst.append(np.random.randint(0, h)) + index_lst.append(np.random.randint(0, len(self.candidates))) + return cutout, n_holes, x1_lst, y1_lst, index_lst + + def transform(self, results: dict) -> dict: + """Call function to drop some regions of image.""" + cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches( + results) + if cutout: + h, w, c = results['img'][0].shape + for i in range(n_holes): + x1 = x1_lst[i] + y1 = y1_lst[i] + index = index_lst[i] + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + for idx in range(len(results['img'])): + results['img'][idx][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRandomRotFlip(BaseTransform): + """Rotate and flip the image & seg or just rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + rotate_prob (float): The probability of rotate image. + flip_prob (float): The probability of rotate&flip image. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + """ + + def __init__(self, + rotate_prob=0.5, + flip_prob=0.5, + degree=(-20, 20), + pad_val=0, + seg_pad_val=255): + self.rotate_prob = rotate_prob + self.flip_prob = flip_prob + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + + def random_rot_flip(self, results: dict) -> dict: + k = np.random.randint(0, 4) + results['img'] = [np.rot90(img, k) for img in results['img']] + for key in results.get('seg_fields', []): + results[key] = np.rot90(results[key], k) + axis = np.random.randint(0, 2) + results['img'] = [ + np.flip(img, axis=axis).copy() for img in results['img']] + for key in results.get('seg_fields', []): + results[key] = np.flip(results[key], axis=axis).copy() + return results + + def random_rotate(self, results: dict) -> dict: + angle = np.random.uniform(min(*self.degree), max(*self.degree)) + results['img'] = [ + mmcv.imrotate(img, angle=angle, + border_value=self.pad_val) for img in results['img']] + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate(results[key], + angle=angle, + border_value=self.seg_pad_val, + interpolation='nearest') + return results + + def transform(self, results: dict) -> dict: + """Call function to rotate or rotate & flip image, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated or rotated & flipped results. + """ + rotate_flag = 0 + if random.random() < self.rotate_prob: + results = self.random_rotate(results) + rotate_flag = 1 + if random.random() < self.flip_prob and rotate_flag == 0: + results = self.random_rot_flip(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(rotate_prob={self.rotate_prob}, ' \ + f'flip_prob={self.flip_prob}, ' \ + f'degree={self.degree})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + Copyright (c) Facebook, Inc. and its affiliates. + Licensed under the Apache-2.0 License + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + + - img + - gt_seg_map (optional) + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional) + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, scale: Union[int, Tuple[int, int]], + max_size: int) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + # Create a empty Resize object + self.resize = TRANSFORMS.build({ + 'type': 'MultiImgResize', + 'scale': 0, + 'keep_ratio': True + }) + + def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return (new_w, new_h) + + def transform(self, results: Dict) -> Dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + +@TRANSFORMS.register_module() +class MultiImgExchangeTime(BaseTransform): + """Exchange images of different times. + Args: + prob (float): probability of applying the transform. Default: 0.5. + """ + def __init__(self, + prob: float = 0.5) -> None: + + assert 0 <= prob and prob <= 1 + self.prob = prob + + def transform(self, results: dict) -> dict: + """Call function to exchange images .""" + exchange = True if np.random.rand() < self.prob else False + if exchange: + results['img'].reverse() # list.reverse() + if 'gt_seg_map_from' in results['seg_fields'] and \ + 'gt_seg_map_to' in results['seg_fields']: + results['gt_seg_map_from'], results['gt_seg_map_to'] = \ + results['gt_seg_map_to'], results['gt_seg_map_from'] + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgResize(BaseTransform): + """Resize images & seg. + This transform resizes the input image according to ``scale`` or + ``scale_factor``. Bboxes, seg map and keypoints are then resized with the + same scale factor. + if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to + resize. + Required Keys: + + - img + - gt_seg_map (optional) + + Modified Keys: + + - img + - gt_seg_map + - img_shape + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + Args: + scale (int or tuple): Images scales for resizing. Defaults to None + scale_factor (float or tuple[float]): Scale factors for resizing. + Defaults to None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Defaults to False. + clip_object_border (bool): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, the gt + bboxes are allowed to cross the border of images. Therefore, we + don't need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def __init__(self, + scale: Optional[Union[int, Tuple[int, int]]] = None, + scale_factor: Optional[Union[float, Tuple[float, + float]]] = None, + keep_ratio: bool = False, + clip_object_border: bool = True, + backend: str = 'cv2', + interpolation='bilinear') -> None: + assert scale is not None or scale_factor is not None, ( + '`scale` and' + '`scale_factor` can not both be `None`') + if scale is None: + self.scale = None + else: + if isinstance(scale, int): + self.scale = (scale, scale) + else: + self.scale = scale + + self.backend = backend + self.interpolation = interpolation + self.keep_ratio = keep_ratio + self.clip_object_border = clip_object_border + if scale_factor is None: + self.scale_factor = None + elif isinstance(scale_factor, float): + self.scale_factor = (scale_factor, scale_factor) + elif isinstance(scale_factor, tuple): + assert (len(scale_factor)) == 2 + self.scale_factor = scale_factor + else: + raise TypeError( + f'expect scale_factor is float or Tuple(float), but' + f'get {type(scale_factor)}') + + def _resize_img(self, results: dict) -> None: + """Resize images with ``results['scale']``.""" + + if results.get('img', None) is not None: + res_imgs = [] + for img in results['img']: + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + img, + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = img.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + img, + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + res_imgs.append(img) + results['img'] = res_imgs + results['img_shape'] = res_imgs[0].shape[:2] + results['scale_factor'] = (w_scale, h_scale) + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results: dict) -> None: + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results[key] = gt_seg + + def transform(self, results: dict) -> dict: + """Transform function to resize images, semantic + segmentation map. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Resized results, 'img', 'gt_seg_map', + 'scale', 'scale_factor', 'img_shape', + and 'keep_ratio' keys are updated in result dict. + """ + + if self.scale: + results['scale'] = self.scale + else: + img_shape = results['img'][0].shape[:2] + results['scale'] = _scale_size(img_shape[::-1], + self.scale_factor) # type: ignore + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'scale_factor={self.scale_factor}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'clip_object_border={self.clip_object_border}), ' + repr_str += f'backend={self.backend}), ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRandomResize(BaseTransform): + """Random resize images. + How to choose the target scale to resize the image will follow the rules + below: + - if ``scale`` is a sequence of tuple + .. math:: + target\\_scale[0] \\sim Uniform([scale[0][0], scale[1][0]]) + .. math:: + target\\_scale[1] \\sim Uniform([scale[0][1], scale[1][1]]) + Following the resize order of weight and height in cv2, ``scale[i][0]`` + is for width, and ``scale[i][1]`` is for height. + - if ``scale`` is a tuple + .. math:: + target\\_scale[0] \\sim Uniform([ratio\\_range[0], ratio\\_range[1]]) + * scale[0] + .. math:: + target\\_scale[0] \\sim Uniform([ratio\\_range[0], ratio\\_range[1]]) + * scale[1] + Following the resize order of weight and height in cv2, ``ratio_range[0]`` + is for width, and ``ratio_range[1]`` is for height. + - if ``keep_ratio`` is True, the minimum value of ``target_scale`` will be + used to set the shorter side and the maximum value will be used to + set the longer side. + - if ``keep_ratio`` is False, the value of ``target_scale`` will be used to + reisze the width and height accordingly. + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + - img_shape + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + Args: + scale (tuple or Sequence[tuple]): Images scales for resizing. + Defaults to None. + ratio_range (tuple[float], optional): (min_ratio, max_ratio). + Defaults to None. + resize_type (str): The type of resize class to use. Defaults to + "Resize". + **resize_kwargs: Other keyword arguments for the ``resize_type``. + Note: + By defaults, the ``resize_type`` is "Resize", if it's not overwritten + by your registry, it indicates the :class:`mmcv.Resize`. And therefore, + ``resize_kwargs`` accepts any keyword arguments of it, like + ``keep_ratio``, ``interpolation`` and so on. + If you want to use your custom resize class, the class should accept + ``scale`` argument and have ``scale`` attribution which determines the + resize shape. + """ + + def __init__( + self, + scale: Union[Tuple[int, int], Sequence[Tuple[int, int]]], + ratio_range: Tuple[float, float] = None, + resize_type: str = 'MultiImgResize', + **resize_kwargs, + ) -> None: + + self.scale = scale + self.ratio_range = ratio_range + + self.resize_cfg = dict(type=resize_type, **resize_kwargs) + # create a empty Reisize object + self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg}) + + @staticmethod + def _random_sample(scales: Sequence[Tuple[int, int]]) -> tuple: + """Private function to randomly sample a scale from a list of tuples. + Args: + scales (list[tuple]): Images scale range for sampling. + There must be two tuples in scales, which specify the lower + and upper bound of image scales. + Returns: + tuple: The targeted scale of the image to be resized. + """ + + assert is_list_of(scales, tuple) and len(scales) == 2 + scale_0 = [scales[0][0], scales[1][0]] + scale_1 = [scales[0][1], scales[1][1]] + edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1) + edge_1 = np.random.randint(min(scale_1), max(scale_1) + 1) + scale = (edge_0, edge_1) + return scale + + @staticmethod + def _random_sample_ratio(scale: tuple, ratio_range: Tuple[float, + float]) -> tuple: + """Private function to randomly sample a scale from a tuple. + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``scale`` to + generate sampled scale. + Args: + scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``scale``. + Returns: + tuple: The targeted scale of the image to be resized. + """ + + assert isinstance(scale, tuple) and len(scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(scale[0] * ratio), int(scale[1] * ratio) + return scale + + @cache_randomness + def _random_scale(self) -> tuple: + """Private function to randomly sample an scale according to the type + of ``scale``. + Returns: + tuple: The targeted scale of the image to be resized. + """ + + if is_tuple_of(self.scale, int): + assert self.ratio_range is not None and len(self.ratio_range) == 2 + scale = self._random_sample_ratio( + self.scale, # type: ignore + self.ratio_range) + elif is_seq_of(self.scale, tuple): + scale = self._random_sample(self.scale) # type: ignore + else: + raise NotImplementedError('Do not support sampling function ' + f'for "{self.scale}"') + + return scale + + def transform(self, results: dict) -> dict: + """Transform function to resize images, bounding boxes, semantic + segmentation map. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Resized results, ``img``, ``gt_semantic_seg``, + ``scale``, ``scale_factor``, ``img_shape``, and + ``keep_ratio`` keys are updated in result dict. + """ + results['scale'] = self._random_scale() + self.resize.scale = results['scale'] + results = self.resize(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'resize_cfg={self.resize_cfg})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgNormalize(BaseTransform): + """Normalize the images. + Required Keys: + + - img + + Modified Keys: + + - img + + Added Keys: + + - img_norm_cfg + - mean + - std + - to_rgb + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB before + normlizing the image. If ``to_rgb=True``, the order of mean and std + should be RGB. If ``to_rgb=False``, the order of mean and std + should be the same order of the image. Defaults to True. + """ + + def __init__(self, + mean: Sequence[Union[int, float]], + std: Sequence[Union[int, float]], + to_rgb: bool = True) -> None: + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def transform(self, results: dict) -> dict: + """Function to normalize images. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Normalized results, key 'img_norm_cfg' key is added in to + result dict. + """ + + results['img'] = [ + mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) + for img in results['img']] + results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgRandomFlip(BaseTransform): + """Flip the image & segmentation map. Added or Updated + keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and + gt_keypoints. There are 3 flip modes: + + - ``prob`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``prob`` . + E.g., ``prob=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + + - ``prob`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``prob/len(direction)``. + E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + + - ``prob`` is list of float, ``direction`` is list of string: + given ``len(prob) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``prob[i]``. + E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with + probability of 0.3, vertically with probability of 0.5. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Added Keys: + + - flip + - flip_direction + + Args: + prob (float | list[float], optional): The flipping probability. + Defaults to None. + direction(str | list[str]): The flipping direction. Options + If input is a list, the length must equal ``prob``. Each + element in ``prob`` indicates the flip probability of + corresponding direction. Defaults to 'horizontal'. + """ + + def __init__(self, + prob: Optional[Union[float, Iterable[float]]] = None, + direction: Union[str, Sequence[Optional[str]]] = 'horizontal') -> None: + + if isinstance(prob, list): + assert is_list_of(prob, float) + assert 0 <= sum(prob) <= 1 + elif isinstance(prob, float): + assert 0 <= prob <= 1 + else: + raise ValueError(f'probs must be float or list of float, but \ + got `{type(prob)}`.') + self.prob = prob + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError(f'direction must be either str or list of str, \ + but got `{type(direction)}`.') + self.direction = direction + + if isinstance(prob, list): + assert len(prob) == len(self.direction) + + @cache_randomness + def _choose_direction(self) -> str: + """Choose the flip direction according to `prob` and `direction`""" + if isinstance(self.direction, + Sequence) and not isinstance(self.direction, str): + # None means non-flip + direction_list: list = list(self.direction) + [None] + elif isinstance(self.direction, str): + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.prob, list): + non_prob: float = 1 - sum(self.prob) + prob_list = self.prob + [non_prob] + elif isinstance(self.prob, float): + non_prob = 1. - self.prob + # exclude non-flip + single_ratio = self.prob / (len(direction_list) - 1) + prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob] + + cur_dir = np.random.choice(direction_list, p=prob_list) + + return cur_dir + + def transform(self, results: dict) -> dict: + """Transform function to flip images, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'img', 'gt_seg_map', + 'flip', and 'flip_direction' keys are + updated in result dict. + """ + + cur_dir = self._choose_direction() + if cur_dir is None: + results['flip'] = False + results['flip_direction'] = None + else: + results['flip'] = True + results['flip_direction'] = cur_dir + + # flip image + results['img'] = [ + mmcv.imflip(img, direction=results['flip_direction']) + for img in results['img'] + ] + + # flip segs + for key in results.get('seg_fields', []): + # use copy() to make numpy stride positive + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']).copy() + return results + + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'direction={self.direction})' + + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgPad(BaseTransform): + """Pad the image & segmentation map. + + There are three padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. and (3)pad to square. Also, + pad to square and pad to the minimum size can be used as the same time. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + - img_shape + + Added Keys: + + - pad_shape + - pad_fixed_size + - pad_size_divisor + + Args: + size (tuple, optional): Fixed padding size. + Expected padding shape (w, h). Defaults to None. + size_divisor (int, optional): The divisor of padded size. Defaults to + None. + pad_to_square (bool): Whether to pad the image into a square. + Currently only used for YOLOX. Defaults to False. + pad_val (Number | dict[str, Number], optional): Padding value for if + the pad_mode is "constant". If it is a single number, the value + to pad the image is the number and to pad the semantic + segmentation map is 255. If it is a dict, it should have the + following keys: + + - img: The value to pad the image. + - seg: The value to pad the semantic segmentation map. + + Defaults to dict(img=0, seg=255). + padding_mode (str): Type of padding. Should be: constant, edge, + reflect or symmetric. Defaults to 'constant'. + + - constant: pads with a constant value, this value is specified + with pad_val. + - edge: pads with the last value at the edge of the image. + - reflect: pads with reflection of image without repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with 2 + elements on both sides in reflect mode will result in + [3, 2, 1, 2, 3, 4, 3, 2]. + - symmetric: pads with reflection of image repeating the last value + on the edge. For example, padding [1, 2, 3, 4] with 2 elements on + both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def __init__(self, + size: Optional[Tuple[int, int]] = None, + size_divisor: Optional[int] = None, + pad_to_square: bool = False, + pad_val: Union[int, float, dict] = dict(img=0, seg=255), + padding_mode: str = 'constant') -> None: + self.size = size + self.size_divisor = size_divisor + if isinstance(pad_val, int): + pad_val = dict(img=pad_val, seg=255) + assert isinstance(pad_val, dict), 'pad_val ' + self.pad_val = pad_val + self.pad_to_square = pad_to_square + + if pad_to_square: + assert size is None, \ + 'The size and size_divisor must be None ' \ + 'when pad2square is True' + else: + assert size is not None or size_divisor is not None, \ + 'only one of size and size_divisor should be valid' + assert size is None or size_divisor is None + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + self.padding_mode = padding_mode + + def _pad_img(self, results: dict) -> None: + """Pad images according to ``self.size``.""" + pad_val = self.pad_val.get('img', 0) + + size = None + if self.pad_to_square: + max_size = max(results['img'][0].shape[:2]) + size = (max_size, max_size) + if self.size_divisor is not None: + if size is None: + size = (results['img'][0].shape[0], results['img'].shape[1]) + pad_h = int(np.ceil( + size[0] / self.size_divisor)) * self.size_divisor + pad_w = int(np.ceil( + size[1] / self.size_divisor)) * self.size_divisor + size = (pad_h, pad_w) + elif self.size is not None: + size = self.size[::-1] + if isinstance(pad_val, int) and results['img'][0].ndim == 3: + pad_val = tuple(pad_val for _ in range(results['img'][0].shape[2])) + + padded_imgs = [ + mmcv.impad( + img, + shape=size, + pad_val=pad_val, + padding_mode=self.padding_mode) for img in results['img']] + + results['img'] = padded_imgs + results['pad_shape'] = padded_imgs[0].shape + results['pad_fixed_size'] = self.size + results['pad_size_divisor'] = self.size_divisor + results['img_shape'] = padded_imgs[0].shape[:2] + + def _pad_seg(self, results: dict) -> None: + """Pad semantic segmentation map according to + ``results['pad_shape']``.""" + pad_val = self.pad_val.get('seg', 255) + for key in results.get('seg_fields', []): + results[key] = mmcv.impad( + results[key], + shape=results['pad_shape'][:2], + pad_val=pad_val, + padding_mode=self.padding_mode) + + def transform(self, results: dict) -> dict: + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, ' + repr_str += f'size_divisor={self.size_divisor}, ' + repr_str += f'pad_to_square={self.pad_to_square}, ' + repr_str += f'pad_val={self.pad_val}), ' + repr_str += f'padding_mode={self.padding_mode})' + return repr_str + + +@TRANSFORMS.register_module() +class MultiImgAlbu(BaseTransform): + """Albumentation augmentation. Adds custom transformations from + Albumentations library. Please, visit + `https://albumentations.readthedocs.io` to get more information. An example + of ``transforms`` is as followed: + .. code-block:: + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + Args: + transforms (list[dict]): A list of albu transformations + keymap (dict): Contains {'input key':'albumentation-style key'} + update_pad_shape (bool): Whether update final shape. + additional_targets (dict): Dict with keys - new target name, + values - old target name. ex: {'image2': 'image'}. + """ + def __init__(self, + transforms: List[dict], + keymap: dict = None, + update_pad_shape: bool = False, + additional_targets: dict = None) -> None: + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + if keymap is not None: + keymap = copy.deepcopy(keymap) + self.transforms = transforms + self.filter_lost_elements = False + self.update_pad_shape = update_pad_shape + self.additional_targets = additional_targets + + self.aug = Compose([self.albu_builder(t) for t in self.transforms], \ + additional_targets=self.additional_targets) + + if not keymap: + self.keymap_to_albu = {'img': 'image', 'gt_semantic_seg': 'mask'} + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg): + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if is_str(obj_type): + obj_cls = getattr(albumentations, obj_type) + else: + raise TypeError(f'type must be str, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d: dict, keymap: dict) -> dict: + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + + if updated_dict.get('image', None) is not None: + updated_dict['image'] = np.concatenate(updated_dict['image'], axis=-1) + if updated_dict.get('img', None) is not None: + updated_dict['img'] = np.split(updated_dict['img'], indices_or_sections=2, axis=-1) + return updated_dict + + def transform(self, results: dict) -> dict: + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + + results = self.aug(**results) + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'][0].shape + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'(update_pad_shape={self.update_pad_shape}, ' + repr_str += f'(additional_targets={self.additional_targets})' + return repr_str \ No newline at end of file diff --git a/opencd/datasets/whu_cd.py b/opencd/datasets/whu_cd.py new file mode 100644 index 0000000000000000000000000000000000000000..3e12a9eab14e8b1e7cd09f6e6004ef57fc1e12d0 --- /dev/null +++ b/opencd/datasets/whu_cd.py @@ -0,0 +1,22 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import DATASETS +from .basecddataset import _BaseCDDataset + + +@DATASETS.register_module() +class WHU_CD_Dataset(_BaseCDDataset): + """WHU-CD dataset""" + METAINFO = dict( + classes=('unchanged', 'changed'), + palette=[[0, 0, 0], [255, 255, 255]]) + + def __init__(self, + img_suffix='', + seg_map_suffix='', + format_seg_map='to_binary', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + format_seg_map=format_seg_map, + **kwargs) diff --git a/opencd/engine/__init__.py b/opencd/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7621b0f4f43cec6b562a3dbf75968521f11b7bb5 --- /dev/null +++ b/opencd/engine/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-CD. All rights reserved. +from .hooks import CDVisualizationHook + +__all__ = ['CDVisualizationHook'] diff --git a/opencd/engine/hooks/__init__.py b/opencd/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7416597d4678a9825d9588bd76f2cbef5f255f --- /dev/null +++ b/opencd/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-CD. All rights reserved. +from .visualization_hook import CDVisualizationHook + +__all__ = ['CDVisualizationHook'] diff --git a/opencd/engine/hooks/visualization_hook.py b/opencd/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..a19962243fd8050338ed0d63b8149cfa4122cdeb --- /dev/null +++ b/opencd/engine/hooks/visualization_hook.py @@ -0,0 +1,114 @@ +# Copyright (c) Open-CD. All rights reserved. +import os.path as osp +import warnings +from typing import Optional, Sequence + +import mmcv +import mmengine.fileio as fileio +import numpy as np +from mmengine.runner import Runner + +from mmseg.engine import SegVisualizationHook +from mmseg.structures import SegDataSample +from opencd.registry import HOOKS +from opencd.visualization import CDLocalVisualizer + + +@HOOKS.register_module() +class CDVisualizationHook(SegVisualizationHook): + """Change Detection Visualization Hook. Used to visualize validation and + testing process prediction results. + + Args: + img_shape (tuple): if img_shape is given and `draw_on_from_to_img` is + False, the original images will not be read. + draw_on_from_to_img (bool): whether to draw semantic prediction results + on the original images. If it is False, it means that drawing on + the black board. Defaults to False. + + """ + def __init__(self, + img_shape: tuple = None, + draw_on_from_to_img: bool = False, + draw: bool = False, + interval: int = 50, + show: bool = False, + wait_time: float = 0., + backend_args: Optional[dict] = None): + self.img_shape = img_shape + self.draw_on_from_to_img = draw_on_from_to_img + if self.draw_on_from_to_img: + warnings.warn('`draw_on_from_to_img` works only in ' + 'semantic change detection.') + self._visualizer: CDLocalVisualizer = \ + CDLocalVisualizer.get_current_instance() + self.interval = interval + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args.copy() if backend_args else None + self.draw = draw + if not self.draw: + warnings.warn('The draw is False, it means that the ' + 'hook for visualization will not take ' + 'effect. The results will NOT be ' + 'visualized or stored.') + + def _after_iter(self, + runner: Runner, + batch_idx: int, + data_batch: dict, + outputs: Sequence[SegDataSample], + mode: str = 'val') -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. + mode (str): mode (str): Current mode of runner. Defaults to 'val'. + """ + if self.draw is False or mode == 'train': + return + + if self.every_n_inner_iters(batch_idx, self.interval): + + for output in outputs: + img_path = output.img_path[0] + img_from_to = [] + window_name = osp.basename(img_path).split('.')[0] + if self.img_shape is not None: + assert len(self.img_shape) == 3, \ + '`img_shape` should be (H, W, C)' + else: + img_bytes = fileio.get( + img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + self.img_shape = img.shape + + if self.draw_on_from_to_img: + # for semantic change detection + for _img_path in output.img_path: + _img_bytes = fileio.get( + _img_path, backend_args=self.backend_args) + _img = mmcv.imfrombytes(_img_bytes, channel_order='rgb') + img_from_to.append(_img) + + img = np.zeros(self.img_shape) + self._visualizer.add_datasample( + window_name, + img, + img_from_to, + data_sample=output, + show=self.show, + wait_time=self.wait_time, + step=runner.iter, + draw_gt=False) diff --git a/opencd/evaluation/__init__.py b/opencd/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fba3caf94e1d8e3e31aec9b7757d18a4be1c440 --- /dev/null +++ b/opencd/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-CD. All rights reserved. +from .metrics import SCDMetric + +__all__ = ['SCDMetric'] diff --git a/opencd/evaluation/metrics/__init__.py b/opencd/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..707e1f447fc396d1b74e63cdae3894ba5b46013d --- /dev/null +++ b/opencd/evaluation/metrics/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-CD. All rights reserved. +from .scd_metric import SCDMetric + +__all__ = ['SCDMetric'] diff --git a/opencd/evaluation/metrics/scd_metric.py b/opencd/evaluation/metrics/scd_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe6f0e7b41f33f4940ea45d9ebd3db075f0abd3 --- /dev/null +++ b/opencd/evaluation/metrics/scd_metric.py @@ -0,0 +1,281 @@ +# Copyright (c) Open-CD. All rights reserved. +import copy +import logging +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.dist import (broadcast_object_list, collect_results, + is_main_process) +from mmengine.evaluator.metric import _to_cpu +from mmengine.logging import MMLogger, print_log +from prettytable import PrettyTable + +from mmseg.evaluation import IoUMetric +from opencd.registry import METRICS + + +@METRICS.register_module() +class SCDMetric(IoUMetric): + """Change Detection evaluation metric. + + Args: + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to 'binary'. + semantic_prefix (str, optional): The prefix that will be added in the + metric names to disambiguate homonymous metrics of different + evaluators. Defaults to 'semantic'. + cal_sek bool: Whether to calculate the separated kappa (SeK) + coefficient. Defaults: False. + """ + + def __init__(self, + prefix: Optional[str] = 'binary', + semantic_prefix: Optional[str] = 'semantic', + cal_sek: bool = False, + **kwargs) -> None: + super().__init__(prefix=prefix, **kwargs) + + self.semantic_results: List[Any] = [] + self.semantic_prefix = semantic_prefix + self.cal_sek = cal_sek + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + num_semantic_classes = len(self.dataset_meta['semantic_classes']) + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'].squeeze() + label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label) + pred_label_from = data_sample['pred_sem_seg_from']['data'].squeeze() + label_from = data_sample['gt_sem_seg_from']['data'].squeeze().to(pred_label_from) + pred_label_to = data_sample['pred_sem_seg_to']['data'].squeeze() + label_to = data_sample['gt_sem_seg_to']['data'].squeeze().to(pred_label_to) + + self.results.append( + self.intersect_and_union(pred_label, label, num_classes, + self.ignore_index)) + # for semantic pred + self.semantic_results.append( + self.intersect_and_union(pred_label_from, label_from, num_semantic_classes, + self.ignore_index)) + self.semantic_results.append( + self.intersect_and_union(pred_label_to, label_to, num_semantic_classes, + self.ignore_index)) + + def get_sek(self, results: list) -> np.array: + """calculate the Sek value. + + Args: + pre_eval_results (list[tuple[torch.Tensor]]): per image eval results + for computing evaluation metric + + Returns: + [torch.tensor]: The Sek value. + """ + assert len(results) == 4 + + hist_00 = sum(results[0])[0] + + hist_00_list = torch.zeros(len(results[0][0])) + hist_00_list[0] = hist_00 + + total_area_intersect = sum(results[0]) - hist_00_list + total_area_pred_label = sum(results[2]) - hist_00_list + total_area_label = sum(results[3]) - hist_00_list + + # foreground + fg_intersect_sum = total_area_label[1:].sum( + ) - total_area_pred_label[0] + fg_area_union_sum = total_area_label.sum() + + po = total_area_intersect.sum() / total_area_label.sum() + pe = (total_area_label * total_area_pred_label).sum() / \ + total_area_pred_label.sum() ** 2 + + kappa0 = (po - pe) / (1 - pe) + # the `iou_fg` is equal to the binary `changed` iou. + iou_fg = fg_intersect_sum / fg_area_union_sum + sek = (kappa0 * torch.exp(iou_fg)) / torch.e + + return sek.numpy() # consistent with other metrics. + + def compute_metrics(self, binary_results: list, semantic_results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + binary_results (list): The processed results of each batch. + semantic_results (list): The semantic results of each batch + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + binary_results = tuple(zip(*binary_results)) + semantic_results = tuple(zip(*semantic_results)) + assert len(binary_results) == 4 and len(semantic_results) == 4 + + # for binary results + binary_total_area_intersect = sum(binary_results[0]) + binary_total_area_union = sum(binary_results[1]) + binary_total_area_pred_label = sum(binary_results[2]) + binary_total_area_label = sum(binary_results[3]) + binary_ret_metrics = self.total_area_to_metrics( + binary_total_area_intersect, binary_total_area_union, binary_total_area_pred_label, + binary_total_area_label, self.metrics, self.nan_to_num, self.beta) + + binary_class_names = self.dataset_meta['classes'] + + # summary table + binary_ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in binary_ret_metrics.items() + }) + binary_metrics = dict() + for key, val in binary_ret_metrics_summary.items(): + if key == 'aAcc': + binary_metrics[key] = val + else: + binary_metrics['m' + key] = val + + # each class table + binary_ret_metrics.pop('aAcc', None) + binary_ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in binary_ret_metrics.items() + }) + binary_ret_metrics_class.update({'Class': binary_class_names}) + binary_ret_metrics_class.move_to_end('Class', last=False) + binary_class_table_data = PrettyTable() + for key, val in binary_ret_metrics_class.items(): + binary_class_table_data.add_column(key, val) + + print_log('per binary class results:', logger) + print_log('\n' + binary_class_table_data.get_string(), logger=logger) + + # for semantic results + semantic_total_area_intersect = sum(semantic_results[0]) + semantic_total_area_union = sum(semantic_results[1]) + semantic_total_area_pred_label = sum(semantic_results[2]) + semantic_total_area_label = sum(semantic_results[3]) + semantic_ret_metrics = self.total_area_to_metrics( + semantic_total_area_intersect, semantic_total_area_union, semantic_total_area_pred_label, + semantic_total_area_label, self.metrics, self.nan_to_num, self.beta) + + semantic_class_names = self.dataset_meta['semantic_classes'] + + # summary table + semantic_ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in semantic_ret_metrics.items() + }) + # for semantic change detection + if self.cal_sek: + sek = self.get_sek(semantic_results) + semantic_ret_metrics_summary.update({'Sek': np.round(sek * 100, 2)}) + semantic_ret_metrics_summary.update({'SCD_Score': \ + np.round(0.3 * binary_ret_metrics_summary['IoU'] + 0.7 * sek * 100, 2)}) + + semantic_metrics = dict() + for key, val in semantic_ret_metrics_summary.items(): + if key in ['aAcc', 'Sek', 'SCD_Score']: + semantic_metrics[key] = val + else: + semantic_metrics['m' + key] = val + + # each class table + semantic_ret_metrics.pop('aAcc', None) + semantic_ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in semantic_ret_metrics.items() + }) + semantic_ret_metrics_class.update({'Class': semantic_class_names}) + semantic_ret_metrics_class.move_to_end('Class', last=False) + semantic_class_table_data = PrettyTable() + for key, val in semantic_ret_metrics_class.items(): + semantic_class_table_data.add_column(key, val) + + print_log('per semantic class results:', logger) + print_log('\n' + semantic_class_table_data.get_string(), logger=logger) + + return binary_metrics, semantic_metrics + + def evaluate(self, size: int) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + if len(self.results) == 0: + print_log( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.', + logger='current', + level=logging.WARNING) + if len(self.semantic_results) == 0: + print_log( + f'{self.__class__.__name__} got empty `self.semantic_results`. ' + 'Please ensure that the processed results are properly added ' + 'into `self.semantic_results` in `process` method.', + logger='current', + level=logging.WARNING) + + binary_results = collect_results(self.results, size, self.collect_device) + semantic_results = collect_results(self.semantic_results, \ + size * 2, self.collect_device) + + if is_main_process(): + # cast all tensors in results list to cpu + binary_results = _to_cpu(binary_results) + semantic_results = _to_cpu(semantic_results) + _binary_metrics, _semantic_metrics = \ + self.compute_metrics(binary_results, semantic_results) # type: ignore + # Add prefix to metric names + if self.prefix: + _binary_metrics = { + '/'.join((self.prefix, k)): v + for k, v in _binary_metrics.items() + } + _semantic_metrics = { + '/'.join((self.semantic_prefix, k)): v + for k, v in _semantic_metrics.items() + } + _metrics = {**_binary_metrics, **_semantic_metrics} + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results list + self.results.clear() + self.semantic_results.clear() + return metrics[0] diff --git a/opencd/models/.DS_Store b/opencd/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4ad0e586ef922813c08ea93e9315859589d31f0d Binary files /dev/null and b/opencd/models/.DS_Store differ diff --git a/opencd/models/__init__.py b/opencd/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d01125e722524f2849ade5e84763ee464ddd18e7 --- /dev/null +++ b/opencd/models/__init__.py @@ -0,0 +1,7 @@ +from .backbones import * +from .change_detectors import * +from .data_preprocessor import * +from .decode_heads import * +from .losses import * +from .necks import * +from .utils import * diff --git a/opencd/models/__pycache__/__init__.cpython-311.pyc b/opencd/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c0e3ae98b3c462b43f7b8870de8d9fbf59187a7 Binary files /dev/null and b/opencd/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/__pycache__/data_preprocessor.cpython-311.pyc b/opencd/models/__pycache__/data_preprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3454ab9df0b49bb18dd57b67a47e49120b7a06a6 Binary files /dev/null and b/opencd/models/__pycache__/data_preprocessor.cpython-311.pyc differ diff --git a/opencd/models/backbones/__init__.py b/opencd/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45f99fbfbdf4bd7853569bfbf69e3f62aeffcbb8 --- /dev/null +++ b/opencd/models/backbones/__init__.py @@ -0,0 +1,13 @@ +from .fcsn import FC_EF, FC_Siam_conc, FC_Siam_diff +from .ifn import IFN +from .interaction_resnest import IA_ResNeSt +from .interaction_resnet import IA_ResNetV1c +from .interaction_mit import IA_MixVisionTransformer +from .snunet import SNUNet_ECAM +from .tinycd import TinyCD +from .tinynet import TinyNet +from .hanet import HAN + +__all__ = ['IA_ResNetV1c', 'IA_ResNeSt', 'FC_EF', 'FC_Siam_diff', + 'FC_Siam_conc', 'SNUNet_ECAM', 'TinyCD', 'IFN', + 'TinyNet', 'IA_MixVisionTransformer', 'HAN'] \ No newline at end of file diff --git a/opencd/models/backbones/__pycache__/__init__.cpython-311.pyc b/opencd/models/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21a6c1aadec36224713a68b4a0ee563a4385ac07 Binary files /dev/null and b/opencd/models/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/fcsn.cpython-311.pyc b/opencd/models/backbones/__pycache__/fcsn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..854be621c584b8c711ff1ea5f4276a590f04f611 Binary files /dev/null and b/opencd/models/backbones/__pycache__/fcsn.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/hanet.cpython-311.pyc b/opencd/models/backbones/__pycache__/hanet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdd4183d36f4e12dfc51deb27b6245db25c01d40 Binary files /dev/null and b/opencd/models/backbones/__pycache__/hanet.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/ifn.cpython-311.pyc b/opencd/models/backbones/__pycache__/ifn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9115ecebc8f6c643a7c3e7bac3a2872c689f134d Binary files /dev/null and b/opencd/models/backbones/__pycache__/ifn.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/interaction_mit.cpython-311.pyc b/opencd/models/backbones/__pycache__/interaction_mit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7430c47a70e5d04c17f90fd64b29db526a409151 Binary files /dev/null and b/opencd/models/backbones/__pycache__/interaction_mit.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/interaction_resnest.cpython-311.pyc b/opencd/models/backbones/__pycache__/interaction_resnest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a568ebc90297313880104489f48e3974b67b833 Binary files /dev/null and b/opencd/models/backbones/__pycache__/interaction_resnest.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/interaction_resnet.cpython-311.pyc b/opencd/models/backbones/__pycache__/interaction_resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3138c3b8e87b487e2940b7f7d60e89ce8fb41f5 Binary files /dev/null and b/opencd/models/backbones/__pycache__/interaction_resnet.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/snunet.cpython-311.pyc b/opencd/models/backbones/__pycache__/snunet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a10a53a2aad8488d730316f0d02340c2a627cf1 Binary files /dev/null and b/opencd/models/backbones/__pycache__/snunet.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/tinycd.cpython-311.pyc b/opencd/models/backbones/__pycache__/tinycd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb6bb9fd4555750dc84ee880cb383ea9fa67e98 Binary files /dev/null and b/opencd/models/backbones/__pycache__/tinycd.cpython-311.pyc differ diff --git a/opencd/models/backbones/__pycache__/tinynet.cpython-311.pyc b/opencd/models/backbones/__pycache__/tinynet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2625002ef09c91136d6a6c12bd72fde8ec409f49 Binary files /dev/null and b/opencd/models/backbones/__pycache__/tinynet.cpython-311.pyc differ diff --git a/opencd/models/backbones/fcsn.py b/opencd/models/backbones/fcsn.py new file mode 100644 index 0000000000000000000000000000000000000000..753630492882dbe40d43c9d1b10d325a400ef309 --- /dev/null +++ b/opencd/models/backbones/fcsn.py @@ -0,0 +1,489 @@ +""" +Daudt, R. C., Le Saux, B., & Boulch, A. +"Fully convolutional siamese networks for change detection". +In 2018 25th IEEE International Conference on Image Processing (ICIP) +(pp. 4063-4067). IEEE. + +Some code in this file is borrowed from: +https://github.com/rcdaudt/fully_convolutional_change_detection +https://github.com/Bobholamovic/CDLab +https://github.com/likyoo/Siam-NestedUNet +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.padding import ReplicationPad2d + +from opencd.registry import MODELS + + +@MODELS.register_module() +class FC_EF(nn.Module): + """FC_EF segmentation network.""" + + def __init__(self, in_channels, base_channel=16): + super(FC_EF, self).__init__() + + filters = [base_channel, base_channel * 2, base_channel * 4, + base_channel * 8, base_channel * 16] + + self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(filters[0]) + self.do11 = nn.Dropout2d(p=0.2) + self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(filters[0]) + self.do12 = nn.Dropout2d(p=0.2) + + self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(filters[1]) + self.do21 = nn.Dropout2d(p=0.2) + self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(filters[1]) + self.do22 = nn.Dropout2d(p=0.2) + + self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(filters[2]) + self.do31 = nn.Dropout2d(p=0.2) + self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(filters[2]) + self.do32 = nn.Dropout2d(p=0.2) + self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(filters[2]) + self.do33 = nn.Dropout2d(p=0.2) + + self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(filters[3]) + self.do41 = nn.Dropout2d(p=0.2) + self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(filters[3]) + self.do42 = nn.Dropout2d(p=0.2) + self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(filters[3]) + self.do43 = nn.Dropout2d(p=0.2) + + self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv43d = nn.ConvTranspose2d(filters[4], filters[3], kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(filters[3]) + self.do43d = nn.Dropout2d(p=0.2) + self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(filters[3]) + self.do42d = nn.Dropout2d(p=0.2) + self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(filters[2]) + self.do41d = nn.Dropout2d(p=0.2) + + self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv33d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(filters[2]) + self.do33d = nn.Dropout2d(p=0.2) + self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(filters[2]) + self.do32d = nn.Dropout2d(p=0.2) + self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(filters[1]) + self.do31d = nn.Dropout2d(p=0.2) + + self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv22d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(filters[1]) + self.do22d = nn.Dropout2d(p=0.2) + self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(filters[0]) + self.do21d = nn.Dropout2d(p=0.2) + + self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv12d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(filters[0]) + self.do12d = nn.Dropout2d(p=0.2) + self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) + + def forward(self, x1, x2): + """Forward method.""" + x = torch.cat((x1, x2), 1) + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) + x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43, kernel_size=2, stride=2) + + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) + x4d = torch.cat((pad4(x4d), x43), 1) + x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) + x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) + x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) + x3d = torch.cat((pad3(x3d), x33), 1) + x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) + x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) + x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) + x2d = torch.cat((pad2(x2d), x22), 1) + x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) + x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) + x1d = torch.cat((pad1(x1d), x12), 1) + x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) + x11d = self.conv11d(x12d) + + return (x11d,) + + +@MODELS.register_module() +class FC_Siam_diff(nn.Module): + """FC_Siam_diff segmentation network.""" + + def __init__(self, in_channels, base_channel=16): + super(FC_Siam_diff, self).__init__() + + filters = [base_channel, base_channel * 2, base_channel * 4, + base_channel * 8, base_channel * 16] + + self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(filters[0]) + self.do11 = nn.Dropout2d(p=0.2) + self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(filters[0]) + self.do12 = nn.Dropout2d(p=0.2) + + self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(filters[1]) + self.do21 = nn.Dropout2d(p=0.2) + self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(filters[1]) + self.do22 = nn.Dropout2d(p=0.2) + + self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(filters[2]) + self.do31 = nn.Dropout2d(p=0.2) + self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(filters[2]) + self.do32 = nn.Dropout2d(p=0.2) + self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(filters[2]) + self.do33 = nn.Dropout2d(p=0.2) + + self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(filters[3]) + self.do41 = nn.Dropout2d(p=0.2) + self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(filters[3]) + self.do42 = nn.Dropout2d(p=0.2) + self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(filters[3]) + self.do43 = nn.Dropout2d(p=0.2) + + self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv43d = nn.ConvTranspose2d(filters[4], filters[3], kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(filters[3]) + self.do43d = nn.Dropout2d(p=0.2) + self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(filters[3]) + self.do42d = nn.Dropout2d(p=0.2) + self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(filters[2]) + self.do41d = nn.Dropout2d(p=0.2) + + self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv33d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(filters[2]) + self.do33d = nn.Dropout2d(p=0.2) + self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(filters[2]) + self.do32d = nn.Dropout2d(p=0.2) + self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(filters[1]) + self.do31d = nn.Dropout2d(p=0.2) + + self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv22d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(filters[1]) + self.do22d = nn.Dropout2d(p=0.2) + self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(filters[0]) + self.do21d = nn.Dropout2d(p=0.2) + + self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv12d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(filters[0]) + self.do12d = nn.Dropout2d(p=0.2) + self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) + + def forward(self, x1, x2): + """Forward method.""" + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) + x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) + + #################################################### + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) + x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) + + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) + x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) + x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) + x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) + x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) + x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) + x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) + x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) + x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) + x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) + x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) + x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) + x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) + x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) + x11d = self.conv11d(x12d) + + return (x11d,) + + +@MODELS.register_module() +class FC_Siam_conc(nn.Module): + """FC_Siam_conc segmentation network.""" + + def __init__(self, in_channels, base_channel=16): + super(FC_Siam_conc, self).__init__() + + filters = [base_channel, base_channel * 2, base_channel * 4, + base_channel * 8, base_channel * 16] + + self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(filters[0]) + self.do11 = nn.Dropout2d(p=0.2) + self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(filters[0]) + self.do12 = nn.Dropout2d(p=0.2) + + self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(filters[1]) + self.do21 = nn.Dropout2d(p=0.2) + self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(filters[1]) + self.do22 = nn.Dropout2d(p=0.2) + + self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(filters[2]) + self.do31 = nn.Dropout2d(p=0.2) + self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(filters[2]) + self.do32 = nn.Dropout2d(p=0.2) + self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(filters[2]) + self.do33 = nn.Dropout2d(p=0.2) + + self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(filters[3]) + self.do41 = nn.Dropout2d(p=0.2) + self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(filters[3]) + self.do42 = nn.Dropout2d(p=0.2) + self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(filters[3]) + self.do43 = nn.Dropout2d(p=0.2) + + self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv43d = nn.ConvTranspose2d(filters[3]+filters[4], filters[3], kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(filters[3]) + self.do43d = nn.Dropout2d(p=0.2) + self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(filters[3]) + self.do42d = nn.Dropout2d(p=0.2) + self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(filters[2]) + self.do41d = nn.Dropout2d(p=0.2) + + self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv33d = nn.ConvTranspose2d(filters[2]+filters[3], filters[2], kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(filters[2]) + self.do33d = nn.Dropout2d(p=0.2) + self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(filters[2]) + self.do32d = nn.Dropout2d(p=0.2) + self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(filters[1]) + self.do31d = nn.Dropout2d(p=0.2) + + self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv22d = nn.ConvTranspose2d(filters[1]+filters[2], filters[1], kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(filters[1]) + self.do22d = nn.Dropout2d(p=0.2) + self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(filters[0]) + self.do21d = nn.Dropout2d(p=0.2) + + self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv12d = nn.ConvTranspose2d(filters[0]+filters[1], filters[0], kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(filters[0]) + self.do12d = nn.Dropout2d(p=0.2) + self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) + + def forward(self, x1, x2): + """Forward method.""" + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) + x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) + + #################################################### + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) + x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) + + #################################################### + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) + x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) + x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) + x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) + x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) + x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) + x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) + x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) + x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) + x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) + x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) + x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) + x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) + x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) + x11d = self.conv11d(x12d) + + return (x11d,) \ No newline at end of file diff --git a/opencd/models/backbones/hanet.py b/opencd/models/backbones/hanet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a10ac3a3be198a1805a91f65b1bdecd71ca5cca --- /dev/null +++ b/opencd/models/backbones/hanet.py @@ -0,0 +1,289 @@ +""" +C. HAN, C. WU, H. GUO, M. HU, AND H. CHEN, +“HANET: A HIERARCHICAL ATTENTION NETWORK FOR CHANGE DETECTION WITH BI-TEMPORAL VERY-HIGH-RESOLUTION REMOTE SENSING IMAGES,” +IEEE J. SEL. TOP. APPL. EARTH OBS. REMOTE SENS., PP. 1-17, 2023, DOI: 10.1109/JSTARS.2023.3264802. + +Some code in this file is borrowed from: +https://github.com/ChengxiHAN/HANet-CD/blob/main/models/HANet.py +""" + +import torch +import torch.nn as nn + +from opencd.registry import MODELS + + +class CAM_Module(nn.Module): + """ Channel attention module""" + + def __init__(self, in_dim): + super(CAM_Module, self).__init__() + self.chanel_in = in_dim + + self.gamma = nn.Parameter(torch.zeros(1)) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + m_batchsize, C, height, width = x.size() + proj_query = x.view(m_batchsize, C, -1) + proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) + + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, C, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, C, height, width) + out = self.gamma * out + x + return out + + +class Conv_CAM_Layer(nn.Module): + + def __init__(self, in_ch, out_in, use_pam=False): + super(Conv_CAM_Layer, self).__init__() + + self.attn = nn.Sequential( + nn.Conv2d(in_ch, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.PReLU(), + CAM_Module(32), + nn.Conv2d(32, out_in, kernel_size=3, padding=1), + nn.BatchNorm2d(out_in), + nn.PReLU() + ) + + def forward(self, x): + return self.attn(x) + + +class FEC(nn.Module): + """feature extraction cell""" + #convolutional block + def __init__(self, in_ch, mid_ch, out_ch): + super(FEC, self).__init__() + self.activation = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1,bias=True) + self.bn1 = nn.BatchNorm2d(mid_ch) + self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=1, stride=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_ch) + + def forward(self, x): + x = self.conv1(x) + identity = x + x = self.bn1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.bn2(x) + output = self.activation(x + identity) + return output + + +class RowAttention(nn.Module): + + def __init__(self, in_dim, q_k_dim, use_pam=False): + ''' + Parameters + ---------- + in_dim : int + channel of input img tensor + q_k_dim: int + channel of Q, K vector + device : torch.device + ''' + super(RowAttention, self).__init__() + self.in_dim = in_dim + self.q_k_dim = q_k_dim + + self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + ''' + Parameters + ---------- + x : Tensor + 4-D , (batch, in_dims, height, width) -- (b,c1,h,w) + ''' + b, _, h, w = x.size() + + Q = self.query_conv(x) # size = (b,c2, h,w) + K = self.key_conv(x) # size = (b, c2, h, w) + V = self.value_conv(x) # size = (b, c1,h,w) + + Q = Q.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w).permute(0, 2, 1) # size = (b*h,w,c2) + K = K.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) # size = (b*h,c2,w) + V = V.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) # size = (b*h, c1,w) + + row_attn = torch.bmm(Q, K) + row_attn = self.softmax(row_attn) + out = torch.bmm(V, row_attn.permute(0, 2, 1)) + out = out.view(b, h, -1, w).permute(0, 2, 1, 3) + out = self.gamma * out + x + return out + + +class ColAttention(nn.Module): + + def __init__(self, in_dim, q_k_dim, use_pam=False): + ''' + Parameters + ---------- + in_dim : int + channel of input img tensor + q_k_dim: int + channel of Q, K vector + device : torch.device + ''' + super(ColAttention, self).__init__() + self.in_dim = in_dim + self.q_k_dim = q_k_dim + + self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) + self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + ''' + Parameters + ---------- + x : Tensor + 4-D , (batch, in_dims, height, width) -- (b,c1,h,w) + ''' + + b, _, h, w = x.size() + + Q = self.query_conv(x) # size = (b,c2, h,w) + K = self.key_conv(x) # size = (b, c2, h, w) + V = self.value_conv(x) # size = (b, c1,h,w) + + Q = Q.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h).permute(0, 2, 1) # size = (b*w,h,c2) + K = K.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c2,h) + V = V.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c1,h) + + # size = (b*w,h,h) [:,i,j] + col_attn = torch.bmm(Q, K) + col_attn = self.softmax(col_attn) + out = torch.bmm(V, col_attn.permute(0, 2, 1)) + # size = (b,c1,h,w) + out = out.view(b, w, -1, h).permute(0, 2, 3, 1) + out = self.gamma * out + x + + return out + + +@MODELS.register_module() +class HAN(nn.Module): + """HANet""" + def __init__(self, in_channels, base_channel=40): + super(HAN, self).__init__() + torch.nn.Module.dump_patches = True + n1 = base_channel # the initial number of channels of feature map + filters = [n1, n1 * 2, n1 * 4, n1 * 8] + + self.conv0_0 = nn.Conv2d(in_channels, n1, kernel_size=5, padding=2, stride=1) + self.conv0 = FEC(filters[0], filters[0], filters[0]) + self.conv2 = FEC(filters[0], filters[1], filters[1]) + self.conv4 = FEC(filters[1], filters[2], filters[2]) + self.conv5 = FEC(filters[2], filters[3], filters[3]) + self.conv6 = nn.Conv2d(sum(filters), filters[1], kernel_size=1, stride=1) + + self.conv6_1_1 = nn.Conv2d(filters[0] * 2, filters[0], padding=1, kernel_size=3, groups=filters[0] // 2,dilation=1) + self.conv6_1_2 = nn.Conv2d(filters[0] * 2, filters[0], padding=2, kernel_size=3, groups=filters[0] // 2,dilation=2) + self.conv6_1_3 = nn.Conv2d(filters[0] * 2, filters[0], padding=3, kernel_size=3, groups=filters[0] // 2,dilation=3) + self.conv6_1_4 = nn.Conv2d(filters[0] * 2, filters[0], padding=4, kernel_size=3, groups=filters[0] // 2,dilation=4) + self.conv1_1 = nn.Conv2d(filters[0] * 4, filters[0], kernel_size=1, stride=1) + + self.conv6_2_1 = nn.Conv2d(filters[1] * 2, filters[1], padding=1, kernel_size=3, groups=filters[1] // 2, dilation=1) + self.conv6_2_2 = nn.Conv2d(filters[1] * 2, filters[1], padding=2, kernel_size=3, groups=filters[1] // 2, dilation=2) + self.conv6_2_3 = nn.Conv2d(filters[1] * 2, filters[1], padding=3, kernel_size=3, groups=filters[1] // 2, dilation=3) + self.conv6_2_4 = nn.Conv2d(filters[1] * 2, filters[1], padding=4, kernel_size=3, groups=filters[1] // 2, dilation=4) + self.conv2_1 = nn.Conv2d(filters[1] * 4, filters[1], kernel_size=1, stride=1) + + self.conv6_3_1 = nn.Conv2d(filters[2] * 2, filters[2], padding=1, kernel_size=3, groups=filters[2] // 2, dilation=1) + self.conv6_3_2 = nn.Conv2d(filters[2] * 2, filters[2], padding=2, kernel_size=3, groups=filters[2] // 2, dilation=2) + self.conv6_3_3 = nn.Conv2d(filters[2] * 2, filters[2], padding=3, kernel_size=3, groups=filters[2] // 2, dilation=3) + self.conv6_3_4 = nn.Conv2d(filters[2] * 2, filters[2], padding=4, kernel_size=3, groups=filters[2] // 2, dilation=4) + self.conv3_1 = nn.Conv2d(filters[2] * 4, filters[2], kernel_size=1, stride=1) + + self.conv6_4_1 = nn.Conv2d(filters[3]*2, filters[3], padding=1, kernel_size=3, groups=filters[3]//2, dilation=1) + self.conv6_4_2 = nn.Conv2d(filters[3]*2, filters[3], padding=2, kernel_size=3, groups=filters[3]//2, dilation=2) + self.conv6_4_3 = nn.Conv2d(filters[3]*2, filters[3], padding=3, kernel_size=3, groups=filters[3]//2, dilation=3) + self.conv6_4_4 = nn.Conv2d(filters[3]*2, filters[3], padding=4, kernel_size=3, groups=filters[3]//2, dilation=4) + self.conv4_1 = nn.Conv2d(filters[3]*4, filters[3], kernel_size=1, stride=1) + + # SA + self.cam_attention_1 = Conv_CAM_Layer(filters[0], filters[0], False) #SA4 + self.cam_attention_2 = Conv_CAM_Layer(filters[1], filters[1], False) #SA3 + self.cam_attention_3 = Conv_CAM_Layer(filters[2], filters[2], False) #SA2 + self.cam_attention_4 = Conv_CAM_Layer(filters[3], filters[3], False) #SA1 + + #Row Attention + self.row_attention_1 = RowAttention(filters[0], filters[0], False) # SA4 + self.row_attention_2 = RowAttention(filters[1], filters[1], False) # SA3 + self.row_attention_3 = RowAttention(filters[2], filters[2], False) # SA2 + self.row_attention_4 = RowAttention(filters[3], filters[3], False) # SA1 + + # Col Attention + self.col_attention_1 = ColAttention(filters[0], filters[0], False) # SA4 + self.col_attention_2 = ColAttention(filters[1], filters[1], False) # SA3 + self.col_attention_3 = ColAttention(filters[2], filters[2], False) # SA2 + self.col_attention_4 = ColAttention(filters[3], filters[3], False) # SA1 + + self.c4_conv = nn.Conv2d(filters[3], filters[1], kernel_size=3, padding=1) + self.c3_conv = nn.Conv2d(filters[2], filters[1], kernel_size=3, padding=1) + self.c2_conv = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.c1_conv = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + + self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + + self.Up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.Up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) + self.Up3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x1, x2): + x1 = self.conv0(self.conv0_0(x1)) # Output of the first scale + x3 = self.conv2(self.pool(x1)) + x4 = self.conv4(self.pool(x3)) + A_F4 = self.conv5(self.pool(x4)) + + x2 = self.conv0(self.conv0_0(x2)) + x5 = self.conv2(self.pool(x2)) + x6 = self.conv4(self.pool(x5)) + A_F8 = self.conv5(self.pool(x6)) + + c4_1 = self.conv4_1( + torch.cat([self.conv6_4_1(torch.cat([A_F4, A_F8], 1)), self.conv6_4_2(torch.cat([A_F4, A_F8], 1)), + self.conv6_4_3(torch.cat([A_F4, A_F8], 1)), self.conv6_4_4(torch.cat([A_F4, A_F8], 1))], 1)) + c4 = self.cam_attention_4(c4_1) + self.row_attention_4(self.col_attention_4(c4_1)) + + c3_1 = (self.conv3_1(torch.cat( + [self.conv6_3_1(torch.cat([x4, x6], 1)), self.conv6_3_2(torch.cat([x4, x6], 1)), + self.conv6_3_3(torch.cat([x4, x6], 1)), self.conv6_3_4(torch.cat([x4, x6], 1))], 1))) + c3 = torch.cat([(self.cam_attention_3(c3_1)+self.row_attention_3(self.col_attention_3(c3_1))), self.Up1(c4)], 1) + + c2_1 = (self.conv2_1(torch.cat( + [self.conv6_2_1(torch.cat([x3, x5], 1)), self.conv6_2_2(torch.cat([x3, x5], 1)), + self.conv6_2_3(torch.cat([x3, x5], 1)), self.conv6_2_4(torch.cat([x3, x5], 1))], 1))) + c2 = torch.cat([(self.cam_attention_2(c2_1)+self.row_attention_2(self.col_attention_2(c2_1))), self.Up1(c3)], 1) + + c1_1 = (self.conv1_1(torch.cat( + [self.conv6_1_1(torch.cat([x1, x2], 1)), self.conv6_1_2(torch.cat([x1, x2], 1)), + self.conv6_1_3(torch.cat([x1, x2], 1)), self.conv6_1_4(torch.cat([x1, x2], 1))], 1))) + c1 = torch.cat([(self.cam_attention_1(c1_1)+self.row_attention_1(self.col_attention_1(c1_1))), self.Up1(c2)], 1) + out1 = self.conv6(c1) + + return (out1, ) diff --git a/opencd/models/backbones/ifn.py b/opencd/models/backbones/ifn.py new file mode 100644 index 0000000000000000000000000000000000000000..1e656e9cb3318e4a6346894f0602d11a2fa4002b --- /dev/null +++ b/opencd/models/backbones/ifn.py @@ -0,0 +1,235 @@ +# credits: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import vgg16 + +from opencd.registry import MODELS + + +def get_norm_layer(): + # TODO: select appropriate norm layer + return nn.BatchNorm2d + + +def get_act_layer(): + # TODO: select appropriate activation layer + return nn.ReLU + +def make_norm(*args, **kwargs): + norm_layer = get_norm_layer() + return norm_layer(*args, **kwargs) + + +def make_act(*args, **kwargs): + act_layer = get_act_layer() + return act_layer(*args, **kwargs) + +class BasicConv(nn.Module): + def __init__( + self, in_ch, out_ch, + kernel_size, pad_mode='Zero', + bias='auto', norm=False, act=False, + **kwargs + ): + super().__init__() + seq = [] + if kernel_size >= 2: + seq.append(getattr(nn, pad_mode.capitalize()+'Pad2d')(kernel_size//2)) + seq.append( + nn.Conv2d( + in_ch, out_ch, kernel_size, + stride=1, padding=0, + bias=(False if norm else True) if bias=='auto' else bias, + **kwargs + ) + ) + if norm: + if norm is True: + norm = make_norm(out_ch) + seq.append(norm) + if act: + if act is True: + act = make_act() + seq.append(act) + self.seq = nn.Sequential(*seq) + + def forward(self, x): + return self.seq(x) + +class Conv1x1(BasicConv): + def __init__(self, in_ch, out_ch, pad_mode='Zero', bias='auto', norm=False, act=False, **kwargs): + super().__init__(in_ch, out_ch, 1, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) + +class ChannelAttention(nn.Module): + def __init__(self, in_ch, ratio=8): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.fc1 = Conv1x1(in_ch, in_ch//ratio, bias=False, act=True) + self.fc2 = Conv1x1(in_ch//ratio, in_ch, bias=False) + + def forward(self,x): + avg_out = self.fc2(self.fc1(self.avg_pool(x))) + max_out = self.fc2(self.fc1(self.max_pool(x))) + out = avg_out + max_out + return F.sigmoid(out) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super().__init__() + self.conv = BasicConv(2, 1, kernel_size, bias=False) + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out = torch.max(x, dim=1, keepdim=True)[0] + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv(x) + return F.sigmoid(x) + + +class VGG16FeaturePicker(nn.Module): + def __init__(self, indices=(3,8,15,22,29)): + super().__init__() + features = list(vgg16(pretrained=True).features)[:30] + self.features = nn.ModuleList(features).eval() + self.indices = set(indices) + + def forward(self, x): + picked_feats = [] + for idx, model in enumerate(self.features): + x = model(x) + if idx in self.indices: + picked_feats.append(x) + return picked_feats + + +def conv2d_bn(in_ch, out_ch, with_dropout=True): + lst = [ + nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1), + nn.PReLU(), + make_norm(out_ch), + ] + if with_dropout: + lst.append(nn.Dropout(p=0.6)) + return nn.Sequential(*lst) + + +@MODELS.register_module() +class IFN(nn.Module): + def __init__(self, use_dropout=False): + super().__init__() + + self.encoder1 = self.encoder2 = VGG16FeaturePicker() + + self.sa1 = SpatialAttention() + self.sa2= SpatialAttention() + self.sa3 = SpatialAttention() + self.sa4 = SpatialAttention() + self.sa5 = SpatialAttention() + + self.ca1 = ChannelAttention(in_ch=1024) + self.bn_ca1 = make_norm(1024) + self.o1_conv1 = conv2d_bn(1024, 512, use_dropout) + self.o1_conv2 = conv2d_bn(512, 512, use_dropout) + self.bn_sa1 = make_norm(512) + self.o1_conv3 = Conv1x1(512, 1) + self.trans_conv1 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2) + + self.ca2 = ChannelAttention(in_ch=1536) + self.bn_ca2 = make_norm(1536) + self.o2_conv1 = conv2d_bn(1536, 512, use_dropout) + self.o2_conv2 = conv2d_bn(512, 256, use_dropout) + self.o2_conv3 = conv2d_bn(256, 256, use_dropout) + self.bn_sa2 = make_norm(256) + self.o2_conv4 = Conv1x1(256, 1) + self.trans_conv2 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2) + + self.ca3 = ChannelAttention(in_ch=768) + self.o3_conv1 = conv2d_bn(768, 256, use_dropout) + self.o3_conv2 = conv2d_bn(256, 128, use_dropout) + self.o3_conv3 = conv2d_bn(128, 128, use_dropout) + self.bn_sa3 = make_norm(128) + self.o3_conv4 = Conv1x1(128, 1) + self.trans_conv3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2) + + self.ca4 = ChannelAttention(in_ch=384) + self.o4_conv1 = conv2d_bn(384, 128, use_dropout) + self.o4_conv2 = conv2d_bn(128, 64, use_dropout) + self.o4_conv3 = conv2d_bn(64, 64, use_dropout) + self.bn_sa4 = make_norm(64) + self.o4_conv4 = Conv1x1(64, 1) + self.trans_conv4 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) + + self.ca5 = ChannelAttention(in_ch=192) + self.o5_conv1 = conv2d_bn(192, 64, use_dropout) + self.o5_conv2 = conv2d_bn(64, 32, use_dropout) + self.o5_conv3 = conv2d_bn(32, 16, use_dropout) + self.bn_sa5 = make_norm(16) + self.o5_conv4 = Conv1x1(16, 1) + + def forward(self, t1, t2): + # Extract bi-temporal features + with torch.no_grad(): + self.encoder1.eval(), self.encoder2.eval() + t1_feats = self.encoder1(t1) + t2_feats = self.encoder2(t2) + + t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats + t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats + + # Multi-level decoding + x = torch.cat([t1_f_l29, t2_f_l29], dim=1) + x = self.o1_conv1(x) + x = self.o1_conv2(x) + x = self.sa1(x) * x + x = self.bn_sa1(x) + + out1 = self.o1_conv3(x) + + x = self.trans_conv1(x) + x = torch.cat([x, t1_f_l22, t2_f_l22], dim=1) + x = self.ca2(x)*x + x = self.o2_conv1(x) + x = self.o2_conv2(x) + x = self.o2_conv3(x) + x = self.sa2(x) *x + x = self.bn_sa2(x) + + out2 = self.o2_conv4(x) + + x = self.trans_conv2(x) + x = torch.cat([x, t1_f_l15, t2_f_l15], dim=1) + x = self.ca3(x)*x + x = self.o3_conv1(x) + x = self.o3_conv2(x) + x = self.o3_conv3(x) + x = self.sa3(x) *x + x = self.bn_sa3(x) + + out3 = self.o3_conv4(x) + + x = self.trans_conv3(x) + x = torch.cat([x, t1_f_l8, t2_f_l8], dim=1) + x = self.ca4(x)*x + x = self.o4_conv1(x) + x = self.o4_conv2(x) + x = self.o4_conv3(x) + x = self.sa4(x) *x + x = self.bn_sa4(x) + + out4 = self.o4_conv4(x) + + x = self.trans_conv4(x) + x = torch.cat([x, t1_f_l3, t2_f_l3], dim=1) + x = self.ca5(x)*x + x = self.o5_conv1(x) + x = self.o5_conv2(x) + x = self.o5_conv3(x) + x = self.sa5(x) *x + x = self.bn_sa5(x) + + out5 = self.o5_conv4(x) + + return (out1, out2, out3, out4, out5) \ No newline at end of file diff --git a/opencd/models/backbones/interaction_mit.py b/opencd/models/backbones/interaction_mit.py new file mode 100644 index 0000000000000000000000000000000000000000..572929cbb9ac973a487e727d7f8d10e667e4b967 --- /dev/null +++ b/opencd/models/backbones/interaction_mit.py @@ -0,0 +1,45 @@ +# Copyright (c) Open-CD. All rights reserved. +import torch +import torch.nn as nn + +from mmseg.models.utils import nlc_to_nchw +from mmseg.models.backbones import MixVisionTransformer + +from opencd.registry import MODELS + + +@MODELS.register_module() +class IA_MixVisionTransformer(MixVisionTransformer): + def __init__(self, + interaction_cfg=(None, None, None, None), + **kwargs): + super().__init__(**kwargs) + assert self.num_stages == len(interaction_cfg), \ + 'The length of the `interaction_cfg` should be same as the `num_stages`.' + # cross-correlation + self.ccs = [] + for ia_cfg in interaction_cfg: + if ia_cfg is None: + ia_cfg = dict(type='TwoIdentity') + self.ccs.append(MODELS.build(ia_cfg)) + self.ccs = nn.ModuleList(self.ccs) + + def forward(self, x1, x2): + outs = [] + + for i, layer in enumerate(self.layers): + x1, hw_shape = layer[0](x1) + x2, hw_shape = layer[0](x2) + for block in layer[1]: + x1 = block(x1, hw_shape) + x2 = block(x2, hw_shape) + x1 = layer[2](x1) + x2 = layer[2](x2) + + x1 = nlc_to_nchw(x1, hw_shape) + x2 = nlc_to_nchw(x2, hw_shape) + + x1, x2 = self.ccs[i](x1, x2) + if i in self.out_indices: + outs.append(torch.cat([x1, x2], dim=1)) + return outs \ No newline at end of file diff --git a/opencd/models/backbones/interaction_resnest.py b/opencd/models/backbones/interaction_resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..e69e4b0d3fb6d0fd03258a73b079cec0af172167 --- /dev/null +++ b/opencd/models/backbones/interaction_resnest.py @@ -0,0 +1,54 @@ +# Copyright (c) Open-CD. All rights reserved. +from mmseg.models.backbones.resnest import Bottleneck +from mmseg.models.utils import ResLayer +from opencd.registry import MODELS +from .interaction_resnet import IA_ResNetV1d + + +@MODELS.register_module() +class IA_ResNeSt(IA_ResNetV1d): + """Interaction ResNeSt backbone. + This backbone is the implementation of `ResNeSt: + Split-Attention Networks `_. + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(IA_ResNeSt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/opencd/models/backbones/interaction_resnet.py b/opencd/models/backbones/interaction_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f492e65a12ace1a4ec3b7d89c2938455b86789 --- /dev/null +++ b/opencd/models/backbones/interaction_resnet.py @@ -0,0 +1,151 @@ +# Copyright (c) Open-CD. All rights reserved. +import torch +import torch.nn as nn + +from mmseg.models.backbones import ResNet +from opencd.registry import MODELS + + +@MODELS.register_module() +class IA_ResNet(ResNet): + """Interaction ResNet backbone. + + Args: + interaction_cfg (Sequence[dict]): Interaction strategies for the stages. + The length should be the same as `num_stages`. The details can be + found in `opencd/models/utils/interaction_layer.py`. + Default: (None, None, None, None). + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: (1, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default: (1, 1, 1, 1). + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: 'pytorch'. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): Dictionary to construct and config conv layer. + When conv_cfg is None, cfg will be set to dict(type='Conv2d'). + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (dict | None): Dictionary to construct and config DCN conv layer. + When dcn is not None, conv_cfg must be None. Default: None. + stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each + stage. The length of stage_with_dcn is equal to num_stages. + Default: (False, False, False, False). + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + + - position (str, required): Position inside block to insert plugin, + options: 'after_conv1', 'after_conv2', 'after_conv3'. + + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + Default: None. + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None. + contract_dilation (bool): Whether contract first dilation of each layer + Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from opencd.models import IA_ResNet + >>> import torch + >>> self = IA_ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs, inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 128, 8, 8) + (1, 256, 4, 4) + (1, 512, 2, 2) + (1, 1024, 1, 1) + """ + def __init__(self, + interaction_cfg=(None, None, None, None), + **kwargs): + super().__init__(**kwargs) + assert self.num_stages == len(interaction_cfg), \ + 'The length of the `interaction_cfg` should be same as the `num_stages`.' + # cross-correlation + self.ccs = [] + for ia_cfg in interaction_cfg: + if ia_cfg is None: + ia_cfg = dict(type='TwoIdentity') + self.ccs.append(MODELS.build(ia_cfg)) + self.ccs = nn.ModuleList(self.ccs) + + def forward(self, x1, x2): + """Forward function.""" + def _stem_forward(x): + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + return x + + x1 = _stem_forward(x1) + x2 = _stem_forward(x2) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x1 = res_layer(x1) + x2 = res_layer(x2) + x1, x2 = self.ccs[i](x1, x2) + if i in self.out_indices: + outs.append(torch.cat([x1, x2], dim=1)) + return tuple(outs) + + +@MODELS.register_module() +class IA_ResNetV1c(IA_ResNet): + """ResNetV1c variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in + the input stem with three 3x3 convs. For more details please refer to `Bag + of Tricks for Image Classification with Convolutional Neural Networks + `_. + """ + + def __init__(self, **kwargs): + super(IA_ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class IA_ResNetV1d(IA_ResNet): + """ResNetV1d variant described in [1]_. + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(IA_ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) \ No newline at end of file diff --git a/opencd/models/backbones/snunet.py b/opencd/models/backbones/snunet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d673ac65c6f0dc61e229290c3dbd8ad922ce554 --- /dev/null +++ b/opencd/models/backbones/snunet.py @@ -0,0 +1,156 @@ +""" +S. Fang, K. Li, J. Shao, and Z. Li, +“SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images,” +IEEE Geosci. Remote Sensing Lett., pp. 1-5, 2021, doi: 10.1109/LGRS.2021.3056416. +""" + +import torch +import torch.nn as nn + +from opencd.registry import MODELS + + +class conv_block_nested(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch): + super(conv_block_nested, self).__init__() + self.activation = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True) + self.bn1 = nn.BatchNorm2d(mid_ch) + self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True) + self.bn2 = nn.BatchNorm2d(out_ch) + + def forward(self, x): + x = self.conv1(x) + identity = x + x = self.bn1(x) + x = self.activation(x) + + x = self.conv2(x) + x = self.bn2(x) + output = self.activation(x + identity) + return output + + +class up(nn.Module): + def __init__(self, in_ch, bilinear=False): + super(up, self).__init__() + + if bilinear: + self.up = nn.Upsample(scale_factor=2, + mode='bilinear', + align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2) + + def forward(self, x): + + x = self.up(x) + return x + + +class ChannelAttention(nn.Module): + def __init__(self, in_channels, ratio = 16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) + self.relu1 = nn.ReLU() + self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) + self.sigmod = nn.Sigmoid() + def forward(self,x): + avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) + max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) + out = avg_out + max_out + return self.sigmod(out) + + +@MODELS.register_module() +class SNUNet_ECAM(nn.Module): + # SNUNet-CD with ECAM + def __init__(self, in_channels, base_channel=32): + super(SNUNet_ECAM, self).__init__() + torch.nn.Module.dump_patches = True + n1 = base_channel # the initial number of channels of feature map + filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + self.conv0_0 = conv_block_nested(in_channels, filters[0], filters[0]) + self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1]) + self.Up1_0 = up(filters[1]) + self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2]) + self.Up2_0 = up(filters[2]) + self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3]) + self.Up3_0 = up(filters[3]) + self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4]) + self.Up4_0 = up(filters[4]) + + self.conv0_1 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0]) + self.conv1_1 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1]) + self.Up1_1 = up(filters[1]) + self.conv2_1 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2]) + self.Up2_1 = up(filters[2]) + self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3]) + self.Up3_1 = up(filters[3]) + + self.conv0_2 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0]) + self.conv1_2 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1]) + self.Up1_2 = up(filters[1]) + self.conv2_2 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2]) + self.Up2_2 = up(filters[2]) + + self.conv0_3 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0]) + self.conv1_3 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1]) + self.Up1_3 = up(filters[1]) + + self.conv0_4 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0]) + + self.ca = ChannelAttention(filters[0] * 4, ratio=16) + self.ca1 = ChannelAttention(filters[0], ratio=16 // 4) + + # self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + def forward(self, xA, xB): + '''xA''' + x0_0A = self.conv0_0(xA) + x1_0A = self.conv1_0(self.pool(x0_0A)) + x2_0A = self.conv2_0(self.pool(x1_0A)) + x3_0A = self.conv3_0(self.pool(x2_0A)) + # x4_0A = self.conv4_0(self.pool(x3_0A)) + '''xB''' + x0_0B = self.conv0_0(xB) + x1_0B = self.conv1_0(self.pool(x0_0B)) + x2_0B = self.conv2_0(self.pool(x1_0B)) + x3_0B = self.conv3_0(self.pool(x2_0B)) + x4_0B = self.conv4_0(self.pool(x3_0B)) + + x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1)) + x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1)) + x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1)) + + + x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1)) + x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1)) + x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1)) + + x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1)) + x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1)) + x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1)) + x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1)) + + out = torch.cat([x0_1, x0_2, x0_3, x0_4], 1) + + intra = torch.sum(torch.stack((x0_1, x0_2, x0_3, x0_4)), dim=0) + ca1 = self.ca1(intra) + out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1)) + # out = self.conv_final(out) + + return (out, ) diff --git a/opencd/models/backbones/tinycd.py b/opencd/models/backbones/tinycd.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b09d53ad7c79c03729b33011d5e28e6b77dd55 --- /dev/null +++ b/opencd/models/backbones/tinycd.py @@ -0,0 +1,204 @@ +""" +Codegoni A, Lombardi G, Ferrari A. +TINYCD: A (Not So) Deep Learning Model For Change Detection[J]. +arXiv preprint arXiv:2207.13159, 2022. +The code in this file is borrowed from: +https://github.com/AndreaCodegoni/Tiny_model_4_CD +""" +from typing import List, Optional + +import torchvision +from torch import Tensor, reshape, stack +from torch.nn import (Conv2d, InstanceNorm2d, Module, ModuleList, PReLU, + Sequential, Upsample) + +from opencd.registry import MODELS + + +class PixelwiseLinear(Module): + def __init__( + self, + fin: List[int], + fout: List[int], + last_activation: Module = None, + ) -> None: + assert len(fout) == len(fin) + super().__init__() + + n = len(fin) + self._linears = Sequential( + *[ + Sequential( + Conv2d(fin[i], fout[i], kernel_size=1, bias=True), + PReLU() + if i < n - 1 or last_activation is None + else last_activation, + ) + for i in range(n) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + # Processing the tensor: + return self._linears(x) + + +class MixingBlock(Module): + def __init__( + self, + ch_in: int, + ch_out: int, + ): + super().__init__() + self._convmix = Sequential( + Conv2d(ch_in, ch_out, 3, groups=ch_out, padding=1), + PReLU(), + InstanceNorm2d(ch_out), + ) + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + # Packing the tensors and interleaving the channels: + mixed = stack((x, y), dim=2) + mixed = reshape(mixed, (x.shape[0], -1, x.shape[2], x.shape[3])) + + # Mixing: + return self._convmix(mixed) + + +class MixingMaskAttentionBlock(Module): + """use the grouped convolution to make a sort of attention""" + + def __init__( + self, + ch_in: int, + ch_out: int, + fin: List[int], + fout: List[int], + generate_masked: bool = False, + ): + super().__init__() + self._mixing = MixingBlock(ch_in, ch_out) + self._linear = PixelwiseLinear(fin, fout) + self._final_normalization = InstanceNorm2d(ch_out) if generate_masked else None + self._mixing_out = MixingBlock(ch_in, ch_out) if generate_masked else None + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + z_mix = self._mixing(x, y) + z = self._linear(z_mix) + z_mix_out = 0 if self._mixing_out is None else self._mixing_out(x, y) + + return ( + z + if self._final_normalization is None + else self._final_normalization(z_mix_out * z) + ) + + +class UpMask(Module): + def __init__( + self, + scale_factor: float, + nin: int, + nout: int, + ): + super().__init__() + self._upsample = Upsample( + scale_factor=scale_factor, mode="bilinear", align_corners=True + ) + self._convolution = Sequential( + Conv2d(nin, nin, 3, 1, groups=nin, padding=1), + PReLU(), + InstanceNorm2d(nin), + Conv2d(nin, nout, kernel_size=1, stride=1), + PReLU(), + InstanceNorm2d(nout), + ) + + def forward(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor: + x = self._upsample(x) + if y is not None: + x = x * y + return self._convolution(x) + + +def _get_backbone( + bkbn_name, pretrained, output_layer_bkbn, freeze_backbone +) -> ModuleList: + # The whole model: + entire_model = getattr(torchvision.models, bkbn_name)( + pretrained=pretrained + ).features + + # Slicing it: + derived_model = ModuleList([]) + for name, layer in entire_model.named_children(): + derived_model.append(layer) + if name == output_layer_bkbn: + break + + # Freezing the backbone weights: + if freeze_backbone: + for param in derived_model.parameters(): + param.requires_grad = False + return derived_model + + +@MODELS.register_module() +class TinyCD(Module): + def __init__( + self, + in_channels, + bkbn_name="efficientnet_b4", + pretrained=True, + output_layer_bkbn="3", + freeze_backbone=False, + ): + super().__init__() + + # Load the pretrained backbone according to parameters: + self._backbone = _get_backbone( + bkbn_name, pretrained, output_layer_bkbn, freeze_backbone + ) + + # Initialize mixing blocks: + self._first_mix = MixingMaskAttentionBlock(6, 3, [3, 10, 5], [10, 5, 1]) + self._mixing_mask = ModuleList( + [ + MixingMaskAttentionBlock(48, 24, [24, 12, 6], [12, 6, 1]), + MixingMaskAttentionBlock(64, 32, [32, 16, 8], [16, 8, 1]), + MixingBlock(112, 56), + ] + ) + + # Initialize Upsampling blocks: + self._up = ModuleList( + [ + UpMask(2, 56, 64), + UpMask(2, 64, 64), + UpMask(2, 64, 32), + ] + ) + + # Final classification layer: + self._classify = PixelwiseLinear([32, 16], [16, 1], None) # out_channels = 8 + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + features = self._encode(x1, x2) + latents = self._decode(features) + out = self._classify(latents) + + return (out,) + + def _encode(self, ref, test) -> List[Tensor]: + features = [self._first_mix(ref, test)] + for num, layer in enumerate(self._backbone): + ref, test = layer(ref), layer(test) + if num != 0: + features.append(self._mixing_mask[num - 1](ref, test)) + return features + + def _decode(self, features) -> Tensor: + upping = features[-1] + for i, j in enumerate(range(-2, -5, -1)): + upping = self._up[i](upping, features[j]) + return upping diff --git a/opencd/models/backbones/tinynet.py b/opencd/models/backbones/tinynet.py new file mode 100644 index 0000000000000000000000000000000000000000..cb66c9da06edb1bf70995e9530d8f725d12f4e8d --- /dev/null +++ b/opencd/models/backbones/tinynet.py @@ -0,0 +1,519 @@ +# Copyright (c) Open-CD. All rights reserved. +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from torch.nn import functional as F +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils import checkpoint as cp + +from mmseg.models.utils import SELayer, make_divisible +from opencd.registry import MODELS + + +class AsymGlobalAttn(BaseModule): + def __init__(self, dim, strip_kernel_size=21): + super().__init__() + + self.norm = build_norm_layer(dict(type='mmpretrain.LN2d', eps=1e-6), dim)[1] + self.global_ = nn.Sequential( + nn.Conv2d(dim, dim, 1), + nn.Conv2d(dim, dim, (1, strip_kernel_size), padding=(0, (strip_kernel_size-1)//2), groups=dim), + nn.Conv2d(dim, dim, (strip_kernel_size, 1), padding=((strip_kernel_size-1)//2, 0), groups=dim) + ) + + self.v = nn.Conv2d(dim, dim, 1) + self.proj = nn.Conv2d(dim, dim, 1) + self.layer_scale = nn.Parameter(1e-6 * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + B, C, H, W = x.shape + identity = x + + a = self.global_(x) + x = a * self.v(x) + x = self.proj(x) + x = self.norm(x) + x = self.layer_scale.unsqueeze(-1).unsqueeze(-1) * x + identity + + return x + + +class PriorAttention(BaseModule): + def __init__(self, + channels, + num_paths=2, + attn_channels=None, + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN', requires_grad=True)): + super(PriorAttention, self).__init__() + self.num_paths = num_paths # `2` is supported. + attn_channels = attn_channels or channels // 16 + attn_channels = max(attn_channels, 8) + + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = build_norm_layer(norm_cfg, attn_channels)[1] + self.act = build_activation_layer(act_cfg) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x1, x2): + x = torch.abs(x1 - x2) + attn = x.mean((2, 3), keepdim=True) + attn = self.fc_reduce(attn) + attn = self.bn(attn) + attn = self.act(attn) + attn = self.fc_select(attn) + B, C, H, W = attn.shape + attn1, attn2 = attn.reshape(B, self.num_paths, C // self.num_paths, H, W).transpose(0, 1) + attn1 = torch.sigmoid(attn1) + attn2 = torch.sigmoid(attn2) + + return x1 * attn1 + x1, x2 * attn2 + x2 + + +class StemBlock(BaseModule): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + **kwargs): + super(StemBlock, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + ]) + + self.conv = nn.Sequential(*layers) + self.interact = PriorAttention(channels=hidden_dim) + self.post_conv = ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + + def forward(self, x): + x1, x2 = x + identity_x1 = x1 + identity_x2 = x2 + x1 = self.conv(x1) + x2 = self.conv(x2) + x1, x2 = self.interact(x1, x2) + x1 = self.post_conv(x1) + x2 = self.post_conv(x2) + + if self.use_res_connect: + x1 = x1 + identity_x1 + x2 = x2 + identity_x2 + + return x1, x2 + + +class PriorFusion(BaseModule): + def __init__(self, channels, stack_nums=2): + super().__init__() + + self.stem = nn.Sequential( + *[StemBlock( + in_channels=channels, + out_channels=channels, + stride=1, + expand_ratio=4) for _ in range(stack_nums)]) + + self.pseudo_fusion = nn.Sequential( + nn.Conv2d(channels * 2, channels * 2, 3, padding=1, groups=channels * 2), + build_norm_layer(dict(type='mmpretrain.LN2d', eps=1e-6), channels * 2)[1], + nn.GELU(), + nn.Conv2d(channels * 2, channels, 3, padding=1, groups=channels), + ) + + + def forward(self, x1, x2): + B, C, H, W = x1.shape + identity_x1 = x1 + identity_x2 = x2 + + x1, x2 = self.stem((x1, x2)) + x1 = x1 + identity_x1 + x2 = x2 + identity_x2 + + early_x = torch.cat([x1, x2], dim=1) + x = self.pseudo_fusion(early_x) + return early_x, x + + +class TinyBlock(BaseModule): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + with_se=False, + **kwargs): + super(TinyBlock, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + Attention_Layer = SELayer(hidden_dim) if with_se else nn.Identity() + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + Attention_Layer, + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + x = x + self.conv(x) + return x + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class TinyNet(BaseModule): + """TinyNet backbone. + This backbone is the implementation of + + Args: + output_early_x (bool): output early features before fusion. + Defaults to 'False'. + arch='B' (str): The model's architecture. It should be + one of architecture in ``TinyNet.change_extractor_settings``. + Defaults to 'B'. + stem_stack_nums (int): The number of stacked stem blocks. + use_global: (Sequence[bool]): whether use `AsymGlobalAttn` after + stages. Defaults: (True, True, True, True). + strip_kernel_size: (Sequence[int]): The strip kernel size of + `AsymGlobalAttn`. Defaults: (41, 31, 21, 11). + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + change_extractor_settings = { + 'S': [[4, 16, 2], [6, 24, 2], [6, 32, 3], [6, 48, 1]], + 'B': [[4, 16, 2], [6, 24, 2], [6, 32, 3], [6, 48, 1]], + 'L': [[4, 16, 2], [6, 24, 2], [6, 32, 6], [6, 48, 1]],} + + def __init__(self, + output_early_x=False, + arch='B', + stem_stack_nums=2, + use_global=(True, True, True, True), + strip_kernel_size=(41, 31, 21, 11), + widen_factor=1., + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.arch_settings = self.change_extractor_settings[arch] + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 7). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(16 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.fusion_block = PriorFusion(self.in_channels, stem_stack_nums) + + self.layers = [] + self.use_global = use_global + self.strip_kernel_size = strip_kernel_size + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio, + use_global=use_global[i], + strip_kernel_size=self.strip_kernel_size[i]) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + self.output_early_x = output_early_x + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio, use_global, strip_kernel_size): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + TinyBlock( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + # after stage + if use_global: + layers.append( + AsymGlobalAttn(out_channels, strip_kernel_size) + ) + + return nn.Sequential(*layers) + + def forward(self, x1, x2): + x1 = self.conv1(x1) + x2 = self.conv1(x2) + + early_x, x = self.fusion_block(x1, x2) + + if self.output_early_x: + outs = [early_x] + else: + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(TinyNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() \ No newline at end of file diff --git a/opencd/models/change_detectors/__init__.py b/opencd/models/change_detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf56c76830e7c408cd2460f224f920bb2f088c1a --- /dev/null +++ b/opencd/models/change_detectors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Open-CD. All rights reserved. +from .dual_input_encoder_decoder import DIEncoderDecoder +from .siamencoder_decoder import SiamEncoderDecoder +from .siamencoder_multidecoder import SiamEncoderMultiDecoder + +__all__ = ['SiamEncoderDecoder', 'DIEncoderDecoder', 'SiamEncoderMultiDecoder'] diff --git a/opencd/models/change_detectors/__pycache__/__init__.cpython-311.pyc b/opencd/models/change_detectors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4978e1b5ee3ddf0ff668d5597e501c5f6a4ea285 Binary files /dev/null and b/opencd/models/change_detectors/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/change_detectors/__pycache__/dual_input_encoder_decoder.cpython-311.pyc b/opencd/models/change_detectors/__pycache__/dual_input_encoder_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f0695ac1e7dec91f7f7d82a2227919b48b5b364 Binary files /dev/null and b/opencd/models/change_detectors/__pycache__/dual_input_encoder_decoder.cpython-311.pyc differ diff --git a/opencd/models/change_detectors/__pycache__/siamencoder_decoder.cpython-311.pyc b/opencd/models/change_detectors/__pycache__/siamencoder_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ba9c9a39e7ec9d5994f507a7369810d0b3db16e Binary files /dev/null and b/opencd/models/change_detectors/__pycache__/siamencoder_decoder.cpython-311.pyc differ diff --git a/opencd/models/change_detectors/__pycache__/siamencoder_multidecoder.cpython-311.pyc b/opencd/models/change_detectors/__pycache__/siamencoder_multidecoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ca7ff968a0156d9f455d2a8143b6abe6b3d4f02 Binary files /dev/null and b/opencd/models/change_detectors/__pycache__/siamencoder_multidecoder.cpython-311.pyc differ diff --git a/opencd/models/change_detectors/dual_input_encoder_decoder.py b/opencd/models/change_detectors/dual_input_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9b64e2634fe0af1f778218ef2fcf7834964c79 --- /dev/null +++ b/opencd/models/change_detectors/dual_input_encoder_decoder.py @@ -0,0 +1,27 @@ +# Copyright (c) Open-CD. All rights reserved. +from typing import List, Optional + +import torch +from torch import Tensor + +from opencd.registry import MODELS +from .siamencoder_decoder import SiamEncoderDecoder + + +@MODELS.register_module() +class DIEncoderDecoder(SiamEncoderDecoder): + """Dual Input Encoder Decoder segmentors. + + DIEncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract features from images.""" + # `in_channels` is not in the ATTRIBUTE for some backbone CLASS. + img_from, img_to = torch.split(inputs, self.backbone_inchannels, dim=1) + x = self.backbone(img_from, img_to) + if self.with_neck: + x = self.neck(x) + return x diff --git a/opencd/models/change_detectors/siamencoder_decoder.py b/opencd/models/change_detectors/siamencoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5a3496b54cb84e8bb89c7e02d2ad9ec8364ac6 --- /dev/null +++ b/opencd/models/change_detectors/siamencoder_decoder.py @@ -0,0 +1,449 @@ +# Copyright (c) Open-CD. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.models.utils import resize +from mmseg.structures import SegDataSample +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from opencd.registry import MODELS + + +@MODELS.register_module() +class SiamEncoderDecoder(BaseSegmentor): + """SiamEncoder Decoder change detector. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + infercen(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + backbone_inchannels (int): The `in_channels` for backbone network. + Defaults: 3 for RGB image. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None, + backbone_inchannels: int = 3): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.backbone_inchannels = backbone_inchannels # RGB: 3 + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None: + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(MODELS.build(head_cfg)) + else: + self.auxiliary_head = MODELS.build(auxiliary_head) + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract features from images.""" + # `in_channels` is not in the ATTRIBUTE for some backbone CLASS. + img_from, img_to = torch.split(inputs, self.backbone_inchannels, dim=1) + img = torch.cat([img_from, img_to], dim=0) + img_feat = self.backbone(img)[0] + feat_from, feat_to = torch.split(img_feat, img_feat.shape[0] // 2, dim=0) + feat_from = [feat_from] + feat_to = [feat_to] + if self.with_neck: + x = self.neck(feat_from, feat_to) + else: + raise ValueError('`NECK` is needed for `SiamEncoderDecoder`.') + + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + seg_logits = self.decode_head.predict(x, batch_img_metas, + self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = batch_img_metas[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def postprocess_result(self, + seg_logits: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + seg_logits (Tensor): The segmentation results, seg_logits from + model of each input image. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + batch_size, C, H, W = seg_logits.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_seg_logits shape is 1, C, H, W after remove padding + i_seg_logits = seg_logits[i:i + 1, :, + padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logits = i_seg_logits.flip(dims=(3, )) + else: + i_seg_logits = i_seg_logits.flip(dims=(2, )) + + # resize as original shape + i_seg_logits = resize( + i_seg_logits, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_seg_logits = seg_logits[i] + + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_logits = i_seg_logits.sigmoid() + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) + data_samples[i].set_data({ + 'seg_logits': + PixelData(**{'data': i_seg_logits}), + 'pred_sem_seg': + PixelData(**{'data': i_seg_pred}) + }) + + return data_samples diff --git a/opencd/models/change_detectors/siamencoder_multidecoder.py b/opencd/models/change_detectors/siamencoder_multidecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a085afca86d16ffdf8b07a68f6a6b5a0ff73a5ef --- /dev/null +++ b/opencd/models/change_detectors/siamencoder_multidecoder.py @@ -0,0 +1,324 @@ +# Copyright (c) Open-CD. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.models.utils import resize +from mmseg.structures import SegDataSample +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from opencd.registry import MODELS +from .siamencoder_decoder import SiamEncoderDecoder + + +@MODELS.register_module() +class SiamEncoderMultiDecoder(SiamEncoderDecoder): + """SiamEncoder Multihead Decoder segmentors. + + SiamEncoderMultiDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + Args: + postprocess_pred_and_label (str, optional): Whether to post-process the + `pred` and `label` when predicting. Defaults to None. + """ + + def __init__(self, postprocess_pred_and_label=None, **kwargs): + super().__init__(**kwargs) + self.postprocess_pred_and_label = postprocess_pred_and_label + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + # for binary branches + self.decode_head = MODELS.build(decode_head) + self.num_classes = self.decode_head.binary_cd_head.num_classes + self.out_channels = self.decode_head.binary_cd_head.out_channels + # for sementic branches + self.semantic_num_classes = self.decode_head.semantic_cd_head.num_classes + self.semantic_out_channels = self.decode_head.semantic_cd_head.out_channels + + self.align_corners = { + 'seg_logits': self.decode_head.binary_cd_head.align_corners, + 'seg_logits_from': self.decode_head.semantic_cd_head.align_corners, + 'seg_logits_to': self.decode_head.semantic_cd_head_aux.align_corners} + self.thresholds = { + 'seg_logits': self.decode_head.binary_cd_head.threshold, + 'seg_logits_from': self.decode_head.semantic_cd_head.threshold, + 'seg_logits_to': self.decode_head.semantic_cd_head_aux.threshold} + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract features from images.""" + # `in_channels` is not in the ATTRIBUTE for some backbone CLASS. + img_from, img_to = torch.split(inputs, self.backbone_inchannels, dim=1) + feat_from = self.backbone(img_from) + feat_to = self.backbone(img_to) + if self.with_neck: + feat_from = self.neck(feat_from) + feat_to = self.neck(feat_to) + x = (feat_from, feat_to) + + return x + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + semantic_channels = self.semantic_out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = dict( + seg_logits=inputs.new_zeros((batch_size, out_channels, h_img, w_img)), + seg_logits_from=inputs.new_zeros((batch_size, semantic_channels, h_img, w_img)), + seg_logits_to=inputs.new_zeros((batch_size, semantic_channels, h_img, w_img)) + ) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logits = self.encode_decode(crop_img, batch_img_metas) + for seg_name, crop_seg_logit in crop_seg_logits.items(): + preds[seg_name] += F.pad(crop_seg_logit, + (int(x1), int(preds[seg_name].shape[3] - x2), int(y1), + int(preds[seg_name].shape[2] - y2))) + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + for seg_name, pred in preds.items(): + preds[seg_name] = pred / count_mat + + return preds + + # def aug_test(self, inputs, batch_img_metas, rescale=True): + # """Test with augmentations. + + # Only rescale=True is supported. + # """ + # # aug_test rescale all imgs back to ori_shape for now + # assert rescale + # # to save memory, we get augmented seg logit inplace + # seg_logits = self.inference(inputs[0], batch_img_metas[0], rescale) + # for i in range(1, len(inputs)): + # cur_seg_logits = self.inference(inputs[i], batch_img_metas[i], rescale) + # for seg_name, cur_seg_logit in cur_seg_logits.items(): + # seg_logits[seg_name] += cur_seg_logit + # for seg_name, seg_logit in seg_logits.items(): + # seg_logits[seg_name] /= len(inputs) + + # seg_preds = [] + # for seg_name, seg_logit in seg_logits.items(): + # if (self.out_channels == 1 and seg_name == 'seg_logits') \ + # or (self.semantic_out_channels == 1 \ + # and ("from" in seg_name or "to" in seg_name)): + # seg_pred = (seg_logit > + # self.thresholds[seg_name]).to(seg_logit).squeeze(1) + # else: + # seg_pred = seg_logit.argmax(dim=1) + # # unravel batch dim + # seg_pred = list(seg_pred) + # seg_preds.append(seg_pred) + + # # (3, B, H, W) -> (B, 3, H, W) + # seg_preds = [list(pred) for pred in list(zip(*seg_preds))] + # return seg_preds + + def postprocess_result(self, + seg_logits: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + seg_logits (Tensor): The segmentation results, seg_logits from + model of each input image. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + + C = dict() + for seg_name, seg_logit in seg_logits.items(): + batch_size, _C, H, W = seg_logit.shape + C[seg_name] = _C + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + for seg_name, seg_logit in seg_logits.items(): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_seg_logit shape is 1, C, H, W after remove padding + i_seg_logit = seg_logit[i:i + 1, :, + padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logit = i_seg_logit.flip(dims=(3, )) + else: + i_seg_logit = i_seg_logit.flip(dims=(2, )) + + # resize as original shape + i_seg_logit = resize( + i_seg_logit, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners[seg_name], + warning=False).squeeze(0) + else: + i_seg_logit = seg_logit[i] + + if C[seg_name] > 1: + i_seg_pred = i_seg_logit.argmax(dim=0, keepdim=True) + else: + i_seg_logit = i_seg_logit.sigmoid() + i_seg_pred = (i_seg_logit > + self.thresholds[seg_name]).to(i_seg_logit) + + pred_name = '_' + seg_name.split('_')[-1] \ + if seg_name.split('_')[-1] in ['from', 'to'] else '' + pred_name = 'pred_sem_seg' + pred_name + data_samples[i].set_data({ + seg_name: + PixelData(**{'data': i_seg_logit}), + pred_name: + PixelData(**{'data': i_seg_pred}) + }) + + if self.postprocess_pred_and_label is not None: + if self.postprocess_pred_and_label == 'cover_semantic': + for data_sample in data_samples: + # postprocess_semantic_pred + data_sample.pred_sem_seg_from.data = data_sample.pred_sem_seg_from.data + 1 + data_sample.pred_sem_seg_to.data = data_sample.pred_sem_seg_to.data + 1 + data_sample.pred_sem_seg_from.data = data_sample.pred_sem_seg_from.data * \ + data_sample.pred_sem_seg.data + data_sample.pred_sem_seg_to.data = data_sample.pred_sem_seg_to.data * \ + data_sample.pred_sem_seg.data + + # postprocess_semantic_label + data_sample.gt_sem_seg_from.data[data_sample.gt_sem_seg_from.data == 255] = -1 + data_sample.gt_sem_seg_from.data = data_sample.gt_sem_seg_from.data + 1 + data_sample.gt_sem_seg_to.data[data_sample.gt_sem_seg_to.data == 255] = -1 + data_sample.gt_sem_seg_to.data = data_sample.gt_sem_seg_to.data + 1 + else: + raise ValueError( + f'`postprocess_pred_and_label` should be `cover_semantic` or None.') + + return data_samples + + + + # for seg_name, seg_logit in seg_logits.items(): + # batch_size, C, H, W = seg_logit.shape + + # if data_samples is None: + # data_samples = [SegDataSample() for _ in range(batch_size)] + # only_prediction = True + # else: + # only_prediction = False + + # for i in range(batch_size): + # if not only_prediction: + # img_meta = data_samples[i].metainfo + # # remove padding area + # if 'img_padding_size' not in img_meta: + # padding_size = img_meta.get('padding_size', [0] * 4) + # else: + # padding_size = img_meta['img_padding_size'] + # padding_left, padding_right, padding_top, padding_bottom =\ + # padding_size + # # i_seg_logit shape is 1, C, H, W after remove padding + # i_seg_logit = seg_logit[i:i + 1, :, + # padding_top:H - padding_bottom, + # padding_left:W - padding_right] + + # flip = img_meta.get('flip', None) + # if flip: + # flip_direction = img_meta.get('flip_direction', None) + # assert flip_direction in ['horizontal', 'vertical'] + # if flip_direction == 'horizontal': + # i_seg_logit = i_seg_logit.flip(dims=(3, )) + # else: + # i_seg_logit = i_seg_logit.flip(dims=(2, )) + + # # resize as original shape + # i_seg_logit = resize( + # i_seg_logit, + # size=img_meta['ori_shape'], + # mode='bilinear', + # align_corners=self.align_corners[seg_name], + # warning=False).squeeze(0) + # else: + # i_seg_logit = seg_logit[i] + + # if C > 1: + # i_seg_pred = i_seg_logit.argmax(dim=0, keepdim=True) + # else: + # i_seg_logit = F.sigmoid(i_seg_logit) + # i_seg_pred = (i_seg_logit > + # self.thresholds[seg_name]).to(i_seg_logit) + + # data_samples[i].set_data({ + # 'seg_logits': + # PixelData(**{'data': i_seg_logit}), + # 'pred_sem_seg': + # PixelData(**{'data': i_seg_pred}) + # }) + + # seg_logits[seg_name] = data_samples + + # return seg_logits diff --git a/opencd/models/data_preprocessor.py b/opencd/models/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..7b59e019d62b954510650e09994596d744692d20 --- /dev/null +++ b/opencd/models/data_preprocessor.py @@ -0,0 +1,254 @@ +# Copyright (c) Open-CD. All rights reserved. +from numbers import Number +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseDataPreprocessor + +from mmseg.utils import SampleList +from opencd.registry import MODELS + + +def stack_batch(inputs: List[torch.Tensor], + data_samples: Optional[SampleList] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Union[int, float] = 0, + seg_pad_val: Union[int, float] = 255) -> torch.Tensor: + """Stack multiple inputs to form a batch and pad the images and gt_sem_segs + to the max shape use the right bottom padding mode. + + Args: + inputs (List[Tensor]): The input multiple tensors. each is a + CHW 3D-tensor. + data_samples (list[:obj:`SegDataSample`]): The list of data samples. + It usually includes information such as `gt_sem_seg`. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (int, float): The padding value. Defaults to 0 + seg_pad_val (int, float): The padding value. Defaults to 255 + + Returns: + Tensor: The 4D-tensor. + List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. + """ + assert isinstance(inputs, list), \ + f'Expected input type to be list, but got {type(inputs)}' + assert len({tensor.ndim for tensor in inputs}) == 1, \ + f'Expected the dimensions of all inputs must be the same, ' \ + f'but got {[tensor.ndim for tensor in inputs]}' + assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ + f'but got {inputs[0].ndim}' + assert len({tensor.shape[0] for tensor in inputs}) == 1, \ + f'Expected the channels of all inputs must be the same, ' \ + f'but got {[tensor.shape[0] for tensor in inputs]}' + + # only one of size and size_divisor should be valid + assert (size is not None) ^ (size_divisor is not None), \ + 'only one of size and size_divisor should be valid' + + padded_inputs = [] + padded_samples = [] + inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] + max_size = np.stack(inputs_sizes).max(0) + if size_divisor is not None and size_divisor > 1: + # the last two dims are H,W, both subject to divisibility requirement + max_size = (max_size + + (size_divisor - 1)) // size_divisor * size_divisor + + for i in range(len(inputs)): + tensor = inputs[i] + if size is not None: + width = max(size[-1] - tensor.shape[-1], 0) + height = max(size[-2] - tensor.shape[-2], 0) + # (padding_left, padding_right, padding_top, padding_bottom) + padding_size = (0, width, 0, height) + elif size_divisor is not None: + width = max(max_size[-1] - tensor.shape[-1], 0) + height = max(max_size[-2] - tensor.shape[-2], 0) + padding_size = (0, width, 0, height) + else: + padding_size = [0, 0, 0, 0] + + # pad img + pad_img = F.pad(tensor, padding_size, value=pad_val) + padded_inputs.append(pad_img) + # pad gt_sem_seg + if data_samples is not None: + data_sample = data_samples[i] + gt_sem_seg = data_sample.gt_sem_seg.data + del data_sample.gt_sem_seg.data + data_sample.gt_sem_seg.data = F.pad( + gt_sem_seg, padding_size, value=seg_pad_val) + if 'gt_edge_map' in data_sample: + gt_edge_map = data_sample.gt_edge_map.data + del data_sample.gt_edge_map.data + data_sample.gt_edge_map.data = F.pad( + gt_edge_map, padding_size, value=seg_pad_val) + if 'gt_seg_map_from' in data_sample: + gt_seg_map_from = data_sample.gt_seg_map_from.data + del data_sample.gt_seg_map_from.data + data_sample.gt_seg_map_from.data = F.pad( + gt_seg_map_from, padding_size, value=seg_pad_val) + if 'gt_seg_map_to' in data_sample: + gt_seg_map_to = data_sample.gt_seg_map_to.data + del data_sample.gt_seg_map_to.data + data_sample.gt_seg_map_to.data = F.pad( + gt_seg_map_to, padding_size, value=seg_pad_val) + data_sample.set_metainfo({ + 'img_shape': tensor.shape[-2:], + 'pad_shape': data_sample.gt_sem_seg.shape, + 'padding_size': padding_size + }) + padded_samples.append(data_sample) + else: + padded_samples.append( + dict( + img_padding_size=padding_size, + pad_shape=pad_img.shape[-2:])) + + return torch.stack(padded_inputs, dim=0), padded_samples + + +@MODELS.register_module() +class DualInputSegDataPreProcessor(BaseDataPreprocessor): + """Image pre-processor for change detection tasks. + + Comparing with the :class:`mmengine.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the input size with defined ``pad_val``, and pad seg map + with defined ``seg_pad_val``. + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + padding_mode (str): Type of padding. Default: constant. + - constant: pads with a constant value, this value is specified + with pad_val. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + batch_augments (list[dict], optional): Batch-level augmentations + test_cfg (dict, optional): The padding size config in testing, if not + specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[dict]] = None, + test_cfg: dict = None, + ): + super().__init__() + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + assert not (bgr_to_rgb and rgb_to_bgr), ( + '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + self.channel_conversion = rgb_to_bgr or bgr_to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + # TODO: support batch augmentations. + self.batch_augments = batch_augments + + # Support different padding methods in testing + self.test_cfg = test_cfg + + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: + """Perform normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Dict: Data in the same format as the model input. + """ + data = self.cast_data(data) # type: ignore + inputs = data['inputs'] + data_samples = data.get('data_samples', None) + # TODO: whether normalize should be after stack_batch + if self.channel_conversion and inputs[0].size(0) == 6: + inputs = [_input[[2, 1, 0, 5, 4, 3], ...] for _input in inputs] + + inputs = [_input.float() for _input in inputs] + if self._enable_normalize: + inputs = [(_input - self.mean) / self.std for _input in inputs] + + if training: + assert data_samples is not None, ('During training, ', + '`data_samples` must be define.') + inputs, data_samples = stack_batch( + inputs=inputs, + data_samples=data_samples, + size=self.size, + size_divisor=self.size_divisor, + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + + if self.batch_augments is not None: + inputs, data_samples = self.batch_augments( + inputs, data_samples) + else: + assert len(inputs) == 1, ( + 'Batch inference is not support currently, ' + 'as the image size might be different in a batch') + # pad images when testing + if self.test_cfg: + inputs, padded_samples = stack_batch( + inputs=inputs, + size=self.test_cfg.get('size', None), + size_divisor=self.test_cfg.get('size_divisor', None), + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + for data_sample, pad_info in zip(data_samples, padded_samples): + data_sample.set_metainfo({**pad_info}) + else: + inputs = torch.stack(inputs, dim=0) + + return dict(inputs=inputs, data_samples=data_samples) diff --git a/opencd/models/decode_heads/__init__.py b/opencd/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89bf0a3841e695b9934d53cefae9627a6ce19380 --- /dev/null +++ b/opencd/models/decode_heads/__init__.py @@ -0,0 +1,10 @@ +from .bit_head import BITHead +from .changer import Changer +from .general_scd_head import GeneralSCDHead +from .identity_head import DSIdentityHead, IdentityHead +from .multi_head import MultiHeadDecoder +from .sta_head import STAHead +from .tiny_head import TinyHead + +__all__ = ['BITHead', 'Changer', 'IdentityHead', 'DSIdentityHead', 'TinyHead', + 'STAHead', 'MultiHeadDecoder', 'GeneralSCDHead'] diff --git a/opencd/models/decode_heads/__pycache__/__init__.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf141b214d23c375da2b069e454ab664bf56901 Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/bit_head.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/bit_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5c0816ba28d7eae808297c704a043f49a161ab Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/bit_head.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/changer.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/changer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..493d1618b2780b70c56236288e478992e93bf527 Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/changer.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/general_scd_head.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/general_scd_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d33904be523de28387735b84cedabc32b592d036 Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/general_scd_head.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/identity_head.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/identity_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c095cfc77a0ea79713c7506f100a706bcc83a7d Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/identity_head.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/multi_head.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/multi_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47bae9104bc125f3b87a18929fd2a2d4261f957d Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/multi_head.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/sta_head.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/sta_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7307f7381ef831e96619839d5dfd46c9d689095e Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/sta_head.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/__pycache__/tiny_head.cpython-311.pyc b/opencd/models/decode_heads/__pycache__/tiny_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a421eb5103d11650cce4cbbec91ae2195375b0a Binary files /dev/null and b/opencd/models/decode_heads/__pycache__/tiny_head.cpython-311.pyc differ diff --git a/opencd/models/decode_heads/bit_head.py b/opencd/models/decode_heads/bit_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a67d7e5324fd353e08d4a8518b286ffa6a059ab6 --- /dev/null +++ b/opencd/models/decode_heads/bit_head.py @@ -0,0 +1,306 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import ModuleList, Sequential + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.utils import Upsample +from opencd.registry import MODELS + + +class CrossAttention(nn.Module): + def __init__(self, + in_dims, + embed_dims, + num_heads, + dropout_rate=0., + apply_softmax=True): + super(CrossAttention, self).__init__() + self.num_heads = num_heads + self.scale = in_dims ** -0.5 + + self.apply_softmax = apply_softmax + + self.to_q = nn.Linear(in_dims, embed_dims, bias=False) + self.to_k = nn.Linear(in_dims, embed_dims, bias=False) + self.to_v = nn.Linear(in_dims, embed_dims, bias=False) + + self.fc_out = nn.Sequential( + nn.Linear(embed_dims, in_dims), + nn.Dropout(dropout_rate) + ) + + def forward(self, x, ref): + b, n = x.shape[:2] + h = self.num_heads + + q = self.to_q(x) + k = self.to_k(ref) + v = self.to_v(ref) + + q = q.reshape((b, n, h, -1)).permute((0, 2, 1, 3)) + k = k.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) + v = v.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) + + mult = torch.matmul(q, k.transpose(-1,-2)) * self.scale + + if self.apply_softmax: + mult = F.softmax(mult, dim=-1) + + out = torch.matmul(mult, v) + out = out.permute((0,2,1,3)).flatten(2) + return self.fc_out(out) + + +class FeedForward(nn.Sequential): + def __init__(self, dim, hidden_dim, dropout_rate=0.): + super().__init__( + # TODO:to be more mmlab + nn.Linear(dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout_rate), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout_rate) + ) + + +class TransformerEncoder(nn.Module): + def __init__(self, + in_dims, + embed_dims, + num_heads, + drop_rate, + norm_cfg, + apply_softmax=True): + super(TransformerEncoder, self).__init__() + self.attn = CrossAttention( + in_dims, + embed_dims, + num_heads, + dropout_rate=drop_rate, + apply_softmax=apply_softmax) + self.ff = FeedForward( + in_dims, + embed_dims, + drop_rate + ) + self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] + self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] + def forward(self, x): + x_ = self.attn(self.norm1(x),self.norm1(x)) + x + y = self.ff(self.norm2(x_)) + x_ + return y + + +class TransformerDecoder(nn.Module): + def __init__( + self, + in_dims, + embed_dims, + num_heads, + drop_rate, + norm_cfg, + apply_softmax=True + ): + super(TransformerDecoder, self).__init__() + self.attn = CrossAttention( + in_dims, + embed_dims, + num_heads, + dropout_rate=drop_rate, + apply_softmax=apply_softmax) + self.ff = FeedForward( + in_dims, + embed_dims, + drop_rate + ) + self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] + self.norm1_ = build_norm_layer(norm_cfg, in_dims)[1] + self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] + + def forward(self, x, ref): + x_ = self.attn(self.norm1(x),self.norm1_(ref)) + x + y = self.ff(self.norm2(x_)) + x_ + return y + + +@MODELS.register_module() +class BITHead(BaseDecodeHead): + """BIT Head + + This head is the improved implementation of'Remote Sensing Image + Change Detection With Transformers' + + Args: + in_channels (int): Number of input feature channels (from backbone). Default: 512 + channels (int): Number of output channels of pre_process. Default: 32. + embed_dims (int): Number of expanded channels of Attention block. Default: 64. + enc_depth (int): Depth of block of transformer encoder. Default: 1. + enc_with_pos (bool): Using position embedding in transformer encoder. + Default: True + dec_depth (int): Depth of block of transformer decoder. Default: 8. + num_heads (int): Number of Multi-Head Cross-Attention Head of transformer encoder. + Default: 8. + use_tokenizer (bool),Using semantic token. Default: True + token_len (int): Number of dims of token. Default: 4. + pre_upsample (int): Scale factor of upsample of pre_process. + (default upsample to 64x64) + Default: 2. + """ + + def __init__(self, + in_channels=256, + channels=32, + embed_dims=64, + enc_depth=1, + enc_with_pos=True, + dec_depth=8, + num_heads=8, + drop_rate=0., + pool_size=2, + pool_mode='max', + use_tokenizer=True, + token_len=4, + pre_upsample=2, + upsample_size=4, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='ReLU', inplace=True), + **kwargs): + super().__init__(in_channels, channels, **kwargs) + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.embed_dims=embed_dims + self.use_tokenizer = use_tokenizer + self.num_heads=num_heads + if not use_tokenizer: + # If a tokenzier is not to be used,then downsample the feature maps + self.pool_size = pool_size + self.pool_mode = pool_mode + self.token_len = pool_size * pool_size + else: + self.token_len = token_len + self.conv_att = ConvModule( + self.channels, + self.token_len, + 1, + conv_cfg=self.conv_cfg, + ) + + self.enc_with_pos = enc_with_pos + if enc_with_pos: + self.enc_pos_embedding = nn.Parameter(torch.randn(1, self.token_len * 2, self.channels)) + + # pre_process to backbone feature + self.pre_process = Sequential( + Upsample(scale_factor=pre_upsample, mode='bilinear', align_corners=self.align_corners), + ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg + ) + ) + + # Transformer Encoder + self.encoder = ModuleList() + for _ in range(enc_depth): + block = TransformerEncoder( + self.channels, + self.embed_dims, + self.num_heads, + drop_rate=drop_rate, + norm_cfg=self.norm_cfg, + ) + self.encoder.append(block) + + # Transformer Decoder + self.decoder = ModuleList() + for _ in range(dec_depth): + block = TransformerDecoder( + self.channels, + self.embed_dims, + self.num_heads, + drop_rate=drop_rate, + norm_cfg=self.norm_cfg, + ) + self.decoder.append(block) + + self.upsample = Upsample(scale_factor=upsample_size,mode='bilinear',align_corners=self.align_corners) + + # Token + def _forward_semantic_tokens(self, x): + b, c = x.shape[:2] + att_map = self.conv_att(x) + att_map = att_map.reshape((b, self.token_len, 1, -1)) + att_map = F.softmax(att_map, dim=-1) + x = x.reshape((b, 1, c, -1)) + tokens = (x * att_map).sum(-1) + return tokens + + def _forward_reshaped_tokens(self, x): + if self.pool_mode == 'max': + x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size)) + elif self.pool_mode == 'avg': + x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size)) + else: + x = x + tokens = x.permute((0, 2, 3, 1)).flatten(1, 2) + return tokens + + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + inputs = self._transform_inputs(inputs) + x1, x2 = torch.chunk(inputs, 2, dim=1) + x1 = self.pre_process(x1) + x2 = self.pre_process(x2) + # Tokenization + if self.use_tokenizer: + token1 = self._forward_semantic_tokens(x1) + token2 = self._forward_semantic_tokens(x2) + else: + token1 = self._forward_reshaped_tokens(x1) + token2 = self._forward_reshaped_tokens(x2) + + # Transformer encoder forward + token = torch.cat([token1, token2], dim=1) + if self.enc_with_pos: + token += self.enc_pos_embedding + for i, _encoder in enumerate(self.encoder): + token = _encoder(token) + token1, token2 = torch.chunk(token, 2, dim=1) + + # Transformer decoder forward + for _decoder in self.decoder: + b, c, h, w = x1.shape + x1 = x1.permute((0, 2, 3, 1)).flatten(1, 2) + x2 = x2.permute((0, 2, 3, 1)).flatten(1, 2) + + x1 = _decoder(x1, token1) + x2 = _decoder(x2, token2) + + x1 = x1.transpose(1, 2).reshape((b, c, h, w)) + x2 = x2.transpose(1, 2).reshape((b, c, h, w)) + + # Feature differencing + y = torch.abs(x1 - x2) + y = self.upsample(y) + + return y + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/opencd/models/decode_heads/changer.py b/opencd/models/decode_heads/changer.py new file mode 100644 index 0000000000000000000000000000000000000000..be4f37c7e88df633d2ec0312e9a9cf1f6825057b --- /dev/null +++ b/opencd/models/decode_heads/changer.py @@ -0,0 +1,227 @@ +# Copyright (c) Open-CD. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import Conv2d, ConvModule, build_activation_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, Sequential +from torch.nn import functional as F + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.utils import resize +from opencd.registry import MODELS +from ..necks.feature_fusion import FeatureFusionNeck + + +class FDAF(BaseModule): + """Flow Dual-Alignment Fusion Module. + + Args: + in_channels (int): Input channels of features. + conv_cfg (dict | None): Config of conv layers. + Default: None + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN') + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + """ + + def __init__(self, + in_channels, + conv_cfg=None, + norm_cfg=dict(type='IN'), + act_cfg=dict(type='GELU')): + super(FDAF, self).__init__() + self.in_channels = in_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + # TODO + conv_cfg=None + norm_cfg=dict(type='IN') + act_cfg=dict(type='GELU') + + kernel_size = 5 + self.flow_make = Sequential( + nn.Conv2d(in_channels*2, in_channels*2, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=True, groups=in_channels*2), + nn.InstanceNorm2d(in_channels*2), + nn.GELU(), + nn.Conv2d(in_channels*2, 4, kernel_size=1, padding=0, bias=False), + ) + + def forward(self, x1, x2, fusion_policy=None): + """Forward function.""" + + output = torch.cat([x1, x2], dim=1) + flow = self.flow_make(output) + f1, f2 = torch.chunk(flow, 2, dim=1) + x1_feat = self.warp(x1, f1) - x2 + x2_feat = self.warp(x2, f2) - x1 + + if fusion_policy == None: + return x1_feat, x2_feat + + output = FeatureFusionNeck.fusion(x1_feat, x2_feat, fusion_policy) + return output + + @staticmethod + def warp(x, flow): + n, c, h, w = x.size() + + norm = torch.tensor([[[[w, h]]]]).type_as(x).to(x.device) + col = torch.linspace(-1.0, 1.0, h).view(-1, 1).repeat(1, w) + row = torch.linspace(-1.0, 1.0, w).repeat(h, 1) + grid = torch.cat((row.unsqueeze(2), col.unsqueeze(2)), 2) + grid = grid.repeat(n, 1, 1, 1).type_as(x).to(x.device) + grid = grid + flow.permute(0, 2, 3, 1) / norm + + output = F.grid_sample(x, grid, align_corners=True) + return output + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. \ + Here MixFFN is uesd as projection head of Changer. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super(MixFFN, self).__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, identity=None): + out = self.layers(x) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@MODELS.register_module() +class Changer(BaseDecodeHead): + """The Head of Changer. + + This head is the implementation of + `Changer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels // 2, + kernel_size=1, + norm_cfg=self.norm_cfg) + + self.neck_layer = FDAF(in_channels=self.channels // 2) + + # projection head + self.discriminator = MixFFN( + embed_dims=self.channels, + feedforward_channels=self.channels, + ffn_drop=0., + dropout_layer=dict(type='DropPath', drop_prob=0.), + act_cfg=dict(type='GELU')) + + def base_forward(self, inputs): + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + return out + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + inputs1 = [] + inputs2 = [] + for input in inputs: + f1, f2 = torch.chunk(input, 2, dim=1) + inputs1.append(f1) + inputs2.append(f2) + + out1 = self.base_forward(inputs1) + out2 = self.base_forward(inputs2) + out = self.neck_layer(out1, out2, 'concat') + + out = self.discriminator(out) + out = self.cls_seg(out) + + return out diff --git a/opencd/models/decode_heads/general_scd_head.py b/opencd/models/decode_heads/general_scd_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8314357da64d015f0f5fa852a871ad5d5c4b79 --- /dev/null +++ b/opencd/models/decode_heads/general_scd_head.py @@ -0,0 +1,26 @@ +# Copyright (c) Open-CD. All rights reserved. +from opencd.registry import MODELS +from .multi_head import MultiHeadDecoder + + +@MODELS.register_module() +class GeneralSCDHead(MultiHeadDecoder): + """The Head of General Semantic Change Detection Head.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, inputs): + inputs1, inputs2 = inputs + out1 = self.semantic_cd_head(inputs1) + out2 = self.semantic_cd_head_aux(inputs2) + inputs_ = self.binary_cd_neck(inputs1, inputs2) + out = self.binary_cd_head(inputs_) + + out_dict = dict( + seg_logits=out, + seg_logits_from=out1, + seg_logits_to=out2 + ) + + return out_dict \ No newline at end of file diff --git a/opencd/models/decode_heads/identity_head.py b/opencd/models/decode_heads/identity_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5ad0868e600ca86f60c736e4c1ed8babaa6ea7 --- /dev/null +++ b/opencd/models/decode_heads/identity_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenCD. All rights reserved. +import torch +import torch.nn as nn + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from opencd.registry import MODELS + + +@MODELS.register_module() +class IdentityHead(BaseDecodeHead): + """Identity Head.""" + + def __init__(self, **kwargs): + super().__init__(channels=1, **kwargs) + delattr(self, 'conv_seg') + + def init_weights(self): + pass + + def _forward_feature(self, inputs): + """ + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + return x + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + return output + + +@MODELS.register_module() +class DSIdentityHead(BaseDecodeHead): + """Deep Supervision Identity Head.""" + + def __init__(self, **kwargs): + super().__init__(channels=1, **kwargs) + delattr(self, 'conv_seg') + + def init_weights(self): + pass + + def _forward_feature(self, inputs): + """ + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + return x + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + return output + + def loss_by_feat(self, seg_logits, batch_data_samples): + """Compute segmentation loss. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + seg_label_size = seg_label.shape[2:] + for seg_idx, single_seg_logit in enumerate(seg_logits): + single_seg_logit = resize( + input=single_seg_logit, + size=seg_label_size, + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(single_seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + loss_name = f'aux_{seg_idx}_' + loss_decode.loss_name + if loss_decode.loss_name not in loss: + loss[loss_name] = loss_decode( + single_seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_name] += loss_decode( + single_seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + single_seg_logit, seg_label, ignore_index=self.ignore_index) + return loss diff --git a/opencd/models/decode_heads/multi_head.py b/opencd/models/decode_heads/multi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..45230e4373d8a70aae57c0eabf85c9c35ff397f9 --- /dev/null +++ b/opencd/models/decode_heads/multi_head.py @@ -0,0 +1,170 @@ +# Copyright (c) Open-CD. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +from mmengine.model import BaseModule +from mmengine.structures import PixelData +from torch import Tensor, nn + +# from mmseg.models import builder +from mmseg.models.utils import resize +from mmseg.structures import SegDataSample +from mmseg.utils import ConfigType, SampleList, add_prefix +from opencd.registry import MODELS + + +@MODELS.register_module() +class MultiHeadDecoder(BaseModule): + """Base class for MultiHeadDecoder. + + Args: + binary_cd_head (dict): The decode head for binary change detection branch. + binary_cd_neck (dict): The feature fusion part for binary \ + change detection branch + semantic_cd_head (dict): The decode head for semantic change \ + detection `from` branch. + semantic_cd_head_aux (dict): The decode head for semantic change \ + detection `to` branch. If None, the siamese semantic head will \ + be used. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + binary_cd_head, + binary_cd_neck=None, + semantic_cd_head=None, + semantic_cd_head_aux=None, + init_cfg=None): + super().__init__(init_cfg) + self.binary_cd_head = MODELS.build(binary_cd_head) + self.siamese_semantic_head = True + if binary_cd_neck is not None: + self.binary_cd_neck = MODELS.build(binary_cd_neck) + if semantic_cd_head is not None: + self.semantic_cd_head = MODELS.build(semantic_cd_head) + if semantic_cd_head_aux is not None: + self.siamese_semantic_head = False + self.semantic_cd_head_aux = MODELS.build(semantic_cd_head_aux) + else: + self.semantic_cd_head_aux = self.semantic_cd_head + + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function. + The return value should be a dict() containing: + `seg_logits`, `seg_logits_from` and `seg_logits_to`. + + For example: + return dict( + seg_logits=out, + seg_logits_from=out1, + seg_logits_to=out2) + """ + pass + + def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Forward function for training. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def predict(self, inputs, batch_img_metas: List[dict], test_cfg, + **kwargs) -> List[Tensor]: + """Forward function for testing.""" + seg_logits = self.forward(inputs) + return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \ + == list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \ + and `seg_logits_to` should be contained." + + self.align_corners = { + 'seg_logits': self.binary_cd_head.align_corners, + 'seg_logits_from': self.semantic_cd_head.align_corners, + 'seg_logits_to': self.semantic_cd_head_aux.align_corners} + + for seg_name, seg_logit in seg_logits.items(): + seg_logits[seg_name] = resize( + input=seg_logit, + size=batch_img_metas[0]['img_shape'], + mode='bilinear', + align_corners=self.align_corners[seg_name]) + return seg_logits + + def get_sub_batch_data_samples(self, batch_data_samples: SampleList, + sub_metainfo_name: str, + sub_data_name: str) -> list: + sub_batch_sample_list = [] + for i in range(len(batch_data_samples)): + data_sample = SegDataSample() + + gt_sem_seg_data = dict( + data=batch_data_samples[i].get(sub_data_name).data) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + img_meta = {} + seg_map_path = batch_data_samples[i].metainfo.get(sub_metainfo_name) + for key in batch_data_samples[i].metainfo.keys(): + if not 'seg_map_path' in key: + img_meta[key] = batch_data_samples[i].metainfo.get(key) + img_meta['seg_map_path'] = seg_map_path + data_sample.set_metainfo(img_meta) + + sub_batch_sample_list.append(data_sample) + return sub_batch_sample_list + + def loss_by_feat(self, seg_logits: dict, + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute segmentation loss.""" + assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \ + == list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \ + and `seg_logits_to` should be contained." + + losses = dict() + binary_cd_loss_decode = self.binary_cd_head.loss_by_feat( + seg_logits['seg_logits'], + self.get_sub_batch_data_samples(batch_data_samples, + sub_metainfo_name='seg_map_path', + sub_data_name='gt_sem_seg')) + losses.update(add_prefix(binary_cd_loss_decode, 'binary_cd')) + + if getattr(self, 'semantic_cd_head'): + semantic_cd_loss_decode_from = self.semantic_cd_head.loss_by_feat( + seg_logits['seg_logits_from'], + self.get_sub_batch_data_samples(batch_data_samples, + sub_metainfo_name='seg_map_path_from', + sub_data_name='gt_sem_seg_from')) + losses.update(add_prefix(semantic_cd_loss_decode_from, 'semantic_cd_from')) + + semantic_cd_loss_decode_to = self.semantic_cd_head_aux.loss_by_feat( + seg_logits['seg_logits_to'], + self.get_sub_batch_data_samples(batch_data_samples, + sub_metainfo_name='seg_map_path_to', + sub_data_name='gt_sem_seg_to')) + losses.update(add_prefix(semantic_cd_loss_decode_to, 'semantic_cd_to')) + + return losses \ No newline at end of file diff --git a/opencd/models/decode_heads/sta_head.py b/opencd/models/decode_heads/sta_head.py new file mode 100644 index 0000000000000000000000000000000000000000..16ab3ef48c2584c132186bc464e9f538f02a4f75 --- /dev/null +++ b/opencd/models/decode_heads/sta_head.py @@ -0,0 +1,373 @@ +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.utils import resize +from opencd.registry import MODELS + + +class BAM(nn.Module): + """ Basic self-attention module + """ + + def __init__(self, in_dim, ds=8, activation=nn.ReLU): + super(BAM, self).__init__() + self.chanel_in = in_dim + self.key_channel = self.chanel_in // 8 + self.activation = activation + self.ds = ds # + self.pool = nn.AvgPool2d(self.ds) + self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + + def forward(self, input): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + x = self.pool(input) + m_batchsize, C, width, height = x.size() + proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds) + proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds) + energy = torch.bmm(proj_query, proj_key) # transpose check + energy = (self.key_channel ** -.5) * energy + + attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds) + + proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(m_batchsize, C, width, height) + + out = F.interpolate(out, [width * self.ds, height * self.ds]) + out = out + input + + return out + + +class _PAMBlock(nn.Module): + ''' + The basic implementation for self-attention block/non-local block + Input/Output: + N * C * H * (2*W) + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + value_channels : the dimension after the value transform + scale : choose the scale to partition the input feature maps + ds : downsampling scale + ''' + + def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1): + super(_PAMBlock, self).__init__() + self.scale = scale + self.ds = ds + self.pool = nn.AvgPool2d(self.ds) + self.in_channels = in_channels + self.key_channels = key_channels + self.value_channels = value_channels + + self.f_key = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(self.key_channels) + ) + self.f_query = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(self.key_channels) + ) + self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, + kernel_size=1, stride=1, padding=0) + + def forward(self, input): + x = input + if self.ds != 1: + x = self.pool(input) + # input shape: b,c,h,2w + batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + local_y = [] + local_x = [] + step_h, step_w = h // self.scale, w // self.scale + for i in range(0, self.scale): + for j in range(0, self.scale): + start_x, start_y = i * step_h, j * step_w + end_x, end_y = min(start_x + step_h, h), min(start_y + step_w, w) + if i == (self.scale - 1): + end_x = h + if j == (self.scale - 1): + end_y = w + local_x += [start_x, end_x] + local_y += [start_y, end_y] + + value = self.f_value(x) + query = self.f_query(x) + key = self.f_key(x) + + value = torch.stack([value[:, :, :, :w], value[:, :, :, w:]], 4) # B*N*H*W*2 + query = torch.stack([query[:, :, :, :w], query[:, :, :, w:]], 4) # B*N*H*W*2 + key = torch.stack([key[:, :, :, :w], key[:, :, :, w:]], 4) # B*N*H*W*2 + + local_block_cnt = 2 * self.scale * self.scale + + # self-attention func + def func(value_local, query_local, key_local): + batch_size_new = value_local.size(0) + h_local, w_local = value_local.size(2), value_local.size(3) + value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1) + + query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1) + query_local = query_local.permute(0, 2, 1) + key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1) + + sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context_local = torch.bmm(value_local, sim_map.permute(0, 2, 1)) + # context_local = context_local.permute(0, 2, 1).contiguous() + context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2) + return context_local + + # Parallel Computing to speed up + # reshape value_local, q, k + v_list = [value[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in + range(0, local_block_cnt, 2)] + v_locals = torch.cat(v_list, dim=0) + q_list = [query[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in + range(0, local_block_cnt, 2)] + q_locals = torch.cat(q_list, dim=0) + k_list = [key[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)] + k_locals = torch.cat(k_list, dim=0) + context_locals = func(v_locals, q_locals, k_locals) + + context_list = [] + for i in range(0, self.scale): + row_tmp = [] + for j in range(0, self.scale): + left = batch_size * (j + i * self.scale) + right = batch_size * (j + i * self.scale) + batch_size + tmp = context_locals[left:right] + row_tmp.append(tmp) + context_list.append(torch.cat(row_tmp, 3)) + + context = torch.cat(context_list, 2) + context = torch.cat([context[:, :, :, :, 0], context[:, :, :, :, 1]], 3) + + if self.ds != 1: + context = F.interpolate(context, [h * self.ds, 2 * w * self.ds]) + + return context + + +class PAMBlock(_PAMBlock): + def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1): + if key_channels == None: + key_channels = in_channels // 8 + if value_channels == None: + value_channels = in_channels + super(PAMBlock, self).__init__(in_channels, key_channels, value_channels, scale, ds) + + +class PAM(nn.Module): + """ + PAM module + """ + + def __init__(self, in_channels, out_channels, sizes=([1]), ds=1): + super(PAM, self).__init__() + self.group = len(sizes) + self.stages = [] + self.ds = ds # output stride + self.value_channels = out_channels + self.key_channels = out_channels // 8 + + self.stages = nn.ModuleList( + [self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds) + for size in sizes]) + self.conv_bn = nn.Sequential( + nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0, bias=False), + # nn.BatchNorm2d(out_channels), + ) + + def _make_stage(self, in_channels, key_channels, value_channels, size, ds): + return PAMBlock(in_channels, key_channels, value_channels, size, ds) + + def forward(self, feats): + priors = [stage(feats) for stage in self.stages] + + # concat + context = [] + for i in range(0, len(priors)): + context += [priors[i]] + output = self.conv_bn(torch.cat(context, 1)) + + return output + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class CDSA(nn.Module): + """self attention module for change detection + """ + + def __init__(self, in_c, ds=1, mode='BAM'): + super(CDSA, self).__init__() + self.in_C = in_c + self.ds = ds + self.mode = mode + if self.mode == 'BAM': + self.Self_Att = BAM(self.in_C, ds=self.ds) + elif self.mode == 'PAM': + self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1, 2, 4, 8], ds=self.ds) + elif self.mode == 'None': + self.Self_Att = nn.Identity() + self.apply(weights_init) + + def forward(self, x1, x2): + height = x1.shape[3] + x = torch.cat((x1, x2), 3) + x = self.Self_Att(x) + return x[:, :, :, 0:height], x[:, :, :, height:] + + +@MODELS.register_module() +class STAHead(BaseDecodeHead): + """The Head of STANet. + + Args: + sa_mode: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__( + self, + sa_mode='PAM', + sa_in_channels=256, + sa_ds=1, + distance_threshold=1, + **kwargs): + super().__init__(input_transform='multiple_select', num_classes=1, **kwargs) + + num_inputs = len(self.in_channels) + assert num_inputs == len(self.in_index) + self.distance_threshold = distance_threshold + + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels: + fpn_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = nn.Sequential( + ConvModule( + len(self.in_channels) * self.channels, + sa_in_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Dropout(0.5), + ConvModule( + sa_in_channels, + sa_in_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ) + + self.netA = CDSA(in_c=sa_in_channels, ds=sa_ds, mode=sa_mode) + self.calc_dist = nn.PairwiseDistance(keepdim=True) + self.conv_seg = nn.Identity() + + def base_forward(self, inputs): + fpn_outs = [ + self.fpn_convs[i](inputs[i]) + for i in range(len(self.in_channels)) + ] + + for i in range(len(self.in_channels)): + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + feats = self.fpn_bottleneck(fpn_outs) + return feats + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + inputs1 = [] + inputs2 = [] + for input in inputs: + f1, f2 = torch.chunk(input, 2, dim=1) + inputs1.append(f1) + inputs2.append(f2) + + f1 = self.base_forward(inputs1) + f2 = self.base_forward(inputs2) + f1, f2 = self.netA(f1, f2) + + # if you use PyTorch<=1.8, there may be some problems. + # see https://github.com/justchenhao/STANet/issues/85 + f1 = f1.permute(0, 2, 3, 1) + f2 = f2.permute(0, 2, 3, 1) + dist = self.calc_dist(f1, f2).permute(0, 3, 1, 2) + + dist = F.interpolate(dist, size=inputs[0].shape[2:], mode='bilinear', align_corners=True) + + return dist + + def predict_by_feat(self, seg_logits, batch_img_metas): + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + + seg_logits_copy = copy.deepcopy(seg_logits) + seg_logits[seg_logits_copy > self.distance_threshold] = 100 + seg_logits[seg_logits_copy <= self.distance_threshold] = -100 + + seg_logits = resize( + input=seg_logits, + size=batch_img_metas[0]['img_shape'], + mode='bilinear', + align_corners=self.align_corners) + return seg_logits diff --git a/opencd/models/decode_heads/tiny_head.py b/opencd/models/decode_heads/tiny_head.py new file mode 100644 index 0000000000000000000000000000000000000000..461fbd8f180c535cfb2fb4c3edf955efd464eade --- /dev/null +++ b/opencd/models/decode_heads/tiny_head.py @@ -0,0 +1,87 @@ +# Copyright (c) Open-CD. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.utils import resize +from opencd.registry import MODELS + + +@MODELS.register_module() +class TinyHead(BaseDecodeHead): + """ + This head is the implementation of `TinyCDv2 + `_. + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + priori_attn (bool): Whether use Priori Guiding Connection. + Default to False. + """ + + def __init__(self, feature_strides, priori_attn=False, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + if priori_attn: + attn_channels = self.in_channels[0] + self.in_channels = self.in_channels[1:] + feature_strides = feature_strides[1:] + self.feature_strides = feature_strides + self.priori_attn = priori_attn + + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + scale_head = [] + scale_head.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + groups=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + if self.priori_attn: + self.gen_diff_attn = ConvModule( + in_channels=attn_channels // 2, + out_channels=self.channels, + kernel_size=1, + stride=1, + groups=1, + norm_cfg=None, + act_cfg=None + ) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + if self.priori_attn: + early_x = x[0] + x = x[1:] + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + if self.priori_attn: + x1_, x2_ = torch.chunk(early_x, 2, dim=1) + diff_x = torch.abs(x1_ - x2_) + diff_x = self.gen_diff_attn(diff_x) + if diff_x.shape != output.shape: + output = resize(output, diff_x.shape[2:], mode='bilinear', align_corners=self.align_corners) + output = output * torch.sigmoid(diff_x) + output + + output = self.cls_seg(output) + return output \ No newline at end of file diff --git a/opencd/models/losses/__init__.py b/opencd/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..310eb9f978981a2a0d86f6978fee5b92371664d2 --- /dev/null +++ b/opencd/models/losses/__init__.py @@ -0,0 +1,3 @@ +from .bcl_loss import BCLLoss + +__all__ = ['BCLLoss'] \ No newline at end of file diff --git a/opencd/models/losses/__pycache__/__init__.cpython-311.pyc b/opencd/models/losses/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11833400942f42b0e1ad580257edd98bcf53e571 Binary files /dev/null and b/opencd/models/losses/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/losses/__pycache__/bcl_loss.cpython-311.pyc b/opencd/models/losses/__pycache__/bcl_loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c32b5e7648ebad8038e558bb8e2c0d9584e86876 Binary files /dev/null and b/opencd/models/losses/__pycache__/bcl_loss.cpython-311.pyc differ diff --git a/opencd/models/losses/bcl_loss.py b/opencd/models/losses/bcl_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c0ad20cbce49c3b37507a287b68fda9341c360 --- /dev/null +++ b/opencd/models/losses/bcl_loss.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + +from opencd.registry import MODELS + + +def bcl_loss( + pred, + target, + margin=2.0, + eps=1e-4, + ignore_index=255, + **kwargs): + pred = pred.squeeze() + target = target.squeeze() + assert pred.size() == target.size() and target.numel() > 0 + mask = (target != ignore_index).float() + target = target * mask + utarget = 1 - target + n_u = utarget.sum() + eps + n_c = target.sum() + eps + loss = torch.sum(utarget * torch.pow(pred, 2) * mask) / n_u + \ + torch.sum(target * torch.pow(torch.clamp(margin - pred, min=0.), 2)) / n_c + return loss + + +@MODELS.register_module() +class BCLLoss(nn.Module): + """Batch-balanced Contrastive Loss""" + + def __init__( + self, + margin=2.0, + loss_weight=1.0, + ignore_index=255, + loss_name='bcl_loss', + **kwargs): + super().__init__() + self.margin = margin + self.loss_weight = loss_weight + self.ignore_index = ignore_index + self._loss_name = loss_name + + def forward(self, + pred, + target, + **kwargs): + + loss = self.loss_weight * bcl_loss( + pred, target, self.margin, self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name \ No newline at end of file diff --git a/opencd/models/necks/__init__.py b/opencd/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..100e4b4aaca2999904072dd2c0f7d13e17163bd2 --- /dev/null +++ b/opencd/models/necks/__init__.py @@ -0,0 +1,4 @@ +from .feature_fusion import FeatureFusionNeck +from .tiny_fpn import TinyFPN + +__all__ = ['FeatureFusionNeck', 'TinyFPN'] \ No newline at end of file diff --git a/opencd/models/necks/__pycache__/__init__.cpython-311.pyc b/opencd/models/necks/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5593d25b619930be6d0e2308275dd2ef0b54f62 Binary files /dev/null and b/opencd/models/necks/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/necks/__pycache__/feature_fusion.cpython-311.pyc b/opencd/models/necks/__pycache__/feature_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6114495e7c1ce2724c7f4f95b95a0aad51f55de2 Binary files /dev/null and b/opencd/models/necks/__pycache__/feature_fusion.cpython-311.pyc differ diff --git a/opencd/models/necks/__pycache__/tiny_fpn.cpython-311.pyc b/opencd/models/necks/__pycache__/tiny_fpn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..475d7a96c6d1f71347e7c452c92ac67e56261eb4 Binary files /dev/null and b/opencd/models/necks/__pycache__/tiny_fpn.cpython-311.pyc differ diff --git a/opencd/models/necks/feature_fusion.py b/opencd/models/necks/feature_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..74d3f4103d5982ce8459df296d3a488901c41ffb --- /dev/null +++ b/opencd/models/necks/feature_fusion.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from opencd.registry import MODELS + + +@MODELS.register_module() +class FeatureFusionNeck(BaseModule): + """Feature Fusion Neck. + + Args: + policy (str): The operation to fuse features. candidates + are `concat`, `sum`, `diff` and `Lp_distance`. + in_channels (Sequence(int)): Input channels. + channels (int): Channels after modules, before conv_seg. + out_indices (tuple[int]): Output from which layer. + """ + + def __init__(self, + policy, + in_channels=None, + channels=None, + out_indices=(0, 1, 2, 3)): + super().__init__() + self.policy = policy + self.in_channels = in_channels + self.channels = channels + self.out_indices = out_indices + + @staticmethod + def fusion(x1, x2, policy): + """Specify the form of feature fusion""" + + _fusion_policies = ['concat', 'sum', 'diff', 'abs_diff'] + assert policy in _fusion_policies, 'The fusion policies {} are ' \ + 'supported'.format(_fusion_policies) + + if policy == 'concat': + x = torch.cat([x1, x2], dim=1) + elif policy == 'sum': + x = x1 + x2 + elif policy == 'diff': + x = x2 - x1 + elif policy == 'abs_diff': + x = torch.abs(x1 - x2) + + return x + + def forward(self, x1, x2): + """Forward function.""" + + assert len(x1) == len(x2), "The features x1 and x2 from the" \ + "backbone should be of equal length" + outs = [] + for i in range(len(x1)): + out = self.fusion(x1[i], x2[i], self.policy) + outs.append(out) + + outs = [outs[i] for i in self.out_indices] + return tuple(outs) \ No newline at end of file diff --git a/opencd/models/necks/tiny_fpn.py b/opencd/models/necks/tiny_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee9dd87430413691227dc85405bd814b07819b1 --- /dev/null +++ b/opencd/models/necks/tiny_fpn.py @@ -0,0 +1,210 @@ +# Copyright (c) Open-CD. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.models.utils import resize +from opencd.registry import MODELS +from ..backbones.tinynet import TinyBlock + + +@MODELS.register_module() +class TinyFPN(BaseModule): + """Feature Pyramid Network. + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + exist_early_x (bool): If True, the first feature in `inputs` will be + ignored and placed at the 0 index of the `output`. Default to False. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + exist_early_x=False, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super().__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + self.exist_early_x = exist_early_x + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = TinyBlock( + in_channels=out_channels, + out_channels=out_channels, + stride=1, + expand_ratio=1) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = TinyBlock( + in_channels=in_channels, + out_channels=out_channels, + stride=2, + expand_ratio=1) + self.fpn_convs.append(extra_fpn_conv) + + def forward(self, inputs): + if self.exist_early_x: + early_x = inputs[0] + inputs = inputs[1:] + + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + if self.exist_early_x: + outs = [early_x] + outs + return tuple(outs) \ No newline at end of file diff --git a/opencd/models/utils/__init__.py b/opencd/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c8062af7251d69bfd06d3efcec8c5ea5b6d3d4 --- /dev/null +++ b/opencd/models/utils/__init__.py @@ -0,0 +1,8 @@ +from .builder import build_interaction_layer +from .interaction_layer import (Aggregation_distribution, ChannelExchange, + SpatialExchange, TwoIdentity) + +__all__ = [ + 'build_interaction_layer', 'Aggregation_distribution', 'ChannelExchange', + 'SpatialExchange', 'TwoIdentity' +] diff --git a/opencd/models/utils/__pycache__/__init__.cpython-311.pyc b/opencd/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1831055e69e34bc4f1d3b8d8c7e8ea8c7933f863 Binary files /dev/null and b/opencd/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/models/utils/__pycache__/builder.cpython-311.pyc b/opencd/models/utils/__pycache__/builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cde8e5cb718fa85f108f0ffe068300eba6e2dff6 Binary files /dev/null and b/opencd/models/utils/__pycache__/builder.cpython-311.pyc differ diff --git a/opencd/models/utils/__pycache__/interaction_layer.cpython-311.pyc b/opencd/models/utils/__pycache__/interaction_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a465416684fa9284305b54c4f93d5782ecef986 Binary files /dev/null and b/opencd/models/utils/__pycache__/interaction_layer.cpython-311.pyc differ diff --git a/opencd/models/utils/builder.py b/opencd/models/utils/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..07ea91ea1e10c8f4d5cc568d4303dcde9574a91b --- /dev/null +++ b/opencd/models/utils/builder.py @@ -0,0 +1,11 @@ +import warnings + +from opencd.registry import MODELS + +ITERACTION_LAYERS = MODELS + +def build_interaction_layer(cfg): + """Build backbone.""" + warnings.warn('``build_interaction_layer`` would be deprecated soon, please use ' + '``opencd.registry.MODELS.build()`` ') + return ITERACTION_LAYERS.build(cfg) diff --git a/opencd/models/utils/interaction_layer.py b/opencd/models/utils/interaction_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..9418ef5675bd2893d99facfd8052663591306a34 --- /dev/null +++ b/opencd/models/utils/interaction_layer.py @@ -0,0 +1,103 @@ +# Copyright (c) Open-CD. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from opencd.models.utils.builder import ITERACTION_LAYERS + + +@ITERACTION_LAYERS.register_module() +class ChannelExchange(BaseModule): + """ + channel exchange + Args: + p (float, optional): p of the features will be exchanged. + Defaults to 1/2. + """ + def __init__(self, p=1/2): + super().__init__() + assert p >= 0 and p <= 1 + self.p = int(1/p) + + def forward(self, x1, x2): + N, c, h, w = x1.shape + + exchange_map = torch.arange(c) % self.p == 0 + exchange_mask = exchange_map.unsqueeze(0).expand((N, -1)) + + out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2) + out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...] + out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...] + out_x1[exchange_mask, ...] = x2[exchange_mask, ...] + out_x2[exchange_mask, ...] = x1[exchange_mask, ...] + + return out_x1, out_x2 + + +@ITERACTION_LAYERS.register_module() +class SpatialExchange(BaseModule): + """ + spatial exchange + Args: + p (float, optional): p of the features will be exchanged. + Defaults to 1/2. + """ + def __init__(self, p=1/2): + super().__init__() + assert p >= 0 and p <= 1 + self.p = int(1/p) + + def forward(self, x1, x2): + N, c, h, w = x1.shape + exchange_mask = torch.arange(w) % self.p == 0 + + out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2) + out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask] + out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask] + out_x1[..., exchange_mask] = x2[..., exchange_mask] + out_x2[..., exchange_mask] = x1[..., exchange_mask] + + return out_x1, out_x2 + + +@ITERACTION_LAYERS.register_module() +class Aggregation_distribution(BaseModule): + # Aggregation_Distribution Layer (AD) + def __init__(self, + channels, + num_paths=2, + attn_channels=None, + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN', requires_grad=True)): + super().__init__() + self.num_paths = num_paths # `2` is supported. + attn_channels = attn_channels or channels // 16 + attn_channels = max(attn_channels, 8) + + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = build_norm_layer(norm_cfg, attn_channels)[1] + self.act = build_activation_layer(act_cfg) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x1, x2): + x = torch.stack([x1, x2], dim=1) + attn = x.sum(1).mean((2, 3), keepdim=True) + attn = self.fc_reduce(attn) + attn = self.bn(attn) + attn = self.act(attn) + attn = self.fc_select(attn) + B, C, H, W = attn.shape + attn1, attn2 = attn.reshape(B, self.num_paths, C // self.num_paths, H, W).transpose(0, 1) + attn1 = torch.sigmoid(attn1) + attn2 = torch.sigmoid(attn2) + return x1 * attn1, x2 * attn2 + + +@ITERACTION_LAYERS.register_module() +class TwoIdentity(BaseModule): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x1, x2): + return x1, x2 diff --git a/opencd/registry.py b/opencd/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c0529149798740a2ba4a17ac2270f1132d8e69a8 --- /dev/null +++ b/opencd/registry.py @@ -0,0 +1,98 @@ +# Copyright (c) Open-CD. All rights reserved. +"""Open-CD provides 17 registry nodes to support using modules across projects. +Each node is a child of the root registry in MMEngine. +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry( + 'runner', parent=MMENGINE_RUNNERS, locations=['opencd.engine']) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', + parent=MMENGINE_RUNNER_CONSTRUCTORS, + locations=['opencd.engine']) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS, locations=['opencd.engine']) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['opencd.engine.hooks']) + +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['opencd.datasets']) +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['opencd.datasets']) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['opencd.datasets.transforms']) + +# manage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['opencd.models']) +# manage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['opencd.models']) +# manage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['opencd.models']) + +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['opencd.evaluation']) + +# manage task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['opencd.models']) + +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['opencd.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['opencd.visualization']) + +# manage logprocessor +LOG_PROCESSORS = Registry( + 'log_processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['opencd.visualization']) + +# manage inferencer +INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS) \ No newline at end of file diff --git a/opencd/version.py b/opencd/version.py new file mode 100644 index 0000000000000000000000000000000000000000..5612ce72c306e9cf7e0a105472d8b5d4fcfe5237 --- /dev/null +++ b/opencd/version.py @@ -0,0 +1,23 @@ +# Copyright (c) Open-CD. All rights reserved. + +__version__ = '1.1.0' + +from typing import Tuple + +short_version = __version__ + + +def parse_version_info(version_str: str) -> Tuple: + """Parse version info of Open-CD.""" + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) \ No newline at end of file diff --git a/opencd/visualization/__init__.py b/opencd/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af00add480eb0c7d2da56af0ce5b31d02c10b60a --- /dev/null +++ b/opencd/visualization/__init__.py @@ -0,0 +1,4 @@ +from .cd_local_visualizer import CDLocalVisualizer +from .cd_vis_backend import CDLocalVisBackend + +__all__ = ['CDLocalVisBackend', 'CDLocalVisualizer'] diff --git a/opencd/visualization/__pycache__/__init__.cpython-311.pyc b/opencd/visualization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c93150b7f6d7c861c66403dc712a69b641389d7 Binary files /dev/null and b/opencd/visualization/__pycache__/__init__.cpython-311.pyc differ diff --git a/opencd/visualization/__pycache__/cd_local_visualizer.cpython-311.pyc b/opencd/visualization/__pycache__/cd_local_visualizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d8eb3396fe4dd57110254c5d715451cb445b7ac Binary files /dev/null and b/opencd/visualization/__pycache__/cd_local_visualizer.cpython-311.pyc differ diff --git a/opencd/visualization/__pycache__/cd_vis_backend.cpython-311.pyc b/opencd/visualization/__pycache__/cd_vis_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9dd1b3c4a38ab19a92c4ed76dd0411bc1e56b07 Binary files /dev/null and b/opencd/visualization/__pycache__/cd_vis_backend.cpython-311.pyc differ diff --git a/opencd/visualization/cd_local_visualizer.py b/opencd/visualization/cd_local_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c84882704e1864b62a1d24d51c9f5a8f50df9b --- /dev/null +++ b/opencd/visualization/cd_local_visualizer.py @@ -0,0 +1,215 @@ +from typing import Optional, Sequence + +import mmcv +import numpy as np +from mmengine.dist import master_only + +from mmseg.structures import SegDataSample +from mmseg.visualization import SegLocalVisualizer +from opencd.registry import VISUALIZERS + + +@VISUALIZERS.register_module() +class CDLocalVisualizer(SegLocalVisualizer): + """Change Detection Local Visualizer. """ + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + image_from_to: Sequence[np.array], + data_sample: Optional[SegDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + step: int = 0, + with_labels: Optional[bool] = False) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. it is usually used when the display + is not available. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + image_from_to (Sequence[np.array]): The image pairs to draw. + gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. + Defaults to None. + pred_sample (:obj:`SegDataSample`, optional): Prediction + SegDataSample. Defaults to None. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + step (int): Global step value to record. Defaults to 0. + with_labels(bool, optional): Add semantic labels in visualization + result, Defaults to True. + """ + exist_img_from_to = True if len(image_from_to) > 0 else False + if exist_img_from_to: + assert len(image_from_to) == 2, '`image_from_to` contains `from` ' \ + 'and `to` images' + + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + semantic_classes = self.dataset_meta.get('semantic_classes', None) + semantic_palette = self.dataset_meta.get('semantic_palette', None) + + gt_img_data = None + gt_img_data_from = None + gt_img_data_to = None + pred_img_data = None + pred_img_data_from = None + pred_img_data_to = None + + drawn_img_from = None + drawn_img_to = None + + if draw_gt and data_sample is not None and 'gt_sem_seg' in data_sample: + gt_img_data = image + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing change ' \ + 'deteaction results.' + gt_img_data = self._draw_sem_seg(gt_img_data, data_sample.gt_sem_seg, + classes, palette, with_labels) + if draw_gt and data_sample is not None and 'gt_sem_seg_from' in data_sample \ + and 'gt_sem_seg_to' in data_sample: + if exist_img_from_to: + gt_img_data_from = image_from_to[0] + gt_img_data_to = image_from_to[1] + else: + gt_img_data_from = np.zeros_like(image) + gt_img_data_to = np.zeros_like(image) + assert semantic_classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing change ' \ + 'deteaction results.' + gt_img_data_from = self._draw_sem_seg(gt_img_data_from, + data_sample.gt_sem_seg_from, semantic_classes, + semantic_palette, with_labels) + gt_img_data_to = self._draw_sem_seg(gt_img_data_to, + data_sample.gt_sem_seg_to, semantic_classes, + semantic_palette, with_labels) + + if (draw_pred and data_sample is not None + and 'pred_sem_seg' in data_sample): + pred_img_data = image + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + pred_img_data = self._draw_sem_seg(pred_img_data, + data_sample.pred_sem_seg, + classes, palette, + with_labels) + + if (draw_pred and data_sample is not None and 'pred_sem_seg_from' in data_sample \ + and 'pred_sem_seg_to' in data_sample): + if exist_img_from_to: + pred_img_data_from = image_from_to[0] + pred_img_data_to = image_from_to[1] + else: + pred_img_data_from = np.zeros_like(image) + pred_img_data_to = np.zeros_like(image) + assert semantic_classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing change ' \ + 'deteaction results.' + pred_img_data_from = self._draw_sem_seg(pred_img_data_from, + data_sample.pred_sem_seg_from, semantic_classes, + semantic_palette, with_labels) + pred_img_data_to = self._draw_sem_seg(pred_img_data_to, + data_sample.pred_sem_seg_to, semantic_classes, + semantic_palette, with_labels) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if gt_img_data_from is not None and pred_img_data_from is not None: + drawn_img_from = np.concatenate((gt_img_data_from, pred_img_data_from), axis=1) + elif gt_img_data_from is not None: + drawn_img_from = gt_img_data_from + else: + drawn_img_from = pred_img_data_from + + if gt_img_data_to is not None and pred_img_data_to is not None: + drawn_img_to = np.concatenate((gt_img_data_to, pred_img_data_to), axis=1) + elif gt_img_data_to is not None: + drawn_img_to = gt_img_data_to + else: + drawn_img_to = pred_img_data_to + + if show: + if drawn_img_from is not None and drawn_img_to is not None: + drawn_img_cat = np.concatenate((drawn_img, drawn_img_from, drawn_img_to), axis=0) + self.show(drawn_img_cat, win_name=name, wait_time=wait_time) + else: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + if drawn_img_from is not None and drawn_img_to is not None: + drawn_img_cat = np.concatenate((drawn_img, drawn_img_from, drawn_img_to), axis=0) + mmcv.imwrite(mmcv.bgr2rgb(drawn_img_cat), out_file) + else: + mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file) + else: + self.add_image(name, drawn_img, drawn_img_from, drawn_img_to, step) + + @master_only + def add_image(self, name: str, + image: np.ndarray, + image_from: np.ndarray = None, + image_to: np.ndarray = None, + step: int = 0) -> None: + """Record the image. + + Args: + name (str): The image identifier. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Defaults to None. + step (int): Global step value to record. Defaults to 0. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_image(name, image, image_from, image_to, step) # type: ignore + + @master_only + def set_image(self, image: np.ndarray) -> None: + """Set the image to draw. + + Args: + image (np.ndarray): The image to draw. + """ + assert image is not None + image = image.astype('uint8') + self._image = image + self.width, self.height = image.shape[1], image.shape[0] + # print(image.shape) + self._default_font_size = max( + np.sqrt(self.height * self.width) // 90, 10) + + self.fig_save.set_size_inches( # type: ignore + self.width / self.dpi, self.height / self.dpi) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + self.ax_save.cla() + self.ax_save.axis(False) + self.ax_save.imshow( + image, + extent=(0, self.width, self.height, 0), + interpolation='none') diff --git a/opencd/visualization/cd_vis_backend.py b/opencd/visualization/cd_vis_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..f9fa9803b63562a8823ea808664ff7abeab0c541 --- /dev/null +++ b/opencd/visualization/cd_vis_backend.py @@ -0,0 +1,45 @@ +import os +import os.path as osp + +import cv2 +import numpy as np +from mmengine.registry import VISBACKENDS +from mmengine.visualization.vis_backend import LocalVisBackend, force_init_env + + +@VISBACKENDS.register_module() +class CDLocalVisBackend(LocalVisBackend): + + @force_init_env + def add_image(self, + name: str, + image: np.array, + image_from: np.array = None, + image_to: np.array = None, + step: int = 0, + **kwargs) -> None: + """Record the image to disk. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to be saved. The format + should be RGB. Defaults to None. + step (int): Global step value to record. Defaults to 0. + """ + assert image.dtype == np.uint8 + drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + os.makedirs(self._img_save_dir, exist_ok=True) + save_file_name = f'{name}.png' + + if image_from is not None and image_to is not None: + assert image_from.dtype == np.uint8 and image_to.dtype == np.uint8 + drawn_image_from = cv2.cvtColor(image_from, cv2.COLOR_RGB2BGR) + drawn_image_to = cv2.cvtColor(image_to, cv2.COLOR_RGB2BGR) + for sub_dir in ['binary', 'from', 'to']: + os.makedirs(osp.join(self._img_save_dir, sub_dir), exist_ok=True) + + cv2.imwrite(osp.join(self._img_save_dir, 'binary', save_file_name), drawn_image) + cv2.imwrite(osp.join(self._img_save_dir, 'from', save_file_name), drawn_image_from) + cv2.imwrite(osp.join(self._img_save_dir, 'to', save_file_name), drawn_image_to) + else: + cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..629f036fa1da4c0aa406f18447b73b6dcab8c2e0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +torch==2.1.2 +torchvision +torchaudio +-f https://download.openmmlab.com/mmcv/dist/cpu/torch2.1/index.html +mmcv==2.1.0 +wandb +einops +importlib +peft +scipy +ftfy +prettytable +torchmetrics \ No newline at end of file diff --git a/samples/.DS_Store b/samples/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb1e252e2fc83ed0885aa10196038184abf1f1c9 Binary files /dev/null and b/samples/.DS_Store differ diff --git a/samples/A/test_1.png b/samples/A/test_1.png new file mode 100644 index 0000000000000000000000000000000000000000..390f38bb8413229368b6592382f3bd93c00d0135 --- /dev/null +++ b/samples/A/test_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b42012ec2f384fdffe7c9e3ef7c7a5bec77cbe57a30a1fc530833667a9b0d46 +size 2136706 diff --git a/samples/A/test_2.png b/samples/A/test_2.png new file mode 100644 index 0000000000000000000000000000000000000000..7c3687f7017d8f91fec20db7fe6395a819dac3ae --- /dev/null +++ b/samples/A/test_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:100d08c2f080c13db395ea4732a9fa775143e451ff49d00771f1bb30db544c9f +size 2185259 diff --git a/samples/A/test_3.png b/samples/A/test_3.png new file mode 100644 index 0000000000000000000000000000000000000000..45d7ea0dda046234f654637019bc4061813f2992 --- /dev/null +++ b/samples/A/test_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f24f8e1ca576c6ecdd6b1ad5c81c05fde1ce75183a1c04069f571ad8e14297e +size 2459524 diff --git a/samples/A/test_4.png b/samples/A/test_4.png new file mode 100644 index 0000000000000000000000000000000000000000..014683897b3977964b226d7ede2fa532be7c7b70 --- /dev/null +++ b/samples/A/test_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1de8e369f46983e47cb917717e4c8a385abe3b756cd8274bb0ac2bf06f2bda8a +size 2233532 diff --git a/samples/A/test_5.png b/samples/A/test_5.png new file mode 100644 index 0000000000000000000000000000000000000000..6bb514cc19e216b8ce8f91970ee14817319b0d8a --- /dev/null +++ b/samples/A/test_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7c025376639fcc4f80f724cb8803394f85d5f4130629ad6323394e3b6ecffeb +size 2344059 diff --git a/samples/B/test_1.png b/samples/B/test_1.png new file mode 100644 index 0000000000000000000000000000000000000000..28ff24ba7a62cc30beb7bb73498b3aa970840d94 --- /dev/null +++ b/samples/B/test_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:489850790cb52b59fc84b3d1f06cf365f9ceac890257e23c7ae646bae98e12ec +size 1925883 diff --git a/samples/B/test_2.png b/samples/B/test_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d12620af924ef6fa62466c029cf1cd96c0be6c6d --- /dev/null +++ b/samples/B/test_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a63fb7b7498074c0448c1c7179ed2aedfd1fd790f1e8c7926cbf5e88727cd8c +size 1977928 diff --git a/samples/B/test_3.png b/samples/B/test_3.png new file mode 100644 index 0000000000000000000000000000000000000000..dea7af5fbd3e5bf6e550dc0de5ea4b23695c8f04 --- /dev/null +++ b/samples/B/test_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65f54da0b2b71feae7dbfece67151394115acded5834ba621a2e5b5ee3ba8dc9 +size 2254787 diff --git a/samples/B/test_4.png b/samples/B/test_4.png new file mode 100644 index 0000000000000000000000000000000000000000..92a1a2b1c974e28d1687cd5d0eb744851679acd9 --- /dev/null +++ b/samples/B/test_4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e86b6cf242405be37daf4f023586927c714d80339c5fb985ef54c128a8311461 +size 2015537 diff --git a/samples/B/test_5.png b/samples/B/test_5.png new file mode 100644 index 0000000000000000000000000000000000000000..2872fb8f8bd6aede62e789dce4387605f9d9a312 --- /dev/null +++ b/samples/B/test_5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdd3a1b715c23048db363eacf077a053bf2b753179f0b26d1c6055a3d6a75a88 +size 2033893 diff --git a/samples/label/test_1.png b/samples/label/test_1.png new file mode 100644 index 0000000000000000000000000000000000000000..4cc129d9fd1c5c67a301a0801360b3e5d8904d8a Binary files /dev/null and b/samples/label/test_1.png differ diff --git a/samples/label/test_2.png b/samples/label/test_2.png new file mode 100644 index 0000000000000000000000000000000000000000..e47a57cef45a79aa25053a7fc5acdeaf959cb0a8 Binary files /dev/null and b/samples/label/test_2.png differ diff --git a/samples/label/test_3.png b/samples/label/test_3.png new file mode 100644 index 0000000000000000000000000000000000000000..286e6a43433cb080ea8486d46a0be3f7681970da Binary files /dev/null and b/samples/label/test_3.png differ diff --git a/samples/label/test_4.png b/samples/label/test_4.png new file mode 100644 index 0000000000000000000000000000000000000000..3905715410463a2347f7383cca3fbc43fb95d9c3 Binary files /dev/null and b/samples/label/test_4.png differ diff --git a/samples/label/test_5.png b/samples/label/test_5.png new file mode 100644 index 0000000000000000000000000000000000000000..63deb8dfac140180aa16911118db878beb566b13 Binary files /dev/null and b/samples/label/test_5.png differ diff --git a/tools/.DS_Store b/tools/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..49853c96f63787624020a7b2a033026bde8b8977 Binary files /dev/null and b/tools/.DS_Store differ diff --git a/tools/analysis_tools/analyze_logs.py b/tools/analysis_tools/analyze_logs.py new file mode 100644 index 0000000000000000000000000000000000000000..7464d231621b17249ce69f358479bbba42757362 --- /dev/null +++ b/tools/analysis_tools/analyze_logs.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/open- +mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" +import argparse +import json +from collections import defaultdict + +import matplotlib.pyplot as plt +import seaborn as sns + + +def plot_curve(log_dicts, args): + if args.backend is not None: + plt.switch_backend(args.backend) + sns.set_style(args.style) + # if legend is None, use {filename}_{key} as legend + legend = args.legend + if legend is None: + legend = [] + for json_log in args.json_logs: + for metric in args.keys: + legend.append(f'{json_log}_{metric}') + assert len(legend) == (len(args.json_logs) * len(args.keys)) + metrics = args.keys + + num_metrics = len(metrics) + for i, log_dict in enumerate(log_dicts): + epochs = list(log_dict.keys()) + for j, metric in enumerate(metrics): + print(f'plot curve of {args.json_logs[i]}, metric is {metric}') + plot_epochs = [] + plot_iters = [] + plot_values = [] + # In some log files exist lines of validation, + # `mode` list is used to only collect iter number + # of training line. + for epoch in epochs: + epoch_logs = log_dict[epoch] + if metric not in epoch_logs.keys(): + continue + if metric in ['mIoU', 'mAcc', 'aAcc']: + plot_epochs.append(epoch) + plot_values.append(epoch_logs[metric][0]) + else: + for idx in range(len(epoch_logs[metric])): + plot_iters.append(epoch_logs['step'][idx]) + plot_values.append(epoch_logs[metric][idx]) + ax = plt.gca() + label = legend[i * num_metrics + j] + if metric in ['mIoU', 'mAcc', 'aAcc']: + ax.set_xticks(plot_epochs) + plt.xlabel('step') + plt.plot(plot_epochs, plot_values, label=label, marker='o') + else: + plt.xlabel('iter') + plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) + plt.legend() + if args.title is not None: + plt.title(args.title) + if args.out is None: + plt.show() + else: + print(f'save curve to: {args.out}') + plt.savefig(args.out) + plt.cla() + + +def parse_args(): + parser = argparse.ArgumentParser(description='Analyze Json Log') + parser.add_argument( + 'json_logs', + type=str, + nargs='+', + help='path of train log in json format') + parser.add_argument( + '--keys', + type=str, + nargs='+', + default=['mIoU'], + help='the metric that you want to plot') + parser.add_argument('--title', type=str, help='title of figure') + parser.add_argument( + '--legend', + type=str, + nargs='+', + default=None, + help='legend of each plot') + parser.add_argument( + '--backend', type=str, default=None, help='backend of plt') + parser.add_argument( + '--style', type=str, default='dark', help='style of plt') + parser.add_argument('--out', type=str, default=None) + args = parser.parse_args() + return args + + +def load_json_logs(json_logs): + # load and convert json_logs to log_dict, key is step, value is a sub dict + # keys of sub dict is different metrics + # value of sub dict is a list of corresponding values of all iterations + log_dicts = [dict() for _ in json_logs] + prev_step = 0 + for json_log, log_dict in zip(json_logs, log_dicts): + with open(json_log) as log_file: + for line in log_file: + log = json.loads(line.strip()) + # the final step in json file is 0. + if 'step' in log and log['step'] != 0: + step = log['step'] + prev_step = step + else: + step = prev_step + if step not in log_dict: + log_dict[step] = defaultdict(list) + for k, v in log.items(): + log_dict[step][k].append(v) + return log_dicts + + +def main(): + args = parse_args() + json_logs = args.json_logs + for json_log in json_logs: + assert json_log.endswith('.json') + log_dicts = load_json_logs(json_logs) + plot_curve(log_dicts, args) + + +if __name__ == '__main__': + main() diff --git a/tools/analysis_tools/benchmark.py b/tools/analysis_tools/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..afaeabac85fa642b03c006b8a920c0d95d4cb400 --- /dev/null +++ b/tools/analysis_tools/benchmark.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import time + +import numpy as np +import torch +from mmengine import Config +from mmengine.fileio import dump +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import Runner, load_checkpoint +from mmengine.utils import mkdir_or_exist + +from mmseg.registry import MODELS + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMSeg benchmark a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--log-interval', type=int, default=50, help='interval of logging') + parser.add_argument( + '--work-dir', + help=('if specified, the results will be dumped ' + 'into the directory as json')) + parser.add_argument('--repeat-times', type=int, default=1) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + + init_default_scope(cfg.get('default_scope', 'mmseg')) + + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + if args.work_dir is not None: + mkdir_or_exist(osp.abspath(args.work_dir)) + json_file = osp.join(args.work_dir, f'fps_{timestamp}.json') + else: + # use config filename as default work_dir if cfg.work_dir is None + work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + mkdir_or_exist(osp.abspath(work_dir)) + json_file = osp.join(work_dir, f'fps_{timestamp}.json') + + repeat_times = args.repeat_times + # set cudnn_benchmark + torch.backends.cudnn.benchmark = False + cfg.model.pretrained = None + + benchmark_dict = dict(config=args.config, unit='img / s') + overall_fps_list = [] + cfg.test_dataloader.batch_size = 1 + for time_index in range(repeat_times): + print(f'Run {time_index + 1}:') + # build the dataloader + data_loader = Runner.build_dataloader(cfg.test_dataloader) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = MODELS.build(cfg.model) + + if 'checkpoint' in args and osp.exists(args.checkpoint): + load_checkpoint(model, args.checkpoint, map_location='cpu') + + if torch.cuda.is_available(): + model = model.cuda() + + model = revert_sync_batchnorm(model) + + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0 + total_iters = 200 + + # benchmark with 200 batches and take the average + for i, data in enumerate(data_loader): + data = model.data_preprocessor(data, True) + inputs = data['inputs'] + data_samples = data['data_samples'] + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(inputs, data_samples, mode='predict') + + if torch.cuda.is_available(): + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % args.log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print(f'Done image [{i + 1:<3}/ {total_iters}], ' + f'fps: {fps:.2f} img / s') + + if (i + 1) == total_iters: + fps = (i + 1 - num_warmup) / pure_inf_time + print(f'Overall fps: {fps:.2f} img / s\n') + benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2) + overall_fps_list.append(fps) + break + benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2) + benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4) + print(f'Average fps of {repeat_times} evaluations: ' + f'{benchmark_dict["average_fps"]}') + print(f'The variance of {repeat_times} evaluations: ' + f'{benchmark_dict["fps_variance"]}') + dump(benchmark_dict, json_file, indent=4) + + +if __name__ == '__main__': + main() diff --git a/tools/analysis_tools/browse_dataset.py b/tools/analysis_tools/browse_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..925c14a8ab63b4e38950b6c6af58e37dba002a4c --- /dev/null +++ b/tools/analysis_tools/browse_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.utils import ProgressBar + +from mmseg.registry import DATASETS, VISUALIZERS +from mmseg.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--output-dir', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=2, + help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # register all modules in mmdet into the registries + register_all_modules() + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.metainfo + + progress_bar = ProgressBar(len(dataset)) + for item in dataset: + img = item['inputs'].permute(1, 2, 0).numpy() + img = img[..., [2, 1, 0]] # bgr to rgb + data_sample = item['data_samples'].numpy() + img_path = osp.basename(item['data_samples'].img_path) + + out_file = osp.join( + args.output_dir, + osp.basename(img_path)) if args.output_dir is not None else None + + visualizer.add_datasample( + name=osp.basename(img_path), + image=img, + data_sample=data_sample, + draw_gt=True, + draw_pred=False, + wait_time=args.show_interval, + out_file=out_file, + show=not args.not_show) + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..39756cdfdd2341e7e02f9de24077da880b6021c3 --- /dev/null +++ b/tools/analysis_tools/confusion_matrix.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.ticker import MultipleLocator +from mmengine.config import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import mkdir_or_exist, progressbar +from PIL import Image + +from mmseg.registry import DATASETS + +init_default_scope('mmseg') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate confusion matrix from segmentation results') + parser.add_argument('config', help='test config file path') + parser.add_argument( + 'prediction_path', help='prediction path where test folder result') + parser.add_argument( + 'save_dir', help='directory where confusion matrix will be saved') + parser.add_argument( + '--show', action='store_true', help='show confusion matrix') + parser.add_argument( + '--color-theme', + default='winter', + help='theme of the matrix color map') + parser.add_argument( + '--title', + default='Normalized Confusion Matrix', + help='title of the matrix color map') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def calculate_confusion_matrix(dataset, results): + """Calculate the confusion matrix. + + Args: + dataset (Dataset): Test or val dataset. + results (list[ndarray]): A list of segmentation results in each image. + """ + n = len(dataset.METAINFO['classes']) + confusion_matrix = np.zeros(shape=[n, n]) + assert len(dataset) == len(results) + ignore_index = dataset.ignore_index + reduce_zero_label = dataset.reduce_zero_label + prog_bar = progressbar.ProgressBar(len(results)) + for idx, per_img_res in enumerate(results): + res_segm = per_img_res + gt_segm = dataset[idx]['data_samples'] \ + .gt_sem_seg.data.squeeze().numpy().astype(np.uint8) + gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten() + if reduce_zero_label: + gt_segm = gt_segm - 1 + to_ignore = gt_segm == ignore_index + + gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore] + inds = n * gt_segm + res_segm + mat = np.bincount(inds, minlength=n**2).reshape(n, n) + confusion_matrix += mat + prog_bar.update() + return confusion_matrix + + +def plot_confusion_matrix(confusion_matrix, + labels, + save_dir=None, + show=True, + title='Normalized Confusion Matrix', + color_theme='OrRd'): + """Draw confusion matrix with matplotlib. + + Args: + confusion_matrix (ndarray): The confusion matrix. + labels (list[str]): List of class names. + save_dir (str|optional): If set, save the confusion matrix plot to the + given path. Default: None. + show (bool): Whether to show the plot. Default: True. + title (str): Title of the plot. Default: `Normalized Confusion Matrix`. + color_theme (str): Theme of the matrix color map. Default: `winter`. + """ + # normalize the confusion matrix + per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] + confusion_matrix = \ + confusion_matrix.astype(np.float32) / per_label_sums * 100 + + num_classes = len(labels) + fig, ax = plt.subplots( + figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300) + cmap = plt.get_cmap(color_theme) + im = ax.imshow(confusion_matrix, cmap=cmap) + colorbar = plt.colorbar(mappable=im, ax=ax) + colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小 + + title_font = {'weight': 'bold', 'size': 20} + ax.set_title(title, fontdict=title_font) + label_font = {'size': 40} + plt.ylabel('Ground Truth Label', fontdict=label_font) + plt.xlabel('Prediction Label', fontdict=label_font) + + # draw locator + xmajor_locator = MultipleLocator(1) + xminor_locator = MultipleLocator(0.5) + ax.xaxis.set_major_locator(xmajor_locator) + ax.xaxis.set_minor_locator(xminor_locator) + ymajor_locator = MultipleLocator(1) + yminor_locator = MultipleLocator(0.5) + ax.yaxis.set_major_locator(ymajor_locator) + ax.yaxis.set_minor_locator(yminor_locator) + + # draw grid + ax.grid(True, which='minor', linestyle='-') + + # draw label + ax.set_xticks(np.arange(num_classes)) + ax.set_yticks(np.arange(num_classes)) + ax.set_xticklabels(labels, fontsize=20) + ax.set_yticklabels(labels, fontsize=20) + + ax.tick_params( + axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) + plt.setp( + ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') + + # draw confusion matrix value + for i in range(num_classes): + for j in range(num_classes): + ax.text( + j, + i, + '{}%'.format( + round(confusion_matrix[i, j], 2 + ) if not np.isnan(confusion_matrix[i, j]) else -1), + ha='center', + va='center', + color='k', + size=20) + + ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 + + fig.tight_layout() + if save_dir is not None: + mkdir_or_exist(save_dir) + plt.savefig( + os.path.join(save_dir, 'confusion_matrix.png'), format='png') + if show: + plt.show() + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + results = [] + for img in sorted(os.listdir(args.prediction_path)): + img = os.path.join(args.prediction_path, img) + image = Image.open(img) + image = np.copy(image) + results.append(image) + + assert isinstance(results, list) + if isinstance(results[0], np.ndarray): + pass + else: + raise TypeError('invalid type of prediction results') + + dataset = DATASETS.build(cfg.test_dataloader.dataset) + confusion_matrix = calculate_confusion_matrix(dataset, results) + plot_confusion_matrix( + confusion_matrix, + dataset.METAINFO['classes'], + save_dir=args.save_dir, + show=args.show, + title=args.title, + color_theme=args.color_theme) + + +if __name__ == '__main__': + main() diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py new file mode 100644 index 0000000000000000000000000000000000000000..66b2d52fcd2cb0f19066cfa4dfbfe13bc1e682e2 --- /dev/null +++ b/tools/analysis_tools/get_flops.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import tempfile +from pathlib import Path + +import torch +from mmengine import Config, DictAction +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope + +from mmseg.models import BaseSegmentor +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample + +try: + from mmengine.analysis import get_model_complexity_info + from mmengine.analysis.print_helper import _format_size +except ImportError: + raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Get the FLOPs of a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[2048, 1024], + help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def inference(args: argparse.Namespace, logger: MMLogger) -> dict: + config_name = Path(args.config) + + if not config_name.exists(): + logger.error(f'Config file {config_name} does not exist') + + cfg: Config = Config.fromfile(config_name) + cfg.work_dir = tempfile.TemporaryDirectory().name + cfg.log_level = 'WARN' + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + init_default_scope(cfg.get('scope', 'mmseg')) + + if len(args.shape) == 1: + input_shape = (3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (3, ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + result = {} + + model: BaseSegmentor = MODELS.build(cfg.model) + if hasattr(model, 'auxiliary_head'): + model.auxiliary_head = None + if torch.cuda.is_available(): + model.cuda() + model = revert_sync_batchnorm(model) + result['ori_shape'] = input_shape[-2:] + result['pad_shape'] = input_shape[-2:] + data_batch = { + 'inputs': [torch.rand(input_shape)], + 'data_samples': [SegDataSample(metainfo=result)] + } + data = model.data_preprocessor(data_batch) + model.eval() + if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']: + # TODO: Support MaskFormer and Mask2Former + raise NotImplementedError('MaskFormer and Mask2Former are not ' + 'supported yet.') + outputs = get_model_complexity_info( + model, + input_shape, + inputs=data['inputs'], + show_table=False, + show_arch=False) + result['flops'] = _format_size(outputs['flops']) + result['params'] = _format_size(outputs['params']) + result['compute_type'] = 'direct: randomly generate a picture' + return result + + +def main(): + + args = parse_args() + logger = MMLogger.get_instance(name='MMLogger') + + result = inference(args, logger) + split_line = '=' * 30 + ori_shape = result['ori_shape'] + pad_shape = result['pad_shape'] + flops = result['flops'] + params = result['params'] + compute_type = result['compute_type'] + + if pad_shape != ori_shape: + print(f'{split_line}\nUse size divisor set input shape ' + f'from {ori_shape} to {pad_shape}') + print(f'{split_line}\nCompute type: {compute_type}\n' + f'Input shape: {pad_shape}\nFlops: {flops}\n' + f'Params: {params}\n{split_line}') + print('!!!Please be cautious if you use the results in papers. ' + 'You may need to check if all ops are supported and verify ' + 'that the flops computation is correct.') + + +if __name__ == '__main__': + main() diff --git a/tools/analysis_tools/visualization_cam.py b/tools/analysis_tools/visualization_cam.py new file mode 100644 index 0000000000000000000000000000000000000000..00cdb3e04ab1f9000844ace781bc138f230d4630 --- /dev/null +++ b/tools/analysis_tools/visualization_cam.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM). + +requirement: pip install grad-cam +""" + +from argparse import ArgumentParser + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine import Config +from mmengine.model import revert_sync_batchnorm +from PIL import Image +from pytorch_grad_cam import GradCAM +from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image + +from mmseg.apis import inference_model, init_model, show_result_pyplot +from mmseg.utils import register_all_modules + + +class SemanticSegmentationTarget: + """wrap the model. + + requirement: pip install grad-cam + + Args: + category (int): Visualization class. + mask (ndarray): Mask of class. + size (tuple): Image size. + """ + + def __init__(self, category, mask, size): + self.category = category + self.mask = torch.from_numpy(mask) + self.size = size + if torch.cuda.is_available(): + self.mask = self.mask.cuda() + + def __call__(self, model_output): + model_output = torch.unsqueeze(model_output, dim=0) + model_output = F.interpolate( + model_output, size=self.size, mode='bilinear') + model_output = torch.squeeze(model_output, dim=0) + + return (model_output[self.category, :, :] * self.mask).sum() + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--out-file', + default='prediction.png', + help='Path to output prediction file') + parser.add_argument( + '--cam-file', default='vis_cam.png', help='Path to output cam file') + parser.add_argument( + '--target-layers', + default='backbone.layer4[2]', + help='Target layers to visualize CAM') + parser.add_argument( + '--category-index', default='7', help='Category to visualize CAM') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + register_all_modules() + model = init_model(args.config, args.checkpoint, device=args.device) + if args.device == 'cpu': + model = revert_sync_batchnorm(model) + + # test a single image + result = inference_model(model, args.img) + + # show the results + show_result_pyplot( + model, + args.img, + result, + draw_gt=False, + show=False if args.out_file is not None else True, + out_file=args.out_file) + + # result data conversion + prediction_data = result.pred_sem_seg.data + pre_np_data = prediction_data.cpu().numpy().squeeze(0) + + target_layers = args.target_layers + target_layers = [eval(f'model.{target_layers}')] + + category = int(args.category_index) + mask_float = np.float32(pre_np_data == category) + + # data processing + image = np.array(Image.open(args.img).convert('RGB')) + height, width = image.shape[0], image.shape[1] + rgb_img = np.float32(image) / 255 + config = Config.fromfile(args.config) + image_mean = config.data_preprocessor['mean'] + image_std = config.data_preprocessor['std'] + input_tensor = preprocess_image( + rgb_img, + mean=[x / 255 for x in image_mean], + std=[x / 255 for x in image_std]) + + # Grad CAM(Class Activation Maps) + # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM + targets = [ + SemanticSegmentationTarget(category, mask_float, (height, width)) + ] + with GradCAM( + model=model, + target_layers=target_layers, + use_cuda=torch.cuda.is_available()) as cam: + grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] + cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) + + # save cam file + Image.fromarray(cam_image).save(args.cam_file) + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/chase_db1.py b/tools/dataset_converters/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fefbd77435c5745d290269cd00f67fda604455 --- /dev/null +++ b/tools/dataset_converters/chase_db1.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +from mmengine.utils import mkdir_or_exist + +CHASE_DB1_LEN = 28 * 3 +TRAINING_LEN = 60 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert CHASE_DB1 dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='path of CHASEDB1.zip') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'CHASE_DB1') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + print('Extracting CHASEDB1.zip...') + zip_file = zipfile.ZipFile(dataset_path) + zip_file.extractall(tmp_dir) + + print('Generating training dataset...') + + assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ + f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}' + + for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(tmp_dir, img_name)) + if osp.splitext(img_name)[1] == '.jpg': + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'training', + osp.splitext(img_name)[0] + '.png')) + else: + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a + # threshold to convert the nonstandard annotation imgs. The + # value divided by 128 is equivalent to '1 if value >= 128 + # else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(img_name)[0] + '.png')) + + for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(tmp_dir, img_name)) + if osp.splitext(img_name)[1] == '.jpg': + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'validation', + osp.splitext(img_name)[0] + '.png')) + else: + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(img_name)[0] + '.png')) + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/cityscapes.py b/tools/dataset_converters/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..0d6a80135d906db7330a736ccbcc908e0a6309c6 --- /dev/null +++ b/tools/dataset_converters/cityscapes.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +from cityscapesscripts.preparation.json2labelImg import json2labelImg +from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress, + track_progress) + + +def convert_json_to_label(json_file): + label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') + json2labelImg(json_file, label_file, 'trainIds') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert Cityscapes annotations to TrainIds') + parser.add_argument('cityscapes_path', help='cityscapes data path') + parser.add_argument('--gt-dir', default='gtFine', type=str) + parser.add_argument('-o', '--out-dir', help='output path') + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cityscapes_path = args.cityscapes_path + out_dir = args.out_dir if args.out_dir else cityscapes_path + mkdir_or_exist(out_dir) + + gt_dir = osp.join(cityscapes_path, args.gt_dir) + + poly_files = [] + for poly in scandir(gt_dir, '_polygons.json', recursive=True): + poly_file = osp.join(gt_dir, poly) + poly_files.append(poly_file) + if args.nproc > 1: + track_parallel_progress(convert_json_to_label, poly_files, args.nproc) + else: + track_progress(convert_json_to_label, poly_files) + + split_names = ['train', 'val', 'test'] + + for split in split_names: + filenames = [] + for poly in scandir( + osp.join(gt_dir, split), '_polygons.json', recursive=True): + filenames.append(poly.replace('_gtFine_polygons.json', '')) + with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: + f.writelines(f + '\n' for f in filenames) + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/coco_stuff10k.py b/tools/dataset_converters/coco_stuff10k.py new file mode 100644 index 0000000000000000000000000000000000000000..920127ee10fc09b76f8e2344ecdf3b7800d51802 --- /dev/null +++ b/tools/dataset_converters/coco_stuff10k.py @@ -0,0 +1,308 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil +from functools import partial + +import numpy as np +from mmengine.utils import (mkdir_or_exist, track_parallel_progress, + track_progress) +from PIL import Image +from scipy.io import loadmat + +COCO_LEN = 10000 + +clsID_to_trID = { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 11: 11, + 13: 12, + 14: 13, + 15: 14, + 16: 15, + 17: 16, + 18: 17, + 19: 18, + 20: 19, + 21: 20, + 22: 21, + 23: 22, + 24: 23, + 25: 24, + 27: 25, + 28: 26, + 31: 27, + 32: 28, + 33: 29, + 34: 30, + 35: 31, + 36: 32, + 37: 33, + 38: 34, + 39: 35, + 40: 36, + 41: 37, + 42: 38, + 43: 39, + 44: 40, + 46: 41, + 47: 42, + 48: 43, + 49: 44, + 50: 45, + 51: 46, + 52: 47, + 53: 48, + 54: 49, + 55: 50, + 56: 51, + 57: 52, + 58: 53, + 59: 54, + 60: 55, + 61: 56, + 62: 57, + 63: 58, + 64: 59, + 65: 60, + 67: 61, + 70: 62, + 72: 63, + 73: 64, + 74: 65, + 75: 66, + 76: 67, + 77: 68, + 78: 69, + 79: 70, + 80: 71, + 81: 72, + 82: 73, + 84: 74, + 85: 75, + 86: 76, + 87: 77, + 88: 78, + 89: 79, + 90: 80, + 92: 81, + 93: 82, + 94: 83, + 95: 84, + 96: 85, + 97: 86, + 98: 87, + 99: 88, + 100: 89, + 101: 90, + 102: 91, + 103: 92, + 104: 93, + 105: 94, + 106: 95, + 107: 96, + 108: 97, + 109: 98, + 110: 99, + 111: 100, + 112: 101, + 113: 102, + 114: 103, + 115: 104, + 116: 105, + 117: 106, + 118: 107, + 119: 108, + 120: 109, + 121: 110, + 122: 111, + 123: 112, + 124: 113, + 125: 114, + 126: 115, + 127: 116, + 128: 117, + 129: 118, + 130: 119, + 131: 120, + 132: 121, + 133: 122, + 134: 123, + 135: 124, + 136: 125, + 137: 126, + 138: 127, + 139: 128, + 140: 129, + 141: 130, + 142: 131, + 143: 132, + 144: 133, + 145: 134, + 146: 135, + 147: 136, + 148: 137, + 149: 138, + 150: 139, + 151: 140, + 152: 141, + 153: 142, + 154: 143, + 155: 144, + 156: 145, + 157: 146, + 158: 147, + 159: 148, + 160: 149, + 161: 150, + 162: 151, + 163: 152, + 164: 153, + 165: 154, + 166: 155, + 167: 156, + 168: 157, + 169: 158, + 170: 159, + 171: 160, + 172: 161, + 173: 162, + 174: 163, + 175: 164, + 176: 165, + 177: 166, + 178: 167, + 179: 168, + 180: 169, + 181: 170, + 182: 171 +} + + +def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir, + out_mask_dir, is_train): + imgpath, maskpath = tuple_path + shutil.copyfile( + osp.join(in_img_dir, imgpath), + osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join( + out_img_dir, 'test2014', imgpath)) + annotate = loadmat(osp.join(in_ann_dir, maskpath)) + mask = annotate['S'].astype(np.uint8) + mask_copy = mask.copy() + for clsID, trID in clsID_to_trID.items(): + mask_copy[mask == clsID] = trID + seg_filename = osp.join(out_mask_dir, 'train2014', + maskpath.split('.')[0] + + '_labelTrainIds.png') if is_train else osp.join( + out_mask_dir, 'test2014', + maskpath.split('.')[0] + '_labelTrainIds.png') + Image.fromarray(mask_copy).save(seg_filename, 'PNG') + + +def generate_coco_list(folder): + train_list = osp.join(folder, 'imageLists', 'train.txt') + test_list = osp.join(folder, 'imageLists', 'test.txt') + train_paths = [] + test_paths = [] + + with open(train_list) as f: + for filename in f: + basename = filename.strip() + imgpath = basename + '.jpg' + maskpath = basename + '.mat' + train_paths.append((imgpath, maskpath)) + + with open(test_list) as f: + for filename in f: + basename = filename.strip() + imgpath = basename + '.jpg' + maskpath = basename + '.mat' + test_paths.append((imgpath, maskpath)) + + return train_paths, test_paths + + +def parse_args(): + parser = argparse.ArgumentParser( + description=\ + 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa + parser.add_argument('coco_path', help='coco stuff path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=16, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + coco_path = args.coco_path + nproc = args.nproc + + out_dir = args.out_dir or coco_path + out_img_dir = osp.join(out_dir, 'images') + out_mask_dir = osp.join(out_dir, 'annotations') + + mkdir_or_exist(osp.join(out_img_dir, 'train2014')) + mkdir_or_exist(osp.join(out_img_dir, 'test2014')) + mkdir_or_exist(osp.join(out_mask_dir, 'train2014')) + mkdir_or_exist(osp.join(out_mask_dir, 'test2014')) + + train_list, test_list = generate_coco_list(coco_path) + assert (len(train_list) + + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( + len(train_list), len(test_list)) + + if args.nproc > 1: + track_parallel_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=True), + train_list, + nproc=nproc) + track_parallel_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=False), + test_list, + nproc=nproc) + else: + track_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=True), train_list) + track_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=False), test_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/coco_stuff164k.py b/tools/dataset_converters/coco_stuff164k.py new file mode 100644 index 0000000000000000000000000000000000000000..a13114ab1e0c37675369b2e9ba065cbfb2dca1e7 --- /dev/null +++ b/tools/dataset_converters/coco_stuff164k.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil +from functools import partial +from glob import glob + +import numpy as np +from mmengine.utils import (mkdir_or_exist, track_parallel_progress, + track_progress) +from PIL import Image + +COCO_LEN = 123287 + +clsID_to_trID = { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 12: 11, + 13: 12, + 14: 13, + 15: 14, + 16: 15, + 17: 16, + 18: 17, + 19: 18, + 20: 19, + 21: 20, + 22: 21, + 23: 22, + 24: 23, + 26: 24, + 27: 25, + 30: 26, + 31: 27, + 32: 28, + 33: 29, + 34: 30, + 35: 31, + 36: 32, + 37: 33, + 38: 34, + 39: 35, + 40: 36, + 41: 37, + 42: 38, + 43: 39, + 45: 40, + 46: 41, + 47: 42, + 48: 43, + 49: 44, + 50: 45, + 51: 46, + 52: 47, + 53: 48, + 54: 49, + 55: 50, + 56: 51, + 57: 52, + 58: 53, + 59: 54, + 60: 55, + 61: 56, + 62: 57, + 63: 58, + 64: 59, + 66: 60, + 69: 61, + 71: 62, + 72: 63, + 73: 64, + 74: 65, + 75: 66, + 76: 67, + 77: 68, + 78: 69, + 79: 70, + 80: 71, + 81: 72, + 83: 73, + 84: 74, + 85: 75, + 86: 76, + 87: 77, + 88: 78, + 89: 79, + 91: 80, + 92: 81, + 93: 82, + 94: 83, + 95: 84, + 96: 85, + 97: 86, + 98: 87, + 99: 88, + 100: 89, + 101: 90, + 102: 91, + 103: 92, + 104: 93, + 105: 94, + 106: 95, + 107: 96, + 108: 97, + 109: 98, + 110: 99, + 111: 100, + 112: 101, + 113: 102, + 114: 103, + 115: 104, + 116: 105, + 117: 106, + 118: 107, + 119: 108, + 120: 109, + 121: 110, + 122: 111, + 123: 112, + 124: 113, + 125: 114, + 126: 115, + 127: 116, + 128: 117, + 129: 118, + 130: 119, + 131: 120, + 132: 121, + 133: 122, + 134: 123, + 135: 124, + 136: 125, + 137: 126, + 138: 127, + 139: 128, + 140: 129, + 141: 130, + 142: 131, + 143: 132, + 144: 133, + 145: 134, + 146: 135, + 147: 136, + 148: 137, + 149: 138, + 150: 139, + 151: 140, + 152: 141, + 153: 142, + 154: 143, + 155: 144, + 156: 145, + 157: 146, + 158: 147, + 159: 148, + 160: 149, + 161: 150, + 162: 151, + 163: 152, + 164: 153, + 165: 154, + 166: 155, + 167: 156, + 168: 157, + 169: 158, + 170: 159, + 171: 160, + 172: 161, + 173: 162, + 174: 163, + 175: 164, + 176: 165, + 177: 166, + 178: 167, + 179: 168, + 180: 169, + 181: 170, + 255: 255 +} + + +def convert_to_trainID(maskpath, out_mask_dir, is_train): + mask = np.array(Image.open(maskpath)) + mask_copy = mask.copy() + for clsID, trID in clsID_to_trID.items(): + mask_copy[mask == clsID] = trID + seg_filename = osp.join( + out_mask_dir, 'train2017', + osp.basename(maskpath).split('.')[0] + + '_labelTrainIds.png') if is_train else osp.join( + out_mask_dir, 'val2017', + osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png') + Image.fromarray(mask_copy).save(seg_filename, 'PNG') + + +def parse_args(): + parser = argparse.ArgumentParser( + description=\ + 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa + parser.add_argument('coco_path', help='coco stuff path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=16, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + coco_path = args.coco_path + nproc = args.nproc + + out_dir = args.out_dir or coco_path + out_img_dir = osp.join(out_dir, 'images') + out_mask_dir = osp.join(out_dir, 'annotations') + + mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) + mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) + + if out_dir != coco_path: + shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) + + train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) + train_list = [file for file in train_list if '_labelTrainIds' not in file] + test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) + test_list = [file for file in test_list if '_labelTrainIds' not in file] + assert (len(train_list) + + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( + len(train_list), len(test_list)) + + if args.nproc > 1: + track_parallel_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), + train_list, + nproc=nproc) + track_parallel_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), + test_list, + nproc=nproc) + else: + track_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), + train_list) + track_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), + test_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/drive.py b/tools/dataset_converters/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..076fd05a2029216e0f1a1494610181fdaa7fbef9 --- /dev/null +++ b/tools/dataset_converters/drive.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import cv2 +import mmcv +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert DRIVE dataset to mmsegmentation format') + parser.add_argument( + 'training_path', help='the training part of DRIVE dataset') + parser.add_argument( + 'testing_path', help='the testing part of DRIVE dataset') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + training_path = args.training_path + testing_path = args.testing_path + if args.out_dir is None: + out_dir = osp.join('data', 'DRIVE') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + print('Extracting training.zip...') + zip_file = zipfile.ZipFile(training_path) + zip_file.extractall(tmp_dir) + + print('Generating training dataset...') + now_dir = osp.join(tmp_dir, 'training', 'images') + for img_name in os.listdir(now_dir): + img = mmcv.imread(osp.join(now_dir, img_name)) + mmcv.imwrite( + img, + osp.join( + out_dir, 'images', 'training', + osp.splitext(img_name)[0].replace('_training', '') + + '.png')) + + now_dir = osp.join(tmp_dir, 'training', '1st_manual') + for img_name in os.listdir(now_dir): + cap = cv2.VideoCapture(osp.join(now_dir, img_name)) + ret, img = cap.read() + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(img_name)[0] + '.png')) + + print('Extracting test.zip...') + zip_file = zipfile.ZipFile(testing_path) + zip_file.extractall(tmp_dir) + + print('Generating validation dataset...') + now_dir = osp.join(tmp_dir, 'test', 'images') + for img_name in os.listdir(now_dir): + img = mmcv.imread(osp.join(now_dir, img_name)) + mmcv.imwrite( + img, + osp.join( + out_dir, 'images', 'validation', + osp.splitext(img_name)[0].replace('_test', '') + '.png')) + + now_dir = osp.join(tmp_dir, 'test', '1st_manual') + if osp.exists(now_dir): + for img_name in os.listdir(now_dir): + cap = cv2.VideoCapture(osp.join(now_dir, img_name)) + ret, img = cap.read() + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a + # threshold to convert the nonstandard annotation imgs. The + # value divided by 128 is equivalent to '1 if value >= 128 + # else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(img_name)[0] + '.png')) + + now_dir = osp.join(tmp_dir, 'test', '2nd_manual') + if osp.exists(now_dir): + for img_name in os.listdir(now_dir): + cap = cv2.VideoCapture(osp.join(now_dir, img_name)) + ret, img = cap.read() + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(img_name)[0] + '.png')) + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/hrf.py b/tools/dataset_converters/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..3bfd80c9ee42e3b5cba4a12a6c8b32ddbb2f1f11 --- /dev/null +++ b/tools/dataset_converters/hrf.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +from mmengine.utils import mkdir_or_exist + +HRF_LEN = 15 +TRAINING_LEN = 5 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert HRF dataset to mmsegmentation format') + parser.add_argument('healthy_path', help='the path of healthy.zip') + parser.add_argument( + 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip') + parser.add_argument('glaucoma_path', help='the path of glaucoma.zip') + parser.add_argument( + 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip') + parser.add_argument( + 'diabetic_retinopathy_path', + help='the path of diabetic_retinopathy.zip') + parser.add_argument( + 'diabetic_retinopathy_manualsegm_path', + help='the path of diabetic_retinopathy_manualsegm.zip') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + images_path = [ + args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path + ] + annotations_path = [ + args.healthy_manualsegm_path, args.glaucoma_manualsegm_path, + args.diabetic_retinopathy_manualsegm_path + ] + if args.out_dir is None: + out_dir = osp.join('data', 'HRF') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + print('Generating images...') + for now_path in images_path: + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + zip_file = zipfile.ZipFile(now_path) + zip_file.extractall(tmp_dir) + + assert len(os.listdir(tmp_dir)) == HRF_LEN, \ + f'len(os.listdir(tmp_dir)) != {HRF_LEN}' + + for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'training', + osp.splitext(filename)[0] + '.png')) + for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Generating annotations...') + for now_path in annotations_path: + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + zip_file = zipfile.ZipFile(now_path) + zip_file.extractall(tmp_dir) + + assert len(os.listdir(tmp_dir)) == HRF_LEN, \ + f'len(os.listdir(tmp_dir)) != {HRF_LEN}' + + for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a + # threshold to convert the nonstandard annotation imgs. The + # value divided by 128 is equivalent to '1 if value >= 128 + # else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(filename)[0] + '.png')) + for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/isaid.py b/tools/dataset_converters/isaid.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5ccd9c776e9621c261e6d168bf6aa4f7b451f6 --- /dev/null +++ b/tools/dataset_converters/isaid.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import os +import os.path as osp +import shutil +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist +from PIL import Image + +iSAID_palette = \ + { + 0: (0, 0, 0), + 1: (0, 0, 63), + 2: (0, 63, 63), + 3: (0, 63, 0), + 4: (0, 63, 127), + 5: (0, 63, 191), + 6: (0, 63, 255), + 7: (0, 127, 63), + 8: (0, 127, 127), + 9: (0, 0, 127), + 10: (0, 0, 191), + 11: (0, 0, 255), + 12: (0, 191, 127), + 13: (0, 127, 191), + 14: (0, 127, 255), + 15: (0, 100, 155) + } + +iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()} + + +def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette): + """RGB-color encoding to grayscale labels.""" + arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) + + for c, i in palette.items(): + m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) + arr_2d[m] = i + + return arr_2d + + +def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap): + img = np.asarray(Image.open(src_path).convert('RGB')) + + img_H, img_W, _ = img.shape + + if img_H < patch_H and img_W > patch_W: + + img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0) + + img_H, img_W, _ = img.shape + + elif img_H > patch_H and img_W < patch_W: + + img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0) + + img_H, img_W, _ = img.shape + + elif img_H < patch_H and img_W < patch_W: + + img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0) + + img_H, img_W, _ = img.shape + + for x in range(0, img_W, patch_W - overlap): + for y in range(0, img_H, patch_H - overlap): + x_str = x + x_end = x + patch_W + if x_end > img_W: + diff_x = x_end - img_W + x_str -= diff_x + x_end = img_W + y_str = y + y_end = y + patch_H + if y_end > img_H: + diff_y = y_end - img_H + y_str -= diff_y + y_end = img_H + + img_patch = img[y_str:y_end, x_str:x_end, :] + img_patch = Image.fromarray(img_patch.astype(np.uint8)) + image = osp.basename(src_path).split('.')[0] + '_' + str( + y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str( + x_end) + '.png' + # print(image) + save_path_image = osp.join(out_dir, 'img_dir', mode, str(image)) + img_patch.save(save_path_image, format='BMP') + + +def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap): + label = mmcv.imread(src_path, channel_order='rgb') + label = iSAID_convert_from_color(label) + img_H, img_W = label.shape + + if img_H < patch_H and img_W > patch_W: + + label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255) + + img_H = patch_H + + elif img_H > patch_H and img_W < patch_W: + + label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255) + + img_W = patch_W + + elif img_H < patch_H and img_W < patch_W: + + label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255) + + img_H = patch_H + img_W = patch_W + + for x in range(0, img_W, patch_W - overlap): + for y in range(0, img_H, patch_H - overlap): + x_str = x + x_end = x + patch_W + if x_end > img_W: + diff_x = x_end - img_W + x_str -= diff_x + x_end = img_W + y_str = y + y_end = y + patch_H + if y_end > img_H: + diff_y = y_end - img_H + y_str -= diff_y + y_end = img_H + + lab_patch = label[y_str:y_end, x_str:x_end] + lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P') + + image = osp.basename(src_path).split('.')[0].split( + '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str( + x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png' + lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image))) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert iSAID dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='iSAID folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + + parser.add_argument( + '--patch_width', + default=896, + type=int, + help='Width of the cropped image patch') + parser.add_argument( + '--patch_height', + default=896, + type=int, + help='Height of the cropped image patch') + parser.add_argument( + '--overlap_area', default=384, type=int, help='Overlap area') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + # image patch width and height + patch_H, patch_W = args.patch_width, args.patch_height + + overlap = args.overlap_area # overlap area + + if args.out_dir is None: + out_dir = osp.join('data', 'iSAID') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) + + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test')) + + assert os.path.exists(os.path.join(dataset_path, 'train')), \ + f'train is not in {dataset_path}' + assert os.path.exists(os.path.join(dataset_path, 'val')), \ + f'val is not in {dataset_path}' + assert os.path.exists(os.path.join(dataset_path, 'test')), \ + f'test is not in {dataset_path}' + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for dataset_mode in ['train', 'val', 'test']: + + # for dataset_mode in [ 'test']: + print(f'Extracting {dataset_mode}ing.zip...') + img_zipp_list = glob.glob( + os.path.join(dataset_path, dataset_mode, 'images', '*.zip')) + print('Find the data', img_zipp_list) + for img_zipp in img_zipp_list: + zip_file = zipfile.ZipFile(img_zipp) + zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img')) + src_path_list = glob.glob( + os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png')) + + src_prog_bar = ProgressBar(len(src_path_list)) + for i, img_path in enumerate(src_path_list): + if dataset_mode != 'test': + slide_crop_image(img_path, out_dir, dataset_mode, patch_H, + patch_W, overlap) + + else: + shutil.move(img_path, + os.path.join(out_dir, 'img_dir', dataset_mode)) + src_prog_bar.update() + + if dataset_mode != 'test': + label_zipp_list = glob.glob( + os.path.join(dataset_path, dataset_mode, 'Semantic_masks', + '*.zip')) + for label_zipp in label_zipp_list: + zip_file = zipfile.ZipFile(label_zipp) + zip_file.extractall( + os.path.join(tmp_dir, dataset_mode, 'lab')) + + lab_path_list = glob.glob( + os.path.join(tmp_dir, dataset_mode, 'lab', 'images', + '*.png')) + lab_prog_bar = ProgressBar(len(lab_path_list)) + for i, lab_path in enumerate(lab_path_list): + slide_crop_label(lab_path, out_dir, dataset_mode, patch_H, + patch_W, overlap) + lab_prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/levircd.py b/tools/dataset_converters/levircd.py new file mode 100644 index 0000000000000000000000000000000000000000..8717f3e856ba3f171b511f34d0217e1fda87ccb6 --- /dev/null +++ b/tools/dataset_converters/levircd.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert levir-cd dataset to mmsegmentation format') + parser.add_argument('--dataset_path', help='potsdam folder path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=256) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + input_folder = args.dataset_path + png_files = glob.glob( + os.path.join(input_folder, '**/*.png'), recursive=True) + output_folder = args.out_dir + prog_bar = ProgressBar(len(png_files)) + for png_file in png_files: + new_path = os.path.join( + output_folder, + os.path.relpath(os.path.dirname(png_file), input_folder)) + os.makedirs(os.path.dirname(new_path), exist_ok=True) + label = False + if 'label' in png_file: + label = True + clip_big_image(png_file, new_path, args, label) + prog_bar.update() + + +def clip_big_image(image_path, clip_save_dir, args, to_label=False): + image = mmcv.imread(image_path) + + h, w, c = image.shape + clip_size = args.clip_size + stride_size = args.stride_size + + num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( + (h - clip_size) / + stride_size) * stride_size + clip_size >= h else math.ceil( + (h - clip_size) / stride_size) + 1 + num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( + (w - clip_size) / + stride_size) * stride_size + clip_size >= w else math.ceil( + (w - clip_size) / stride_size) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * clip_size + ymin = y * clip_size + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, + np.zeros_like(xmin)) + ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, + np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h) + ], + axis=1) + + if to_label: + image[image == 255] = 1 + image = image[:, :, 0] + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, start_x:end_x] \ + if to_label else image[start_y:end_y, start_x:end_x, :] + idx = osp.basename(image_path).split('.')[0] + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join(clip_save_dir, + f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/loveda.py b/tools/dataset_converters/loveda.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0ef4bb8bbd07f60dfc0397e9659f0200b96f5d --- /dev/null +++ b/tools/dataset_converters/loveda.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import shutil +import tempfile +import zipfile + +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert LoveDA dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='LoveDA folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'loveDA') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'img_dir')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + assert 'Train.zip' in os.listdir(dataset_path), \ + f'Train.zip is not in {dataset_path}' + assert 'Val.zip' in os.listdir(dataset_path), \ + f'Val.zip is not in {dataset_path}' + assert 'Test.zip' in os.listdir(dataset_path), \ + f'Test.zip is not in {dataset_path}' + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for dataset in ['Train', 'Val', 'Test']: + zip_file = zipfile.ZipFile( + os.path.join(dataset_path, dataset + '.zip')) + zip_file.extractall(tmp_dir) + data_type = dataset.lower() + for location in ['Rural', 'Urban']: + for image_type in ['images_png', 'masks_png']: + if image_type == 'images_png': + dst = osp.join(out_dir, 'img_dir', data_type) + else: + dst = osp.join(out_dir, 'ann_dir', data_type) + if dataset == 'Test' and image_type == 'masks_png': + continue + else: + src_dir = osp.join(tmp_dir, dataset, location, + image_type) + src_lst = os.listdir(src_dir) + for file in src_lst: + shutil.move(osp.join(src_dir, file), dst) + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/nyu.py b/tools/dataset_converters/nyu.py new file mode 100644 index 0000000000000000000000000000000000000000..49e09e7af6844b709e681f6d9f4df14ed547a00c --- /dev/null +++ b/tools/dataset_converters/nyu.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil +import tempfile +import zipfile + +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert NYU Depth dataset to mmsegmentation format') + parser.add_argument('raw_data', help='the path of raw data') + parser.add_argument( + '-o', '--out_dir', help='output path', default='./data/nyu') + args = parser.parse_args() + return args + + +def reorganize(raw_data_dir: str, out_dir: str): + """Reorganize NYU Depth dataset files into the required directory + structure. + + Args: + raw_data_dir (str): Path to the raw data directory. + out_dir (str): Output directory for the organized dataset. + """ + + def move_data(data_list, dst_prefix, fname_func): + """Move data files from source to destination directory. + + Args: + data_list (list): List of data file paths. + dst_prefix (str): Prefix to be added to destination paths. + fname_func (callable): Function to process file names + """ + for data_item in data_list: + data_item = data_item.strip().strip('/') + new_item = fname_func(data_item) + shutil.move( + osp.join(raw_data_dir, data_item), + osp.join(out_dir, dst_prefix, new_item)) + + def process_phase(phase): + """Process a dataset phase (e.g., 'train' or 'test').""" + with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f: + data = filter(lambda x: len(x.strip()) > 0, f.readlines()) + data = map(lambda x: x.split()[:2], data) + images, annos = zip(*data) + + move_data(images, f'images/{phase}', + lambda x: x.replace('/rgb', '')) + move_data(annos, f'annotations/{phase}', + lambda x: x.replace('/sync_depth', '')) + + process_phase('train') + process_phase('test') + + +def main(): + args = parse_args() + + print('Making directories...') + mkdir_or_exist(args.out_dir) + for subdir in [ + 'images/train', 'images/test', 'annotations/train', + 'annotations/test' + ]: + mkdir_or_exist(osp.join(args.out_dir, subdir)) + + print('Generating images and annotations...') + + if args.raw_data.endswith('.zip'): + with tempfile.TemporaryDirectory() as tmp_dir: + zip_file = zipfile.ZipFile(args.raw_data) + zip_file.extractall(tmp_dir) + reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir) + else: + assert osp.isdir( + args.raw_data + ), 'the argument --raw-data should be either a zip file or directory.' + reorganize(args.raw_data, args.out_dir) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/pascal_context.py b/tools/dataset_converters/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..a92d1dc6411137b92fe67fbde0fc554060194085 --- /dev/null +++ b/tools/dataset_converters/pascal_context.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from functools import partial + +import numpy as np +from detail import Detail +from mmengine.utils import mkdir_or_exist, track_progress +from PIL import Image + +_mapping = np.sort( + np.array([ + 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, + 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, + 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, + 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 + ])) +_key = np.array(range(len(_mapping))).astype('uint8') + + +def generate_labels(img_id, detail, out_dir): + + def _class_to_index(mask, _mapping, _key): + # assert the values + values = np.unique(mask) + for i in range(len(values)): + assert (values[i] in _mapping) + index = np.digitize(mask.ravel(), _mapping, right=True) + return _key[index].reshape(mask.shape) + + mask = Image.fromarray( + _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) + filename = img_id['file_name'] + mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) + return osp.splitext(osp.basename(filename))[0] + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert PASCAL VOC annotations to mmsegmentation format') + parser.add_argument('devkit_path', help='pascal voc devkit path') + parser.add_argument('json_path', help='annoation json filepath') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + devkit_path = args.devkit_path + if args.out_dir is None: + out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') + else: + out_dir = args.out_dir + json_path = args.json_path + mkdir_or_exist(out_dir) + img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') + + train_detail = Detail(json_path, img_dir, 'train') + train_ids = train_detail.getImgs() + + val_detail = Detail(json_path, img_dir, 'val') + val_ids = val_detail.getImgs() + + mkdir_or_exist( + osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) + + train_list = track_progress( + partial(generate_labels, detail=train_detail, out_dir=out_dir), + train_ids) + with open( + osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', + 'train.txt'), 'w') as f: + f.writelines(line + '\n' for line in sorted(train_list)) + + val_list = track_progress( + partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) + with open( + osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', + 'val.txt'), 'w') as f: + f.writelines(line + '\n' for line in sorted(val_list)) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/potsdam.py b/tools/dataset_converters/potsdam.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c713ee2a08d2f6eaf68fb225899504b8f4e829 --- /dev/null +++ b/tools/dataset_converters/potsdam.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert potsdam dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='potsdam folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=512) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def clip_big_image(image_path, clip_save_dir, args, to_label=False): + # Original image of Potsdam dataset is very large, thus pre-processing + # of them is adopted. Given fixed clip size and stride size to generate + # clipped image, the intersection of width and height is determined. + # For example, given one 5120 x 5120 original image, the clip size is + # 512 and stride size is 256, thus it would generate 20x20 = 400 images + # whose size are all 512x512. + image = mmcv.imread(image_path) + + h, w, c = image.shape + clip_size = args.clip_size + stride_size = args.stride_size + + num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( + (h - clip_size) / + stride_size) * stride_size + clip_size >= h else math.ceil( + (h - clip_size) / stride_size) + 1 + num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( + (w - clip_size) / + stride_size) * stride_size + clip_size >= w else math.ceil( + (w - clip_size) / stride_size) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * clip_size + ymin = y * clip_size + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, + np.zeros_like(xmin)) + ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, + np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h) + ], + axis=1) + + if to_label: + color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], + [255, 255, 0], [0, 255, 0], [0, 255, 255], + [0, 0, 255]]) + flatten_v = np.matmul( + image.reshape(-1, c), + np.array([2, 3, 4]).reshape(3, 1)) + out = np.zeros_like(flatten_v) + for idx, class_color in enumerate(color_map): + value_idx = np.matmul(class_color, + np.array([2, 3, 4]).reshape(3, 1)) + out[flatten_v == value_idx] = idx + image = out.reshape(h, w) + + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, + start_x:end_x] if to_label else image[ + start_y:end_y, start_x:end_x, :] + idx_i, idx_j = osp.basename(image_path).split('_')[2:4] + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join( + clip_save_dir, + f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +def main(): + args = parse_args() + splits = { + 'train': [ + '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11', + '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7', + '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9' + ], + 'val': [ + '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13', + '4_15', '2_14', '5_13', '4_13', '3_14', '7_13' + ] + } + + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'potsdam') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) + print('Find the data', zipp_list) + + for zipp in zipp_list: + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + zip_file = zipfile.ZipFile(zipp) + zip_file.extractall(tmp_dir) + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + if not len(src_path_list): + sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0]) + src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif')) + + prog_bar = ProgressBar(len(src_path_list)) + for i, src_path in enumerate(src_path_list): + idx_i, idx_j = osp.basename(src_path).split('_')[2:4] + data_type = 'train' if f'{idx_i}_{idx_j}' in splits[ + 'train'] else 'val' + if 'label' in src_path: + dst_dir = osp.join(out_dir, 'ann_dir', data_type) + clip_big_image(src_path, dst_dir, args, to_label=True) + else: + dst_dir = osp.join(out_dir, 'img_dir', data_type) + clip_big_image(src_path, dst_dir, args, to_label=False) + prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/refuge.py b/tools/dataset_converters/refuge.py new file mode 100644 index 0000000000000000000000000000000000000000..1186866ab3fd58c4d72e5f573938053a8d7c80b2 --- /dev/null +++ b/tools/dataset_converters/refuge.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert REFUGE dataset to mmsegmentation format') + parser.add_argument('--raw_data_root', help='the root path of raw data') + + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def extract_img(root: str, + cur_dir: str, + out_dir: str, + mode: str = 'train', + file_type: str = 'img') -> None: + """_summary_ + + Args: + Args: + root (str): root where the extracted data is saved + cur_dir (cur_dir): dir where the zip_file exists + out_dir (str): root dir where the data is saved + + mode (str, optional): Defaults to 'train'. + file_type (str, optional): Defaults to 'img',else to 'mask'. + """ + zip_file = zipfile.ZipFile(cur_dir) + zip_file.extractall(root) + for cur_dir, dirs, files in os.walk(root): + # filter child dirs and directories with "Illustration" and "MACOSX" + if len(dirs) == 0 and \ + cur_dir.split('\\')[-1].find('Illustration') == -1 and \ + cur_dir.find('MACOSX') == -1: + + file_names = [ + file for file in files + if file.endswith('.jpg') or file.endswith('.bmp') + ] + for filename in sorted(file_names): + img = mmcv.imread(osp.join(cur_dir, filename)) + + if file_type == 'annotations': + img = img[:, :, 0] + img[np.where(img == 0)] = 1 + img[np.where(img == 128)] = 2 + img[np.where(img == 255)] = 0 + mmcv.imwrite( + img, + osp.join(out_dir, file_type, mode, + osp.splitext(filename)[0] + '.png')) + + +def main(): + args = parse_args() + + raw_data_root = args.raw_data_root + if args.out_dir is None: + out_dir = osp.join('./data', 'REFUGE') + + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'images', 'test')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'test')) + + print('Generating images and annotations...') + # process data from the child dir on the first rank + cur_dir, dirs, files = list(os.walk(raw_data_root))[0] + print('====================') + + files = list(filter(lambda x: x.endswith('.zip'), files)) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for file in files: + # search data folders for training,validation,test + mode = list( + filter(lambda x: file.lower().find(x) != -1, + ['training', 'test', 'validation']))[0] + file_root = osp.join(tmp_dir, file[:-4]) + file_type = 'images' if file.find('Anno') == -1 and file.find( + 'GT') == -1 else 'annotations' + extract_img(file_root, osp.join(cur_dir, file), out_dir, mode, + file_type) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/stare.py b/tools/dataset_converters/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..4a23ba4dd8a4744bca9d1a506c79131c0e42c73d --- /dev/null +++ b/tools/dataset_converters/stare.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import gzip +import os +import os.path as osp +import tarfile +import tempfile + +import mmcv +from mmengine.utils import mkdir_or_exist + +STARE_LEN = 20 +TRAINING_LEN = 10 + + +def un_gz(src, dst): + g_file = gzip.GzipFile(src) + with open(dst, 'wb+') as f: + f.write(g_file.read()) + g_file.close() + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert STARE dataset to mmsegmentation format') + parser.add_argument('image_path', help='the path of stare-images.tar') + parser.add_argument('labels_ah', help='the path of labels-ah.tar') + parser.add_argument('labels_vk', help='the path of labels-vk.tar') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + image_path = args.image_path + labels_ah = args.labels_ah + labels_vk = args.labels_vk + if args.out_dir is None: + out_dir = osp.join('data', 'STARE') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + mkdir_or_exist(osp.join(tmp_dir, 'gz')) + mkdir_or_exist(osp.join(tmp_dir, 'files')) + + print('Extracting stare-images.tar...') + with tarfile.open(image_path) as f: + f.extractall(osp.join(tmp_dir, 'gz')) + + for filename in os.listdir(osp.join(tmp_dir, 'gz')): + un_gz( + osp.join(tmp_dir, 'gz', filename), + osp.join(tmp_dir, 'files', + osp.splitext(filename)[0])) + + now_dir = osp.join(tmp_dir, 'files') + + assert len(os.listdir(now_dir)) == STARE_LEN, \ + f'len(os.listdir(now_dir)) != {STARE_LEN}' + + for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'training', + osp.splitext(filename)[0] + '.png')) + + for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Removing the temporary files...') + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + mkdir_or_exist(osp.join(tmp_dir, 'gz')) + mkdir_or_exist(osp.join(tmp_dir, 'files')) + + print('Extracting labels-ah.tar...') + with tarfile.open(labels_ah) as f: + f.extractall(osp.join(tmp_dir, 'gz')) + + for filename in os.listdir(osp.join(tmp_dir, 'gz')): + un_gz( + osp.join(tmp_dir, 'gz', filename), + osp.join(tmp_dir, 'files', + osp.splitext(filename)[0])) + + now_dir = osp.join(tmp_dir, 'files') + + assert len(os.listdir(now_dir)) == STARE_LEN, \ + f'len(os.listdir(now_dir)) != {STARE_LEN}' + + for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(now_dir, filename)) + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a threshold + # to convert the nonstandard annotation imgs. The value divided by + # 128 equivalent to '1 if value >= 128 else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(filename)[0] + '.png')) + + for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Removing the temporary files...') + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + mkdir_or_exist(osp.join(tmp_dir, 'gz')) + mkdir_or_exist(osp.join(tmp_dir, 'files')) + + print('Extracting labels-vk.tar...') + with tarfile.open(labels_vk) as f: + f.extractall(osp.join(tmp_dir, 'gz')) + + for filename in os.listdir(osp.join(tmp_dir, 'gz')): + un_gz( + osp.join(tmp_dir, 'gz', filename), + osp.join(tmp_dir, 'files', + osp.splitext(filename)[0])) + + now_dir = osp.join(tmp_dir, 'files') + + assert len(os.listdir(now_dir)) == STARE_LEN, \ + f'len(os.listdir(now_dir)) != {STARE_LEN}' + + for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(filename)[0] + '.png')) + + for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/synapse.py b/tools/dataset_converters/synapse.py new file mode 100644 index 0000000000000000000000000000000000000000..42dac6b7eff94107b8b3a59984622cb1fd2e7599 --- /dev/null +++ b/tools/dataset_converters/synapse.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import nibabel as nib +import numpy as np +from mmengine.utils import mkdir_or_exist +from PIL import Image + + +def read_files_from_txt(txt_path): + with open(txt_path) as f: + files = f.readlines() + files = [file.strip() for file in files] + return files + + +def read_nii_file(nii_path): + img = nib.load(nii_path).get_fdata() + return img + + +def split_3d_image(img): + c, _, _ = img.shape + res = [] + for i in range(c): + res.append(img[i, :, :]) + return res + + +def label_mapping(label): + """Label mapping from TransUNet paper setting. It only has 9 classes, which + are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground + classes in original dataset are all set to background. + + More details could be found here: https://arxiv.org/abs/2102.04306 + """ + maped_label = np.zeros_like(label) + maped_label[label == 8] = 1 + maped_label[label == 4] = 2 + maped_label[label == 3] = 3 + maped_label[label == 2] = 4 + maped_label[label == 6] = 5 + maped_label[label == 11] = 6 + maped_label[label == 1] = 7 + maped_label[label == 7] = 8 + return maped_label + + +def pares_args(): + parser = argparse.ArgumentParser( + description='Convert synapse dataset to mmsegmentation format') + parser.add_argument( + '--dataset-path', type=str, help='synapse dataset path.') + parser.add_argument( + '--save-path', + default='data/synapse', + type=str, + help='save path of the dataset.') + args = parser.parse_args() + return args + + +def main(): + args = pares_args() + dataset_path = args.dataset_path + save_path = args.save_path + + if not osp.exists(dataset_path): + raise ValueError('The dataset path does not exist. ' + 'Please enter a correct dataset path.') + if not osp.exists(osp.join(dataset_path, 'img')) \ + or not osp.exists(osp.join(dataset_path, 'label')): + raise FileNotFoundError('The dataset structure is incorrect. ' + 'Please check your dataset.') + + train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt')) + train_id = [idx[3:7] for idx in train_id] + + test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt')) + test_id = [idx[3:7] for idx in test_id] + + mkdir_or_exist(osp.join(save_path, 'img_dir/train')) + mkdir_or_exist(osp.join(save_path, 'img_dir/val')) + mkdir_or_exist(osp.join(save_path, 'ann_dir/train')) + mkdir_or_exist(osp.join(save_path, 'ann_dir/val')) + + # It follows data preparation pipeline from here: + # https://github.com/Beckschen/TransUNet/tree/main/datasets + for i, idx in enumerate(train_id): + img_3d = read_nii_file( + osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) + label_3d = read_nii_file( + osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) + + img_3d = np.clip(img_3d, -125, 275) + img_3d = (img_3d + 125) / 400 + img_3d *= 255 + img_3d = np.transpose(img_3d, [2, 0, 1]) + img_3d = np.flip(img_3d, 2) + + label_3d = np.transpose(label_3d, [2, 0, 1]) + label_3d = np.flip(label_3d, 2) + label_3d = label_mapping(label_3d) + + for c in range(img_3d.shape[0]): + img = img_3d[c] + label = label_3d[c] + + img = Image.fromarray(img).convert('RGB') + label = Image.fromarray(label).convert('L') + img.save( + osp.join( + save_path, 'img_dir/train', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.jpg')) + label.save( + osp.join( + save_path, 'ann_dir/train', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.png')) + + for i, idx in enumerate(test_id): + img_3d = read_nii_file( + osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) + label_3d = read_nii_file( + osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) + + img_3d = np.clip(img_3d, -125, 275) + img_3d = (img_3d + 125) / 400 + img_3d *= 255 + img_3d = np.transpose(img_3d, [2, 0, 1]) + img_3d = np.flip(img_3d, 2) + + label_3d = np.transpose(label_3d, [2, 0, 1]) + label_3d = np.flip(label_3d, 2) + label_3d = label_mapping(label_3d) + + for c in range(img_3d.shape[0]): + img = img_3d[c] + label = label_3d[c] + + img = Image.fromarray(img).convert('RGB') + label = Image.fromarray(label).convert('L') + img.save( + osp.join( + save_path, 'img_dir/val', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.jpg')) + label.save( + osp.join( + save_path, 'ann_dir/val', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.png')) + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/vaihingen.py b/tools/dataset_converters/vaihingen.py new file mode 100644 index 0000000000000000000000000000000000000000..db980144eb491846a844b0a374bb7a01d5509265 --- /dev/null +++ b/tools/dataset_converters/vaihingen.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert vaihingen dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='vaihingen folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=512) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def clip_big_image(image_path, clip_save_dir, to_label=False): + # Original image of Vaihingen dataset is very large, thus pre-processing + # of them is adopted. Given fixed clip size and stride size to generate + # clipped image, the intersection of width and height is determined. + # For example, given one 5120 x 5120 original image, the clip size is + # 512 and stride size is 256, thus it would generate 20x20 = 400 images + # whose size are all 512x512. + image = mmcv.imread(image_path) + + h, w, c = image.shape + cs = args.clip_size + ss = args.stride_size + + num_rows = math.ceil((h - cs) / ss) if math.ceil( + (h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1 + num_cols = math.ceil((w - cs) / ss) if math.ceil( + (w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * cs + ymin = y * cs + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin)) + ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + cs, w), + np.minimum(ymin + cs, h) + ], + axis=1) + + if to_label: + color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], + [255, 255, 0], [0, 255, 0], [0, 255, 255], + [0, 0, 255]]) + flatten_v = np.matmul( + image.reshape(-1, c), + np.array([2, 3, 4]).reshape(3, 1)) + out = np.zeros_like(flatten_v) + for idx, class_color in enumerate(color_map): + value_idx = np.matmul(class_color, + np.array([2, 3, 4]).reshape(3, 1)) + out[flatten_v == value_idx] = idx + image = out.reshape(h, w) + + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, + start_x:end_x] if to_label else image[ + start_y:end_y, start_x:end_x, :] + area_idx = osp.basename(image_path).split('_')[3].strip('.tif') + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join(clip_save_dir, + f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +def main(): + splits = { + 'train': [ + 'area1', 'area11', 'area13', 'area15', 'area17', 'area21', + 'area23', 'area26', 'area28', 'area3', 'area30', 'area32', + 'area34', 'area37', 'area5', 'area7' + ], + 'val': [ + 'area6', 'area24', 'area35', 'area16', 'area14', 'area22', + 'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33', + 'area27', 'area38', 'area12', 'area29' + ], + } + + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'vaihingen') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) + print('Find the data', zipp_list) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for zipp in zipp_list: + zip_file = zipfile.ZipFile(zipp) + zip_file.extractall(tmp_dir) + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + if 'ISPRS_semantic_labeling_Vaihingen' in zipp: + src_path_list = glob.glob( + os.path.join(os.path.join(tmp_dir, 'top'), '*.tif')) + if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + # delete unused area9 ground truth + for area_ann in src_path_list: + if 'area9' in area_ann: + src_path_list.remove(area_ann) + prog_bar = ProgressBar(len(src_path_list)) + for i, src_path in enumerate(src_path_list): + area_idx = osp.basename(src_path).split('_')[3].strip('.tif') + data_type = 'train' if area_idx in splits['train'] else 'val' + if 'noBoundary' in src_path: + dst_dir = osp.join(out_dir, 'ann_dir', data_type) + clip_big_image(src_path, dst_dir, to_label=True) + else: + dst_dir = osp.join(out_dir, 'img_dir', data_type) + clip_big_image(src_path, dst_dir, to_label=False) + prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + args = parse_args() + main() diff --git a/tools/dataset_converters/voc_aug.py b/tools/dataset_converters/voc_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..a536f4290d06e4a6c3c9fa8dbadfda847fec583b --- /dev/null +++ b/tools/dataset_converters/voc_aug.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from functools import partial + +import numpy as np +from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress +from PIL import Image +from scipy.io import loadmat + +AUG_LEN = 10582 + + +def convert_mat(mat_file, in_dir, out_dir): + data = loadmat(osp.join(in_dir, mat_file)) + mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) + seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) + Image.fromarray(mask).save(seg_filename, 'PNG') + + +def generate_aug_list(merged_list, excluded_list): + return list(set(merged_list) - set(excluded_list)) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert PASCAL VOC annotations to mmsegmentation format') + parser.add_argument('devkit_path', help='pascal voc devkit path') + parser.add_argument('aug_path', help='pascal voc aug path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + devkit_path = args.devkit_path + aug_path = args.aug_path + nproc = args.nproc + if args.out_dir is None: + out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') + else: + out_dir = args.out_dir + mkdir_or_exist(out_dir) + in_dir = osp.join(aug_path, 'dataset', 'cls') + + track_parallel_progress( + partial(convert_mat, in_dir=in_dir, out_dir=out_dir), + list(scandir(in_dir, suffix='.mat')), + nproc=nproc) + + full_aug_list = [] + with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: + full_aug_list += [line.strip() for line in f] + with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: + full_aug_list += [line.strip() for line in f] + + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', + 'train.txt')) as f: + ori_train_list = [line.strip() for line in f] + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', + 'val.txt')) as f: + val_list = [line.strip() for line in f] + + aug_train_list = generate_aug_list(ori_train_list + full_aug_list, + val_list) + assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( + AUG_LEN) + + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', + 'trainaug.txt'), 'w') as f: + f.writelines(line + '\n' for line in aug_train_list) + + aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) + assert len(aug_list) == AUG_LEN - len( + ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - + len(ori_train_list)) + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), + 'w') as f: + f.writelines(line + '\n' for line in aug_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/deployment/pytorch2torchscript.py b/tools/deployment/pytorch2torchscript.py new file mode 100644 index 0000000000000000000000000000000000000000..e69e705bb13ff3cca233534c34fcdaaeda02825b --- /dev/null +++ b/tools/deployment/pytorch2torchscript.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import numpy as np +import torch +import torch._C +import torch.serialization +from mmengine import Config +from mmengine.runner import load_checkpoint +from torch import nn + +from mmseg.models import build_segmentor + +torch.manual_seed(3) + + +def digit_version(version_str): + digit_version = [] + for x in version_str.split('.'): + if x.isdigit(): + digit_version.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + digit_version.append(int(patch_version[0]) - 1) + digit_version.append(int(patch_version[1])) + return digit_version + + +def check_torch_version(): + torch_minimum_version = '1.8.0' + torch_version = digit_version(torch.__version__) + + assert (torch_version >= digit_version(torch_minimum_version)), \ + f'Torch=={torch.__version__} is not support for converting to ' \ + f'torchscript. Please install pytorch>={torch_minimum_version}.' + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape, num_classes): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + segs = rng.randint( + low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': 1.0, + 'flip': False, + } for _ in range(N)] + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_semantic_seg': torch.LongTensor(segs) + } + return mm_inputs + + +def pytorch2libtorch(model, + input_shape, + show=False, + output_file='tmp.pt', + verify=False): + """Export Pytorch model to TorchScript model and verify the outputs are + same between Pytorch and TorchScript. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the + output TorchScript model. Default: `tmp.pt`. + verify (bool): Whether compare the outputs between + Pytorch and TorchScript. Default: False. + """ + if isinstance(model.decode_head, nn.ModuleList): + num_classes = model.decode_head[-1].num_classes + else: + num_classes = model.decode_head.num_classes + + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + imgs = mm_inputs.pop('imgs') + + # replace the original forword with forward_dummy + model.forward = model.forward_dummy + model.eval() + traced_model = torch.jit.trace( + model, + example_inputs=imgs, + check_trace=verify, + ) + + if show: + print(traced_model.graph) + + traced_model.save(output_file) + print(f'Successfully exported TorchScript model: {output_file}') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMSeg to TorchScript') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument( + '--show', action='store_true', help='show TorchScript graph') + parser.add_argument( + '--verify', action='store_true', help='verify the TorchScript model') + parser.add_argument('--output-file', type=str, default='tmp.pt') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[512, 512], + help='input image size (height, width)') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + check_torch_version() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + cfg.model.train_cfg = None + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + if args.checkpoint: + load_checkpoint(segmentor, args.checkpoint, map_location='cpu') + + # convert the PyTorch model to LibTorch model + pytorch2libtorch( + segmentor, + input_shape, + show=args.show, + output_file=args.output_file, + verify=args.verify) diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..89711fd5c02cfc1f0386e5354506d4b74ecac251 --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,20 @@ +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/test.py \ + $CONFIG \ + $CHECKPOINT \ + --launcher pytorch \ + ${@:4} diff --git a/tools/dist_train.sh b/tools/dist_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..a857df78788edb8841b6f67d74dd0e6cfb77d8ab --- /dev/null +++ b/tools/dist_train.sh @@ -0,0 +1,17 @@ +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7863eb74f2cab53d025afad347f7886a5ce29919 --- /dev/null +++ b/tools/misc/browse_dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +from mmengine import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import ProgressBar + +from mmseg.registry import DATASETS, VISUALIZERS + + +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--output-dir', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=2, + help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # register all modules in mmseg into the registries + init_default_scope('mmseg') + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + cfg.visualizer['save_dir'] = args.output_dir + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.METAINFO + + progress_bar = ProgressBar(len(dataset)) + for item in dataset: + img = item['inputs'].permute(1, 2, 0).numpy() + data_sample = item['data_samples'].numpy() + img_path = osp.basename(item['data_samples'].img_path) + + img = img[..., [2, 1, 0]] # bgr to rgb + + visualizer.add_datasample( + osp.basename(img_path), + img, + data_sample, + show=not args.not_show, + wait_time=args.show_interval) + + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/tools/misc/print_config.py b/tools/misc/print_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1c024a6a44157a0b0d4d6213d18d67f57a33c5 --- /dev/null +++ b/tools/misc/print_config.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import warnings + +from mmengine import Config, DictAction + +from mmseg.apis import init_model + + +def parse_args(): + parser = argparse.ArgumentParser(description='Print the whole config') + parser.add_argument('config', help='config file path') + parser.add_argument( + '--graph', action='store_true', help='print the models graph') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + 'not be supported in version v0.22.0. Override some settings in the ' + 'used config, the key-value pair in xxx=yyy format will be merged ' + 'into config file. If the value to be overwritten is a list, it ' + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + 'marks are necessary and that no white space is allowed.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options, ' + '--options will not be supported in version v0.22.0.') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + print(f'Config:\n{cfg.pretty_text}') + # dump config + cfg.dump('example.py') + # dump models graph + if args.graph: + model = init_model(args.config, device='cpu') + print(f'Model graph:\n{str(model)}') + with open('example-graph.txt', 'w') as f: + f.writelines(str(model)) + + +if __name__ == '__main__': + main() diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e035ad90e85e0e03d8304c1d5b524c5ac322c644 --- /dev/null +++ b/tools/misc/publish_model.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import subprocess +from hashlib import sha256 + +import torch + +BLOCK_SIZE = 128 * 1024 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Process a checkpoint to be published') + parser.add_argument('in_file', help='input checkpoint filename') + parser.add_argument('out_file', help='output checkpoint filename') + args = parser.parse_args() + return args + + +def sha256sum(filename: str) -> str: + """Compute SHA256 message digest from a file.""" + hash_func = sha256() + byte_array = bytearray(BLOCK_SIZE) + memory_view = memoryview(byte_array) + with open(filename, 'rb', buffering=0) as file: + for block in iter(lambda: file.readinto(memory_view), 0): + hash_func.update(memory_view[:block]) + return hash_func.hexdigest() + + +def process_checkpoint(in_file, out_file): + checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size + if 'optimizer' in checkpoint: + del checkpoint['optimizer'] + # if it is necessary to remove some sensitive data in checkpoint['meta'], + # add the code here. + torch.save(checkpoint, out_file) + sha = sha256sum(in_file) + final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' + subprocess.Popen(['mv', out_file, final_file]) + + +def main(): + args = parse_args() + process_checkpoint(args.in_file, args.out_file) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/beit2mmseg.py b/tools/model_converters/beit2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..20f8f0f4509f93291782ca152bf04ab019b0e0ff --- /dev/null +++ b/tools/model_converters/beit2mmseg.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_beit(ckpt): + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + if k.startswith('patch_embed'): + new_key = k.replace('patch_embed.proj', 'patch_embed.projection') + new_ckpt[new_key] = v + if k.startswith('blocks'): + new_key = k.replace('blocks', 'layers') + if 'norm' in new_key: + new_key = new_key.replace('norm', 'ln') + elif 'mlp.fc1' in new_key: + new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in new_key: + new_key = new_key.replace('mlp.fc2', 'ffn.layers.1') + new_ckpt[new_key] = v + else: + new_key = k + new_ckpt[new_key] = v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained beit models to' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_beit(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/clip2mmseg.py b/tools/model_converters/clip2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..9a97e4b04ab45740ee37149d30a85b67245868f5 --- /dev/null +++ b/tools/model_converters/clip2mmseg.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_vitlayer(paras): + new_para_name = '' + if paras[0] == 'ln_1': + new_para_name = '.'.join(['ln1'] + paras[1:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attn.attn'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['ln2'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:]) + else: + new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:]) + else: + print(f'Wrong for {paras}') + return new_para_name + + +def convert_translayer(paras): + new_para_name = '' + if paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:]) + else: + print(f'Wrong for {paras}') + else: + print(f'Wrong for {paras}') + return new_para_name + + +def convert_key_name(ckpt, visual_split): + new_ckpt = OrderedDict() + for k, v in ckpt.items(): + key_list = k.split('.') + if key_list[0] == 'visual': + new_transform_name = 'image_encoder' + if key_list[1] == 'class_embedding': + new_name = '.'.join([new_transform_name, 'cls_token']) + elif key_list[1] == 'positional_embedding': + new_name = '.'.join([new_transform_name, 'pos_embed']) + elif key_list[1] == 'conv1': + new_name = '.'.join([ + new_transform_name, 'patch_embed.projection', key_list[2] + ]) + elif key_list[1] == 'ln_pre': + new_name = '.'.join( + [new_transform_name, key_list[1], key_list[2]]) + elif key_list[1] == 'transformer': + new_layer_name = 'layers' + layer_index = key_list[3] + paras = key_list[4:] + if int(layer_index) < visual_split: + new_para_name = convert_vitlayer(paras) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + else: + new_para_name = convert_translayer(paras) + new_transform_name = 'decode_head.rec_with_attnbias' + new_layer_name = 'layers' + layer_index = str(int(layer_index) - visual_split) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + elif key_list[1] == 'proj': + new_name = 'decode_head.rec_with_attnbias.proj.weight' + elif key_list[1] == 'ln_post': + new_name = k.replace('visual', 'decode_head.rec_with_attnbias') + else: + print(f'pop parameter: {k}') + continue + else: + text_encoder_name = 'text_encoder' + if key_list[0] == 'transformer': + layer_name = 'transformer' + layer_index = key_list[2] + paras = key_list[3:] + new_para_name = convert_translayer(paras) + new_name = '.'.join([ + text_encoder_name, layer_name, layer_index, new_para_name + ]) + elif key_list[0] in [ + 'positional_embedding', 'text_projection', 'bg_embed', + 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final' + ]: + new_name = 'text_encoder.' + k + else: + print(f'pop parameter: {k}') + continue + new_ckpt[new_name] = v + + return new_ckpt + + +def convert_tensor(ckpt): + cls_token = ckpt['image_encoder.cls_token'] + new_cls_token = cls_token.unsqueeze(0).unsqueeze(0) + ckpt['image_encoder.cls_token'] = new_cls_token + pos_embed = ckpt['image_encoder.pos_embed'] + new_pos_embed = pos_embed.unsqueeze(0) + ckpt['image_encoder.pos_embed'] = new_pos_embed + proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight'] + new_proj_weight = proj_weight.transpose(1, 0) + ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]): + visual_split = 9 + elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]): + visual_split = 18 + else: + print('Make sure the clip model is ViT-B/16 or ViT-L/14!') + visual_split = -1 + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if isinstance(checkpoint, torch.jit.RecursiveScriptModule): + state_dict = checkpoint.state_dict() + else: + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_key_name(state_dict, visual_split) + weight = convert_tensor(weight) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/mit2mmseg.py b/tools/model_converters/mit2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..f10cbbf9d40d3656be0d447460c12fc83771c14c --- /dev/null +++ b/tools/model_converters/mit2mmseg.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_mit(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + for k, v in ckpt.items(): + if k.startswith('head'): + continue + # patch embedding conversion + elif k.startswith('patch_embed'): + stage_i = int(k.split('.')[0].replace('patch_embed', '')) + new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') + new_v = v + if 'proj.' in new_k: + new_k = new_k.replace('proj.', 'projection.') + # transformer encoder layer conversion + elif k.startswith('block'): + stage_i = int(k.split('.')[0].replace('block', '')) + new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') + new_v = v + if 'attn.q.' in new_k: + sub_item_k = k.replace('q.', 'kv.') + new_k = new_k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) + elif 'attn.kv.' in new_k: + continue + elif 'attn.proj.' in new_k: + new_k = new_k.replace('proj.', 'attn.out_proj.') + elif 'attn.sr.' in new_k: + new_k = new_k.replace('sr.', 'sr.') + elif 'mlp.' in new_k: + string = f'{new_k}-' + new_k = new_k.replace('mlp.', 'ffn.layers.') + if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + new_v = v.reshape((*v.shape, 1, 1)) + new_k = new_k.replace('fc1.', '0.') + new_k = new_k.replace('dwconv.dwconv.', '1.') + new_k = new_k.replace('fc2.', '4.') + string += f'{new_k} {v.shape}-{new_v.shape}' + # norm layer conversion + elif k.startswith('norm'): + stage_i = int(k.split('.')[0].replace('norm', '')) + new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') + new_v = v + else: + new_k = k + new_v = v + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained segformer to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_mit(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/san2mmseg.py b/tools/model_converters/san2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..301a46608e0f14df17138922ae3a747aee105372 --- /dev/null +++ b/tools/model_converters/san2mmseg.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_key_name(ckpt): + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + key_list = k.split('.') + if key_list[0] == 'clip_visual_extractor': + new_transform_name = 'image_encoder' + if key_list[1] == 'class_embedding': + new_name = '.'.join([new_transform_name, 'cls_token']) + elif key_list[1] == 'positional_embedding': + new_name = '.'.join([new_transform_name, 'pos_embed']) + elif key_list[1] == 'conv1': + new_name = '.'.join([ + new_transform_name, 'patch_embed.projection', key_list[2] + ]) + elif key_list[1] == 'ln_pre': + new_name = '.'.join( + [new_transform_name, key_list[1], key_list[2]]) + elif key_list[1] == 'resblocks': + new_layer_name = 'layers' + layer_index = key_list[2] + paras = key_list[3:] + if paras[0] == 'ln_1': + new_para_name = '.'.join(['ln1'] + key_list[4:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attn.attn'] + key_list[4:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['ln2'] + key_list[4:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffn.layers.0.0'] + + key_list[-1:]) + else: + new_para_name = '.'.join(['ffn.layers.1'] + + key_list[-1:]) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + elif key_list[0] == 'side_adapter_network': + decode_head_name = 'decode_head' + module_name = 'side_adapter_network' + if key_list[1] == 'vit_model': + if key_list[2] == 'blocks': + layer_name = 'encode_layers' + layer_index = key_list[3] + paras = key_list[4:] + if paras[0] == 'norm1': + new_para_name = '.'.join(['ln1'] + key_list[5:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(key_list[4:]) + new_para_name = new_para_name.replace( + 'attn.qkv.', 'attn.attn.in_proj_') + new_para_name = new_para_name.replace( + 'attn.proj', 'attn.attn.out_proj') + elif paras[0] == 'norm2': + new_para_name = '.'.join(['ln2'] + key_list[5:]) + elif paras[0] == 'mlp': + new_para_name = '.'.join(['ffn'] + key_list[5:]) + new_para_name = new_para_name.replace( + 'fc1', 'layers.0.0') + new_para_name = new_para_name.replace( + 'fc2', 'layers.1') + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + decode_head_name, module_name, layer_name, layer_index, + new_para_name + ]) + elif key_list[2] == 'pos_embed': + new_name = '.'.join( + [decode_head_name, module_name, 'pos_embed']) + elif key_list[2] == 'patch_embed': + new_name = '.'.join([ + decode_head_name, module_name, 'patch_embed', + 'projection', key_list[4] + ]) + else: + print(f'Wrong for {k}') + elif key_list[1] == 'query_embed' or key_list[ + 1] == 'query_pos_embed': + new_name = '.'.join( + [decode_head_name, module_name, key_list[1]]) + elif key_list[1] == 'fusion_layers': + layer_name = 'conv_clips' + layer_index = key_list[2][-1] + paras = '.'.join(key_list[3:]) + new_para_name = paras.replace('input_proj.0', '0') + new_para_name = new_para_name.replace('input_proj.1', '1.conv') + new_name = '.'.join([ + decode_head_name, module_name, layer_name, layer_index, + new_para_name + ]) + elif key_list[1] == 'mask_decoder': + new_name = 'decode_head.' + k + else: + print(f'Wrong for {k}') + elif key_list[0] == 'clip_rec_head': + module_name = 'rec_with_attnbias' + if key_list[1] == 'proj': + new_name = '.'.join( + [decode_head_name, module_name, 'proj.weight']) + elif key_list[1] == 'ln_post': + new_name = '.'.join( + [decode_head_name, module_name, 'ln_post', key_list[2]]) + elif key_list[1] == 'resblocks': + new_layer_name = 'layers' + layer_index = key_list[2] + paras = key_list[3:] + if paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + + paras[2:]) + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + decode_head_name, module_name, new_layer_name, layer_index, + new_para_name + ]) + else: + print(f'Wrong for {k}') + elif key_list[0] == 'ov_classifier': + text_encoder_name = 'text_encoder' + if key_list[1] == 'transformer': + layer_name = 'transformer' + layer_index = key_list[3] + paras = key_list[4:] + if paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + + paras[2:]) + else: + print(f'Wrong for {k}') + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + text_encoder_name, layer_name, layer_index, new_para_name + ]) + elif key_list[1] in [ + 'positional_embedding', 'text_projection', 'bg_embed', + 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final' + ]: + new_name = k.replace('ov_classifier', 'text_encoder') + else: + print(f'Wrong for {k}') + elif key_list[0] == 'criterion': + new_name = k + else: + print(f'Wrong for {k}') + new_ckpt[new_name] = v + return new_ckpt + + +def convert_tensor(ckpt): + cls_token = ckpt['image_encoder.cls_token'] + new_cls_token = cls_token.unsqueeze(0).unsqueeze(0) + ckpt['image_encoder.cls_token'] = new_cls_token + pos_embed = ckpt['image_encoder.pos_embed'] + new_pos_embed = pos_embed.unsqueeze(0) + ckpt['image_encoder.pos_embed'] = new_pos_embed + proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight'] + new_proj_weight = proj_weight.transpose(1, 0) + ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_key_name(state_dict) + weight = convert_tensor(weight) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/stdc2mmseg.py b/tools/model_converters/stdc2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea3b8342f546692f50a8e3c0b740f881058229c --- /dev/null +++ b/tools/model_converters/stdc2mmseg.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_stdc(ckpt, stdc_type): + new_state_dict = {} + if stdc_type == 'STDC1': + stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1'] + else: + stage_lst = [ + '0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3', + '3.4', '4.0', '4.1', '4.2' + ] + for k, v in ckpt.items(): + ori_k = k + flag = False + if 'cp.' in k: + k = k.replace('cp.', '') + if 'features.' in k: + num_layer = int(k.split('.')[1]) + feature_key_lst = 'features.' + str(num_layer) + '.' + stages_key_lst = 'stages.' + stage_lst[num_layer] + '.' + k = k.replace(feature_key_lst, stages_key_lst) + flag = True + if 'conv_list' in k: + k = k.replace('conv_list', 'layers') + flag = True + if 'avd_layer.' in k: + if 'avd_layer.0' in k: + k = k.replace('avd_layer.0', 'downsample.conv') + elif 'avd_layer.1' in k: + k = k.replace('avd_layer.1', 'downsample.bn') + flag = True + if flag: + new_state_dict[k] = ckpt[ori_k] + + return new_state_dict + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained STDC1/2 to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + parser.add_argument('type', help='model type: STDC1 or STDC2') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + assert args.type in ['STDC1', + 'STDC2'], 'STD type should be STDC1 or STDC2!' + weight = convert_stdc(state_dict, args.type) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/swin2mmseg.py b/tools/model_converters/swin2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..d434f9465bbdad6bebc7d5962e8bfaf63c7c9e72 --- /dev/null +++ b/tools/model_converters/swin2mmseg.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_swin(ckpt): + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained swin models to' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_swin(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/twins2mmseg.py b/tools/model_converters/twins2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..647d41784aa07468be4b3f2e183064ad55266ad1 --- /dev/null +++ b/tools/model_converters/twins2mmseg.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_twins(args, ckpt): + + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('head'): + continue + elif k.startswith('patch_embeds'): + if 'proj.' in k: + new_k = k.replace('proj.', 'projection.') + else: + new_k = k + elif k.startswith('blocks'): + # Union + if 'attn.q.' in k: + new_k = k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]], + dim=0) + elif 'mlp.fc1' in k: + new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = k.replace('mlp.fc2', 'ffn.layers.1') + # Only pcpvt + elif args.model == 'pcpvt': + if 'attn.proj.' in k: + new_k = k.replace('proj.', 'attn.out_proj.') + else: + new_k = k + + # Only svt + else: + if 'attn.proj.' in k: + k_lst = k.split('.') + if int(k_lst[2]) % 2 == 1: + new_k = k.replace('proj.', 'attn.out_proj.') + else: + new_k = k + else: + new_k = k + new_k = new_k.replace('blocks.', 'layers.') + elif k.startswith('pos_block'): + new_k = k.replace('pos_block', 'position_encodings') + if 'proj.0.' in new_k: + new_k = new_k.replace('proj.0.', 'proj.') + else: + new_k = k + if 'attn.kv.' not in k: + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + parser.add_argument('model', help='model: pcpvt or svt') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + weight = convert_twins(args, state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/vit2mmseg.py b/tools/model_converters/vit2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1f8a427e232290c6dcf490e33f777275dd238a --- /dev/null +++ b/tools/model_converters/vit2mmseg.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_vit(ckpt): + + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + if k.startswith('norm'): + new_k = k.replace('norm.', 'ln1.') + elif k.startswith('patch_embed'): + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + elif k.startswith('blocks'): + if 'norm' in k: + new_k = k.replace('norm', 'ln') + elif 'mlp.fc1' in k: + new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = k.replace('mlp.fc2', 'ffn.layers.1') + elif 'attn.qkv' in k: + new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') + elif 'attn.proj' in k: + new_k = k.replace('attn.proj', 'attn.attn.out_proj') + else: + new_k = k + new_k = new_k.replace('blocks.', 'layers.') + else: + new_k = k + new_ckpt[new_k] = v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_vit(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/vitjax2mmseg.py b/tools/model_converters/vitjax2mmseg.py new file mode 100644 index 0000000000000000000000000000000000000000..81bc2ea020e32d086fc4ce2153cc2bf51edd4d48 --- /dev/null +++ b/tools/model_converters/vitjax2mmseg.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import mmengine +import numpy as np +import torch + + +def vit_jax_to_torch(jax_weights, num_layer=12): + torch_weights = dict() + + # patch embedding + conv_filters = jax_weights['embedding/kernel'] + conv_filters = conv_filters.permute(3, 2, 0, 1) + torch_weights['patch_embed.projection.weight'] = conv_filters + torch_weights['patch_embed.projection.bias'] = jax_weights[ + 'embedding/bias'] + + # pos embedding + torch_weights['pos_embed'] = jax_weights[ + 'Transformer/posembed_input/pos_embedding'] + + # cls token + torch_weights['cls_token'] = jax_weights['cls'] + + # head + torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale'] + torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias'] + + # transformer blocks + for i in range(num_layer): + jax_block = f'Transformer/encoderblock_{i}' + torch_block = f'layers.{i}' + + # attention norm + torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[ + f'{jax_block}/LayerNorm_0/scale'] + torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[ + f'{jax_block}/LayerNorm_0/bias'] + + # attention + query_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel'] + query_bias = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/query/bias'] + key_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel'] + key_bias = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/key/bias'] + value_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel'] + value_bias = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/value/bias'] + + qkv_weight = torch.from_numpy( + np.stack((query_weight, key_weight, value_weight), 1)) + qkv_weight = torch.flatten(qkv_weight, start_dim=1) + qkv_bias = torch.from_numpy( + np.stack((query_bias, key_bias, value_bias), 0)) + qkv_bias = torch.flatten(qkv_bias, start_dim=0) + + torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight + torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias + to_out_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel'] + to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1) + torch_weights[ + f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight + torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/out/bias'] + + # mlp norm + torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[ + f'{jax_block}/LayerNorm_2/scale'] + torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[ + f'{jax_block}/LayerNorm_2/bias'] + + # mlp + torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_0/kernel'] + torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_0/bias'] + torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_1/kernel'] + torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_1/bias'] + + # transpose weights + for k, v in torch_weights.items(): + if 'weight' in k and 'patch_embed' not in k and 'ln' not in k: + v = v.permute(1, 0) + torch_weights[k] = v + + return torch_weights + + +def main(): + # stole refactoring code from Robin Strudel, thanks + parser = argparse.ArgumentParser( + description='Convert keys from jax official pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + jax_weights = np.load(args.src) + jax_weights_tensor = {} + for key in jax_weights.files: + value = torch.from_numpy(jax_weights[key]) + jax_weights_tensor[key] = value + if 'L_16-i21k' in args.src: + num_layer = 24 + else: + num_layer = 12 + torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(torch_weights, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..4e6f7bf4e33267f269cf0f455924cb70166ccd4b --- /dev/null +++ b/tools/slurm_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +CHECKPOINT=$4 +GPUS=${GPUS:-4} +GPUS_PER_NODE=${GPUS_PER_NODE:-4} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab232105f0309c720ed81a522eca14b6fbd64afd --- /dev/null +++ b/tools/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-4} +GPUS_PER_NODE=${GPUS_PER_NODE:-4} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd129e8c47c0c97d08ab7aa3d21fbfeb32ff34b --- /dev/null +++ b/tools/test.py @@ -0,0 +1,124 @@ +import sys +sys.path.append(sys.path[0] + "/..") +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + + +# TODO: support fuse_conv_bn, visualization, and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMSeg test (and eval) a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help=('if specified, the evaluation metric results will be dumped' + 'into the directory as json')) + parser.add_argument( + '--out', + type=str, + help='The directory to save output prediction for offline evaluation') + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + # Turn on visualization + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualizer = cfg.visualizer + visualizer['save_dir'] = args.show_dir + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks.' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + + if args.tta: + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.tta_model.module = cfg.model + cfg.model = cfg.tta_model + + # add output_dir in metric + if args.out is not None: + cfg.test_evaluator['output_dir'] = args.out + cfg.test_evaluator['keep_results'] = True + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() diff --git a/tools/torchserve/mmseg2torchserve.py b/tools/torchserve/mmseg2torchserve.py new file mode 100644 index 0000000000000000000000000000000000000000..23f99638e799fd0b37a6737cc833dd7d24f611f8 --- /dev/null +++ b/tools/torchserve/mmseg2torchserve.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser, Namespace +from pathlib import Path +from tempfile import TemporaryDirectory + +from mmengine import Config +from mmengine.utils import mkdir_or_exist + +try: + from model_archiver.model_packaging import package_model + from model_archiver.model_packaging_utils import ModelExportUtils +except ImportError: + package_model = None + + +def mmseg2torchserve( + config_file: str, + checkpoint_file: str, + output_folder: str, + model_name: str, + model_version: str = '1.0', + force: bool = False, +): + """Converts mmsegmentation model (config + checkpoint) to TorchServe + `.mar`. + + Args: + config_file: + In MMSegmentation config format. + The contents vary for each task repository. + checkpoint_file: + In MMSegmentation checkpoint format. + The contents vary for each task repository. + output_folder: + Folder where `{model_name}.mar` will be created. + The file created will be in TorchServe archive format. + model_name: + If not None, used for naming the `{model_name}.mar` file + that will be created under `output_folder`. + If None, `{Path(checkpoint_file).stem}` will be used. + model_version: + Model's version. + force: + If True, if there is an existing `{model_name}.mar` + file under `output_folder` it will be overwritten. + """ + mkdir_or_exist(output_folder) + + config = Config.fromfile(config_file) + + with TemporaryDirectory() as tmpdir: + config.dump(f'{tmpdir}/config.py') + + args = Namespace( + **{ + 'model_file': f'{tmpdir}/config.py', + 'serialized_file': checkpoint_file, + 'handler': f'{Path(__file__).parent}/mmseg_handler.py', + 'model_name': model_name or Path(checkpoint_file).stem, + 'version': model_version, + 'export_path': output_folder, + 'force': force, + 'requirements_file': None, + 'extra_files': None, + 'runtime': 'python', + 'archive_format': 'default' + }) + manifest = ModelExportUtils.generate_manifest_json(args) + package_model(args, manifest) + + +def parse_args(): + parser = ArgumentParser( + description='Convert mmseg models to TorchServe `.mar` format.') + parser.add_argument('config', type=str, help='config file path') + parser.add_argument('checkpoint', type=str, help='checkpoint file path') + parser.add_argument( + '--output-folder', + type=str, + required=True, + help='Folder where `{model_name}.mar` will be created.') + parser.add_argument( + '--model-name', + type=str, + default=None, + help='If not None, used for naming the `{model_name}.mar`' + 'file that will be created under `output_folder`.' + 'If None, `{Path(checkpoint_file).stem}` will be used.') + parser.add_argument( + '--model-version', + type=str, + default='1.0', + help='Number used for versioning.') + parser.add_argument( + '-f', + '--force', + action='store_true', + help='overwrite the existing `{model_name}.mar`') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = parse_args() + + if package_model is None: + raise ImportError('`torch-model-archiver` is required.' + 'Try: pip install torch-model-archiver') + + mmseg2torchserve(args.config, args.checkpoint, args.output_folder, + args.model_name, args.model_version, args.force) diff --git a/tools/torchserve/mmseg_handler.py b/tools/torchserve/mmseg_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe5ded8482c1113a6adb45a22b650af71f6294e --- /dev/null +++ b/tools/torchserve/mmseg_handler.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import base64 +import os + +import cv2 +import mmcv +import torch +from mmengine.model.utils import revert_sync_batchnorm +from ts.torch_handler.base_handler import BaseHandler + +from mmseg.apis import inference_model, init_model + + +class MMsegHandler(BaseHandler): + + def initialize(self, context): + properties = context.system_properties + self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(self.map_location + ':' + + str(properties.get('gpu_id')) if torch.cuda. + is_available() else self.map_location) + self.manifest = context.manifest + + model_dir = properties.get('model_dir') + serialized_file = self.manifest['model']['serializedFile'] + checkpoint = os.path.join(model_dir, serialized_file) + self.config_file = os.path.join(model_dir, 'config.py') + + self.model = init_model(self.config_file, checkpoint, self.device) + self.model = revert_sync_batchnorm(self.model) + self.initialized = True + + def preprocess(self, data): + images = [] + + for row in data: + image = row.get('data') or row.get('body') + if isinstance(image, str): + image = base64.b64decode(image) + image = mmcv.imfrombytes(image) + images.append(image) + + return images + + def inference(self, data, *args, **kwargs): + results = [inference_model(self.model, img) for img in data] + return results + + def postprocess(self, data): + output = [] + + for image_result in data: + _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) + content = buffer.tobytes() + output.append(content) + return output diff --git a/tools/torchserve/test_torchserve.py b/tools/torchserve/test_torchserve.py new file mode 100644 index 0000000000000000000000000000000000000000..b015b6658556e5045af2daf5d998de0de61e1f6b --- /dev/null +++ b/tools/torchserve/test_torchserve.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser +from io import BytesIO + +import matplotlib.pyplot as plt +import mmcv +import requests + +from mmseg.apis import inference_model, init_model + + +def parse_args(): + parser = ArgumentParser( + description='Compare result of torchserve and pytorch,' + 'and visualize them.') + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument('model_name', help='The model name in the server') + parser.add_argument( + '--inference-addr', + default='127.0.0.1:8080', + help='Address and port of the inference server') + parser.add_argument( + '--result-image', + type=str, + default=None, + help='save server output in result-image') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + + args = parser.parse_args() + return args + + +def main(args): + url = 'http://' + args.inference_addr + '/predictions/' + args.model_name + with open(args.img, 'rb') as image: + tmp_res = requests.post(url, image) + content = tmp_res.content + if args.result_image: + with open(args.result_image, 'wb') as out_image: + out_image.write(content) + plt.imshow(mmcv.imread(args.result_image, 'grayscale')) + plt.show() + else: + plt.imshow(plt.imread(BytesIO(content))) + plt.show() + model = init_model(args.config, args.checkpoint, args.device) + image = mmcv.imread(args.img) + result = inference_model(model, image) + plt.imshow(result[0]) + plt.show() + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7d951d54cd2dd8eb4bb304066e0b4b6ed7441951 --- /dev/null +++ b/tools/train.py @@ -0,0 +1,105 @@ +import sys +sys.path.append(sys.path[0] + "/..") +import argparse +import logging +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.logging import print_log +from mmengine.runner import Runner + +from mmseg.registry import RUNNERS + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + action='store_true', + default=False, + help='resume from the latest checkpoint in the work_dir automatically') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + + # resume training + cfg.resume = args.resume + + # build the runner from config + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main()